Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Transferir aprendizaje y ajuste fino

Ver en TensorFlow.org Ver fuente en GitHub Descargar cuaderno

Preparar

import numpy as np
import tensorflow as tf
from tensorflow import keras

Introducción

El aprendizaje por transferencia consiste en tomar las características aprendidas en un problema y aprovecharlas en un problema nuevo similar. Por ejemplo, las características de un modelo que ha aprendido a identificar mapaches pueden ser útiles para poner en marcha un modelo destinado a identificar tanukis.

El aprendizaje de transferencia generalmente se realiza para tareas en las que su conjunto de datos tiene muy pocos datos para entrenar un modelo a gran escala desde cero.

La encarnación más común del aprendizaje por transferencia en el contexto del aprendizaje profundo es el siguiente flujo de trabajo:

  1. Tome capas de un modelo previamente entrenado.
  2. Congélelos para evitar destruir la información que contienen durante las rondas de entrenamiento futuras.
  3. Agregue algunas capas nuevas y entrenables sobre las capas congeladas. Aprenderán a convertir las características antiguas en predicciones en un nuevo conjunto de datos.
  4. Entrene las nuevas capas en su conjunto de datos.

Un último paso, opcional, es el ajuste fino , que consiste en descongelar todo el modelo que obtuvo anteriormente (o parte de él) y volver a entrenarlo con los nuevos datos con una tasa de aprendizaje muy baja. Potencialmente, esto puede lograr mejoras significativas al adaptar gradualmente las características previamente entrenadas a los nuevos datos.

Primero, trainable API trainable Keras en detalle, que subyace en la mayoría de los flujos de trabajo de aprendizaje de transferencia y ajuste.

Luego, demostraremos el flujo de trabajo típico tomando un modelo previamente entrenado en el conjunto de datos de ImageNet y reentrenándolo en el conjunto de datos de clasificación "gatos contra perros" de Kaggle.

Esto está adaptado de Deep Learning con Python y la publicación del blog de 2016 "Creación de modelos potentes de clasificación de imágenes utilizando muy pocos datos" .

Congelación de capas: comprensión del atributo trainable

Las capas y los modelos tienen tres atributos de peso:

  • weights es la lista de todas las variables de pesos de la capa.
  • trainable_weights es la lista de aquellos que deben actualizarse (a través del descenso de gradiente) para minimizar la pérdida durante el entrenamiento.
  • non_trainable_weights es la lista de aquellos que no están destinados a ser entrenados. Por lo general, el modelo los actualiza durante el pase hacia adelante.

Ejemplo: la capa Dense tiene 2 pesos entrenables (núcleo y sesgo)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

En general, todos los pesos son pesos entrenables. La única capa incorporada que tiene pesos no entrenables es la capa BatchNormalization . Utiliza pesos no entrenables para realizar un seguimiento de la media y la varianza de sus entradas durante el entrenamiento. Para aprender a usar pesos no entrenables en sus propias capas personalizadas, consulte la guía para escribir nuevas capas desde cero .

Ejemplo: la capa BatchNormalization tiene 2 pesos entrenables y 2 pesos no entrenables

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Las capas y modelos también cuentan con un atributo booleano que se puede trainable . Su valor se puede cambiar. Establecer layer.trainable en False mueve todos los pesos de la capa de entrenables a no entrenables. A esto se le llama "congelar" la capa: el estado de una capa congelada no se actualizará durante el entrenamiento (ya sea al entrenar con fit() o al entrenar con cualquier bucle personalizado que se base en trainable_weights para aplicar actualizaciones de gradiente).

Ejemplo: establecer trainable en False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Cuando un peso entrenable se vuelve no entrenable, su valor ya no se actualiza durante el entrenamiento.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 1ms/step - loss: 0.1275

No confunda el atributo layer.trainable con el argumento training en layer.__call__() (que controla si la capa debe ejecutar su pase hacia adelante en modo de inferencia o en modo de entrenamiento). Para obtener más información, consulte las preguntas frecuentes de Keras .

Configuración recursiva del atributo trainable

Si establece trainable = False en un modelo o en cualquier capa que tenga subcapas, todas las capas secundarias tampoco se pueden entrenar.

Ejemplo:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

El típico flujo de trabajo de aprendizaje por transferencia

Esto nos lleva a cómo se puede implementar un flujo de trabajo de aprendizaje por transferencia típico en Keras:

  1. Cree una instancia de un modelo base y cargue pesos previamente entrenados en él.
  2. Congele todas las capas en el modelo base configurando trainable = False .
  3. Cree un nuevo modelo sobre la salida de una (o varias) capas del modelo base.
  4. Entrene su nuevo modelo en su nuevo conjunto de datos.

