Transfira o aprendizado e o ajuste fino

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Configurar

import numpy as np
import tensorflow as tf
from tensorflow import keras

Introdução

Aprendizagem de transferência consiste em tomar recursos aprendidos em um problema, e aproveitando-los em um novo problema, similar. Por exemplo, recursos de um modelo que aprendeu a identificar guaxinins podem ser úteis para iniciar um modelo destinado a identificar tanukis.

O aprendizado de transferência geralmente é feito para tarefas em que seu conjunto de dados tem poucos dados para treinar um modelo em escala real do zero.

A encarnação mais comum da aprendizagem por transferência no contexto de aprendizagem profunda é o seguinte fluxo de trabalho:

  1. Pegue camadas de um modelo previamente treinado.
  2. Congele-os para evitar a destruição de qualquer informação que eles contenham durante as próximas rodadas de treinamento.
  3. Adicione algumas novas camadas treináveis ​​no topo das camadas congeladas. Eles aprenderão a transformar os recursos antigos em previsões em um novo conjunto de dados.
  4. Treine as novas camadas em seu conjunto de dados.

A última etapa, opcional, é o ajuste fino, que consiste em descongelar todo o modelo que você obtido acima (ou parte dele), e re treinar-lo sobre os novos dados com uma taxa de aprendizagem muito baixo. Isso pode potencialmente alcançar melhorias significativas, adaptando de forma incremental os recursos pré-treinados aos novos dados.

Primeiro, vamos falar sobre o Keras trainable API em detalhes, que subjaz a maioria dos transferência de aprendizagem e de ajuste fino fluxos de trabalho.

Em seguida, demonstraremos o fluxo de trabalho típico pegando um modelo pré-treinado no conjunto de dados ImageNet e retreinando-o no conjunto de dados de classificação "gatos vs cães" do Kaggle.

Esta é uma adaptação de profunda aprendizagem com Python e 2016 blog "construir poderosos modelos de classificação imagem usando muito pouco dados" .

Camadas de congelamento: Entendendo o trainable atributo

Camadas e modelos têm três atributos de peso:

  • weights é a lista de todos os pesos variáveis da camada.
  • trainable_weights a lista de aqueles que se destinam a ser atualizado (via descida gradiente) para minimizar a perda durante o treinamento.
  • non_trainable_weights é a lista daqueles que não são destinadas a ser treinado. Normalmente, eles são atualizados pelo modelo durante o passe para frente.

Exemplo: a Dense camada tem 2 pesos treináveis (kernel & viés)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

Em geral, todos os pesos são treináveis. A única embutido camada que tem pesos não treináveis é o BatchNormalization camada. Ele usa pesos não treináveis ​​para controlar a média e a variância de suas entradas durante o treinamento. Para aprender a usar pesos não treináveis em suas próprias camadas personalizados, consulte o guia para escrever novas camadas a partir do zero .

Exemplo: o BatchNormalization camada tem 2 pesos orientável e 2 pesos não treináveis

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Camadas e modelos também possuem um atributo booleano trainable . Seu valor pode ser alterado. Definir layer.trainable de False movimentos todos os pesos da camada de treinável para não-treinável. Isso é chamado de "congelamento" da camada: o estado de uma camada congelada não será atualizado durante o treinamento (ou quando o treinamento com fit() ou quando o treinamento com qualquer laço personalizado que se baseia em trainable_weights para aplicar as atualizações de gradiente).

Exemplo: Ajuste trainable para False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Quando um peso treinável se torna não treinável, seu valor não é mais atualizado durante o treinamento.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

Não confunda o layer.trainable atributo com o argumento training na layer.__call__() (que controla se a camada deve executar o seu passe para frente no modo de inferência ou modo de treino). Para mais informações, consulte o Keras FAQ .

Definição recursiva do trainable atributo

Se você definir trainable = False em um modelo ou em qualquer camada que tem subcamadas, todas as crianças camadas se tornar não-treinável também.

Exemplo:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

O típico fluxo de trabalho de aprendizagem por transferência

Isso nos leva a como um fluxo de trabalho de aprendizagem por transferência típico pode ser implementado no Keras:

  1. Instancie um modelo básico e carregue pesos pré-treinados nele.
  2. Congelar todas as camadas do modelo de base, definindo trainable = False .
  3. Crie um novo modelo sobre a saída de uma (ou várias) camadas do modelo base.
  4. Treine seu novo modelo em seu novo conjunto de dados.