Tenga en cuenta que un flujo de trabajo alternativo y más ligero también podría ser:

  1. Cree una instancia de un modelo base y cargue pesos previamente entrenados en él.
  2. Ejecute su nuevo conjunto de datos a través de él y registre la salida de una (o varias) capas del modelo base. A esto se le llama extracción de características .
  3. Utilice esa salida como datos de entrada para un modelo nuevo y más pequeño.

Una ventaja clave de ese segundo flujo de trabajo es que solo ejecuta el modelo base una vez en sus datos, en lugar de una vez por época de entrenamiento. Entonces es mucho más rápido y más barato.

Sin embargo, un problema con ese segundo flujo de trabajo es que no le permite modificar dinámicamente los datos de entrada de su nuevo modelo durante el entrenamiento, lo cual es necesario al realizar el aumento de datos, por ejemplo. El aprendizaje por transferencia se usa generalmente para tareas en las que su nuevo conjunto de datos tiene muy pocos datos para entrenar un modelo a gran escala desde cero, y en tales escenarios el aumento de datos es muy importante. Entonces, en lo que sigue, nos centraremos en el primer flujo de trabajo.

Así es como se ve el primer flujo de trabajo en Keras:

Primero, cree una instancia de un modelo base con pesos previamente entrenados.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Luego, congele el modelo base.

base_model.trainable = False

Crea un nuevo modelo en la parte superior.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Entrene el modelo con datos nuevos.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Sintonia FINA

Una vez que su modelo ha convergido en los nuevos datos, puede intentar descongelar todo o parte del modelo base y volver a entrenar todo el modelo de principio a fin con una tasa de aprendizaje muy baja.

Este es un último paso opcional que potencialmente puede brindarle mejoras incrementales. También podría conducir a un sobreajuste rápido, tenlo en cuenta.

Es fundamental realizar este paso solo después de que el modelo con capas congeladas se haya entrenado para la convergencia. Si mezcla capas entrenables inicializadas aleatoriamente con capas entrenables que contienen características entrenadas previamente, las capas inicializadas aleatoriamente causarán actualizaciones de gradiente muy grandes durante el entrenamiento, lo que destruirá sus características entrenadas previamente.

También es fundamental utilizar una tasa de aprendizaje muy baja en esta etapa, porque está entrenando un modelo mucho más grande que en la primera ronda de entrenamiento, en un conjunto de datos que normalmente es muy pequeño. Como resultado, corre el riesgo de sobreajustar muy rápidamente si aplica grandes actualizaciones de peso. Aquí, solo desea readaptar los pesos preentrenados de forma incremental.

Así es como implementar el ajuste fino de todo el modelo base:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Nota importante sobre compile() y trainable

Llamar a compile() en un modelo tiene la intención de "congelar" el comportamiento de ese modelo. Esto implica que los valores de los atributos trainable en el momento en que se compila el modelo deben conservarse durante la vida útil de ese modelo, hasta que se vuelva a llamar a compile . Por lo tanto, si cambia cualquier valor trainable , asegúrese de llamar a compile() nuevamente en su modelo para que se tomen en cuenta los cambios.

Notas importantes sobre la capa BatchNormalization

Muchos modelos de imágenes contienen capas de BatchNormalization . Esa capa es un caso especial en todos los aspectos imaginables. A continuación, se incluyen algunas cosas que debe tener en cuenta.

  • BatchNormalization contiene 2 pesos no entrenables que se actualizan durante el entrenamiento. Estas son las variables que siguen la media y la varianza de las entradas.
  • Cuando configura bn_layer.trainable = False , la capa BatchNormalization se ejecutará en modo de inferencia y no actualizará sus estadísticas de media y varianza. Este no es el caso de otras capas en general, ya que la capacidad de entrenamiento de peso y los modos de inferencia / entrenamiento son dos conceptos ortogonales . Pero los dos están empatados en el caso de la capa BatchNormalization .
  • Cuando descongela un modelo que contiene capas de BatchNormalization para realizar un ajuste fino, debe mantener las capas de BatchNormalization en modo de inferencia pasando training=False al llamar al modelo base. De lo contrario, las actualizaciones aplicadas a los pesos no entrenables destruirán repentinamente lo que ha aprendido el modelo.

Verá este patrón en acción en el ejemplo de extremo a extremo al final de esta guía.

Transfiera el aprendizaje y el ajuste con un ciclo de entrenamiento personalizado

Si en lugar de fit() , está utilizando su propio ciclo de entrenamiento de bajo nivel, el flujo de trabajo permanece esencialmente igual. Debe tener cuidado de solo tener en cuenta la lista model.trainable_weights al aplicar actualizaciones de gradiente:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Lo mismo ocurre con el ajuste fino.

Un ejemplo de principio a fin: ajuste de un modelo de clasificación de imágenes en perros y gatos

conjunto de datos

Para solidificar estos conceptos, lo guiaremos a través de un ejemplo concreto de aprendizaje de transferencia de extremo a extremo y ajuste fino. Cargaremos el modelo Xception, previamente entrenado en ImageNet, y lo usaremos en el conjunto de datos de clasificación "gatos contra perros" de Kaggle.

Obteniendo los datos

Primero, busquemos el conjunto de datos de perros y gatos usando TFDS. Si tiene su propio conjunto de datos, probablemente desee utilizar la utilidad tf.keras.preprocessing.image_dataset_from_directory para generar objetos de conjunto de datos etiquetados similares a partir de un conjunto de imágenes en el disco archivado en carpetas específicas de la clase.

El aprendizaje de transferencia es más útil cuando se trabaja con conjuntos de datos muy pequeños. Para mantener nuestro conjunto de datos pequeño, usaremos el 40% de los datos de entrenamiento originales (25,000 imágenes) para entrenamiento, 10% para validación y 10% para pruebas.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteIL7NQA/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Estas son las primeras 9 imágenes del conjunto de datos de entrenamiento; como puede ver, todas son de diferentes tamaños.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

También podemos ver que la etiqueta 1 es "perro" y la etiqueta 0 es "gato".

Estandarizando los datos

Nuestras imágenes en bruto tienen una variedad de tamaños. Además, cada píxel consta de 3 valores enteros entre 0 y 255 (valores de nivel RGB). Esta no es una buena opción para alimentar una red neuronal. Necesitamos hacer 2 cosas:

  • Estandarice a un tamaño de imagen fijo. Elegimos 150x150.
  • Normalice los valores de los píxeles entre -1 y 1. Lo haremos usando una capa de Normalization como parte del propio modelo.

En general, es una buena práctica desarrollar modelos que toman datos sin procesar como entrada, a diferencia de los modelos que toman datos ya procesados ​​previamente. La razón es que, si su modelo espera datos preprocesados, cada vez que exporte su modelo para usarlo en otro lugar (en un navegador web, en una aplicación móvil), deberá volver a implementar exactamente la misma canalización de preprocesamiento. Esto se vuelve muy complicado muy rápidamente. Por tanto, deberíamos hacer la menor cantidad posible de preprocesamiento antes de utilizar el modelo.

Aquí, cambiaremos el tamaño de la imagen en la canalización de datos (porque una red neuronal profunda solo puede procesar lotes contiguos de datos), y escalaremos el valor de entrada como parte del modelo, cuando lo creemos.

Cambiemos el tamaño de las imágenes a 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Además, vamos a agrupar los datos y usar el almacenamiento en caché y la búsqueda previa para optimizar la velocidad de carga.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Usando el aumento de datos aleatorios

Cuando no tiene un gran conjunto de datos de imágenes, es una buena práctica introducir artificialmente la diversidad de muestras aplicando transformaciones aleatorias pero realistas a las imágenes de entrenamiento, como volteos horizontales aleatorios o pequeñas rotaciones aleatorias. Esto ayuda a exponer el modelo a diferentes aspectos de los datos de entrenamiento mientras ralentiza el sobreajuste.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

Visualicemos cómo se ve la primera imagen del primer lote después de varias transformaciones aleatorias:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[i]))
        plt.axis("off")

png

Construir un modelo

Ahora construyamos un modelo que siga el plan que explicamos anteriormente.

Tenga en cuenta que:

  • Agregamos una capa de Normalization para escalar los valores de entrada (inicialmente en el rango [0, 255] ) al rango [-1, 1] .
  • Dropout una capa de Dropout antes de la capa de clasificación, para regularización.
  • Nos aseguramos de pasar training=False al llamar al modelo base, para que se ejecute en modo de inferencia, de modo que las estadísticas de batchnorm no se actualicen incluso después de descongelar el modelo base para un ajuste fino.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