Observe que um fluxo de trabalho alternativo e mais leve também pode ser:

  1. Instancie um modelo básico e carregue pesos pré-treinados nele.
  2. Execute seu novo conjunto de dados por meio dele e registre a saída de uma (ou várias) camadas do modelo base. Isso é chamado de extração de características.
  3. Use essa saída como dados de entrada para um modelo novo e menor.

Uma vantagem principal desse segundo fluxo de trabalho é que você só executa o modelo básico uma vez nos dados, em vez de uma vez por época de treinamento. Portanto, é muito mais rápido e mais barato.

Um problema com esse segundo fluxo de trabalho, porém, é que ele não permite que você modifique dinamicamente os dados de entrada de seu novo modelo durante o treinamento, o que é necessário ao fazer o aumento de dados, por exemplo. O aprendizado de transferência é normalmente usado para tarefas quando seu novo conjunto de dados tem poucos dados para treinar um modelo em escala real do zero e, em tais cenários, o aumento de dados é muito importante. Portanto, a seguir, vamos nos concentrar no primeiro fluxo de trabalho.

Esta é a aparência do primeiro fluxo de trabalho no Keras:

Primeiro, instancie um modelo básico com pesos pré-treinados.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Em seguida, congele o modelo básico.

base_model.trainable = False

Crie um novo modelo em cima.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Treine o modelo com novos dados.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Afinação

Depois que seu modelo convergir para os novos dados, você pode tentar descongelar todo ou parte do modelo básico e retreinar todo o modelo de ponta a ponta com uma taxa de aprendizado muito baixa.

Esta é uma última etapa opcional que pode oferecer melhorias incrementais. Também pode levar a um rápido sobreajuste - tenha isso em mente.

É fundamental apenas para fazer este passo após o modelo com camadas congeladas foi treinado para a convergência. Se você misturar camadas treináveis ​​inicializadas aleatoriamente com camadas treináveis ​​que contêm recursos pré-treinados, as camadas inicializadas aleatoriamente causarão atualizações de gradiente muito grandes durante o treinamento, o que destruirá seus recursos pré-treinados.

Também é fundamental usar uma taxa de aprendizado muito baixa neste estágio, porque você está treinando um modelo muito maior do que na primeira rodada de treinamento, em um conjunto de dados que normalmente é muito pequeno. Como resultado, você corre o risco de overfitting muito rapidamente se aplicar grandes atualizações de peso. Aqui, você só deseja readaptar os pesos pré-treinados de forma incremental.

Veja como implementar o ajuste fino de todo o modelo básico:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Nota importante sobre compile() e trainable

Chamando compile() em um modelo serve para "congelar" o comportamento desse modelo. Isto implica que os trainable valores de atributos no momento o modelo é compilado deve ser preservada durante toda a vida desse modelo, até compile é chamado novamente. Portanto, se você alterar qualquer trainable valor, certifique-se de chamada compile() novamente no seu modelo para que as alterações sejam tidos em conta.

Notas importantes sobre BatchNormalization camada

Muitos modelos de imagem contêm BatchNormalization camadas. Essa camada é um caso especial em todos os aspectos imagináveis. Aqui estão algumas coisas que você deve ter em mente.

  • BatchNormalization contém 2 pesos não treináveis que são atualizados durante o treinamento. Essas são as variáveis ​​que rastreiam a média e a variância das entradas.
  • Quando você define bn_layer.trainable = False , o BatchNormalization camada será executado no modo de inferência, e não irá atualizar suas estatísticas médias e variância. Este não é o caso de outras camadas, em geral, a partir de peso treinabilidade & inferência modos de treinamento / são dois conceitos ortogonais . Mas os dois estão empatados no caso do BatchNormalization camada.
  • Quando você descongelar um modelo que contém BatchNormalization camadas, a fim de fazer o ajuste fino, você deve manter as BatchNormalization camadas em modo de inferência passando training=False ao chamar o modelo base. Caso contrário, as atualizações aplicadas aos pesos não treináveis ​​destruirão repentinamente o que o modelo aprendeu.

Você verá esse padrão em ação no exemplo de ponta a ponta no final deste guia.

Transfira o aprendizado e o ajuste fino com um loop de treinamento personalizado

Se em vez de fit() , você está usando seu próprio loop de formação de baixo nível, as estadias de fluxo de trabalho essencialmente o mesmo. Você deve ter cuidado para ter apenas em conta a lista model.trainable_weights ao aplicar atualizações gradiente:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Da mesma forma para o ajuste fino.