Entrena la capa superior

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 9s 32ms/step - loss: 0.1758 - binary_accuracy: 0.9226 - val_loss: 0.0897 - val_binary_accuracy: 0.9660
Epoch 2/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1211 - binary_accuracy: 0.9497 - val_loss: 0.0870 - val_binary_accuracy: 0.9686
Epoch 3/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1166 - binary_accuracy: 0.9503 - val_loss: 0.0814 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1125 - binary_accuracy: 0.9534 - val_loss: 0.0825 - val_binary_accuracy: 0.9695
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1073 - binary_accuracy: 0.9569 - val_loss: 0.0763 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1041 - binary_accuracy: 0.9573 - val_loss: 0.0812 - val_binary_accuracy: 0.9686
Epoch 7/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1023 - binary_accuracy: 0.9567 - val_loss: 0.0820 - val_binary_accuracy: 0.9669
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1005 - binary_accuracy: 0.9597 - val_loss: 0.0779 - val_binary_accuracy: 0.9695
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9580 - val_loss: 0.0813 - val_binary_accuracy: 0.9699
Epoch 10/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0940 - binary_accuracy: 0.9651 - val_loss: 0.0762 - val_binary_accuracy: 0.9729
Epoch 11/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0974 - binary_accuracy: 0.9613 - val_loss: 0.0752 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0965 - binary_accuracy: 0.9591 - val_loss: 0.0760 - val_binary_accuracy: 0.9721
Epoch 13/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0962 - binary_accuracy: 0.9598 - val_loss: 0.0785 - val_binary_accuracy: 0.9712
Epoch 14/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0966 - binary_accuracy: 0.9616 - val_loss: 0.0831 - val_binary_accuracy: 0.9699
Epoch 15/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1000 - binary_accuracy: 0.9574 - val_loss: 0.0741 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0940 - binary_accuracy: 0.9628 - val_loss: 0.0781 - val_binary_accuracy: 0.9686
Epoch 17/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0915 - binary_accuracy: 0.9634 - val_loss: 0.0843 - val_binary_accuracy: 0.9678
Epoch 18/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0937 - binary_accuracy: 0.9620 - val_loss: 0.0829 - val_binary_accuracy: 0.9669
Epoch 19/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0988 - binary_accuracy: 0.9601 - val_loss: 0.0862 - val_binary_accuracy: 0.9686
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0928 - binary_accuracy: 0.9644 - val_loss: 0.0798 - val_binary_accuracy: 0.9703

<tensorflow.python.keras.callbacks.History at 0x7f6104f04518>

Haga una ronda de ajuste fino de todo el modelo

Finalmente, descongelemos el modelo base y entrenemos todo el modelo de un extremo a otro con una tasa de aprendizaje baja.

Es importante destacar que, aunque el modelo base se vuelve entrenable, todavía se está ejecutando en modo de inferencia, ya que pasamos training=False cuando lo llamamos cuando construimos el modelo. Esto significa que las capas de normalización de lotes en el interior no actualizarán sus estadísticas de lotes. Si lo hicieran, causarían estragos en las representaciones aprendidas por el modelo hasta ahora.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
  2/291 [..............................] - ETA: 17s - loss: 0.1439 - binary_accuracy: 0.9219WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

291/291 [==============================] - 38s 132ms/step - loss: 0.0786 - binary_accuracy: 0.9706 - val_loss: 0.0631 - val_binary_accuracy: 0.9772
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0553 - binary_accuracy: 0.9790 - val_loss: 0.0537 - val_binary_accuracy: 0.9781
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0442 - binary_accuracy: 0.9829 - val_loss: 0.0532 - val_binary_accuracy: 0.9819
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0369 - binary_accuracy: 0.9858 - val_loss: 0.0460 - val_binary_accuracy: 0.9832
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0335 - binary_accuracy: 0.9870 - val_loss: 0.0561 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0253 - binary_accuracy: 0.9910 - val_loss: 0.0559 - val_binary_accuracy: 0.9819
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0232 - binary_accuracy: 0.9920 - val_loss: 0.0432 - val_binary_accuracy: 0.9845
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0185 - binary_accuracy: 0.9930 - val_loss: 0.0396 - val_binary_accuracy: 0.9854
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0147 - binary_accuracy: 0.9948 - val_loss: 0.0439 - val_binary_accuracy: 0.9832
Epoch 10/10
291/291 [==============================] - 37s 129ms/step - loss: 0.0117 - binary_accuracy: 0.9954 - val_loss: 0.0538 - val_binary_accuracy: 0.9819

<tensorflow.python.keras.callbacks.History at 0x7f611c26e438>

Después de 10 épocas, el ajuste fino nos da una buena mejora aquí.