Um exemplo completo: ajuste fino de um modelo de classificação de imagem em um conjunto de dados de cães e gatos

Para solidificar esses conceitos, vamos conduzi-lo por um exemplo concreto de transferência de aprendizagem e ajuste fino. Vamos carregar o modelo Xception, pré-treinado no ImageNet, e usá-lo no conjunto de dados de classificação "gatos vs. cachorros" do Kaggle.

Obtendo os dados

Primeiro, vamos buscar o conjunto de dados de cães e gatos usando TFDS. Se você tem seu próprio conjunto de dados, você provavelmente vai querer usar o utilitário tf.keras.preprocessing.image_dataset_from_directory para gerar conjunto de dados rotulados semelhante objetos a partir de um conjunto de imagens no disco arquivados em pastas específicas de classe.

A aprendizagem por transferência é mais útil ao trabalhar com conjuntos de dados muito pequenos. Para manter nosso conjunto de dados pequeno, usaremos 40% dos dados de treinamento originais (25.000 imagens) para treinamento, 10% para validação e 10% para teste.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Estas são as primeiras 9 imagens no conjunto de dados de treinamento - como você pode ver, são todas de tamanhos diferentes.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Também podemos ver que o marcador 1 é "cachorro" e o marcador 0 é "gato".

Padronizando os dados

Nossas imagens brutas têm uma variedade de tamanhos. Além disso, cada pixel consiste em 3 valores inteiros entre 0 e 255 (valores de nível RGB). Este não é um ótimo ajuste para alimentar uma rede neural. Precisamos fazer 2 coisas:

  • Padronize para um tamanho de imagem fixo. Escolhemos 150 x 150.
  • Valores de pixel Normalize entre -1 e 1. Nós vamos fazer isso usando uma Normalization camada como parte do próprio modelo.

Em geral, é uma boa prática desenvolver modelos que usam dados brutos como entrada, ao contrário de modelos que usam dados já pré-processados. A razão é que, se seu modelo espera dados pré-processados, sempre que exportar seu modelo para usá-lo em outro lugar (em um navegador da web, em um aplicativo móvel), você precisará reimplementar exatamente o mesmo pipeline de pré-processamento. Isso fica muito complicado muito rapidamente. Portanto, devemos fazer o mínimo possível de pré-processamento antes de chegar ao modelo.

Aqui, faremos o redimensionamento da imagem no pipeline de dados (porque uma rede neural profunda só pode processar lotes contíguos de dados) e faremos o escalonamento do valor de entrada como parte do modelo, quando o criarmos.

Vamos redimensionar as imagens para 150 x 150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Além disso, vamos agrupar os dados e usar cache e pré-busca para otimizar a velocidade de carregamento.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Usando aumento de dados aleatórios

Quando você não tem um grande conjunto de dados de imagens, é uma boa prática introduzir artificialmente a diversidade da amostra, aplicando transformações aleatórias, porém realistas, às imagens de treinamento, como inversão horizontal aleatória ou pequenas rotações aleatórias. Isso ajuda a expor o modelo a diferentes aspectos dos dados de treinamento enquanto desacelera o sobreajuste.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Vamos visualizar a aparência da primeira imagem do primeiro lote após várias transformações aleatórias:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Construir um modelo

Agora vamos construir um modelo que segue o projeto que explicamos anteriormente.

Observe que:

  • Nós adicionar um Rescaling camada para valores de entrada de escala (inicialmente no [0, 255] intervalo) para a [-1, 1] gama.
  • Nós adicionamos um Dropout camada antes de a camada de classificação, para a regularização.
  • Temos certeza de passar training=False ao chamar o modelo de base, para que seja executada no modo de inferência, de modo que as estatísticas batchnorm não são atualizados mesmo depois de descongelar o modelo base para o ajuste fino.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Treine a camada superior

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

Faça uma rodada de ajuste fino de todo o modelo

Por fim, vamos descongelar o modelo básico e treinar todo o modelo de ponta a ponta com uma baixa taxa de aprendizado.

Importante, embora o modelo de base se torna treinável, ainda é executado no modo de inferência, uma vez que passou training=False ao chamar-lo quando construímos o modelo. Isso significa que as camadas de normalização de lote internas não atualizarão suas estatísticas de lote. Se o fizessem, eles destruiriam as representações aprendidas pelo modelo até agora.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

Após 10 épocas, o ajuste fino nos proporciona uma boa melhoria aqui.