

import numpy as np
import tensorflow as tf
from tensorflow import keras
  1. 从之前训练的模型中获取层。
  2. 冻结这些层,以避免在后续训练轮次中破坏它们包含的任何信息。
  3. 在已冻结层的顶部添加一些新的可训练层。这些层会学习将旧特征转换为对新数据集的预测。
  4. 在您的数据集上训练新层。


首先,我们将详细介绍 Keras trainable API,它是大多数迁移学习和微调工作流的基础。

随后,我们将演示一个典型工作流:先获得一个在 ImageNet 数据集上预训练的模型,然后在 Kaggle Dogs vs. Cats 分类数据集上对该模型进行重新训练。

此工作流改编自 Python 深度学习 和 2016 年的博文“使用极少的数据构建强大的图像分类模型”

冻结层:了解 trainable 特性


  • weights 是层的所有权重变量的列表。
  • trainable_weights 是需要进行更新(通过梯度下降)以尽可能减少训练过程中损失的权重列表。
  • non_trainable_weights 是不适合训练的权重列表。它们通常在正向传递过程中由模型更新。

示例:Dense 层具有 2 个可训练权重(内核与偏差)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般而言,所有权重都是可训练权重。唯一具有不可训练权重的内置层是 BatchNormalization 层。在训练期间,它使用不可训练权重跟踪其输入的平均值和方差。要了解如何在您自己的自定义层中使用不可训练权重,请参阅从头开始编写新层的指南。

示例:BatchNormalization 层具有 2 个可训练权重和 2 个不可训练权重

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

层和模型还具有布尔特性 trainable。此特性的值可以更改。将 layer.trainable 设置为 False 会将层的所有权重从可训练移至不可训练。这一过程称为“冻结”层:已冻结层的状态在训练期间不会更新(无论是使用 fit() 进行训练,还是使用依赖于 trainable_weights 来应用梯度更新的任何自定义循环进行训练时)。

示例:将 trainable 设置为 False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2


# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
1/1 [==============================] - 1s 1s/step - loss: 0.0926

请勿将 layer.trainable 特性与 layer.__call__() 中的 training 参数(此参数控制层是在推断模式还是训练模式下运行其前向传递)混淆。有关更多信息,请参阅 Keras常见问题解答

trainable 特性的递归设置

如果在模型或具有子层的任何层上设置 trainable = False,则所有子层也将变为不可训练。


inner_model = keras.Sequential(
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively


下面将介绍如何在 Keras 中实现典型的迁移学习工作流:

  1. 实例化一个基础模型并加载预训练权重。
  2. 通过设置 trainable = False 冻结基础模型中的所有层。
  3. 根据基础模型中一个(或多个)层的输出创建一个新模型。
  4. 在您的新数据集上训练新模型。


  1. 实例化一个基础模型并加载预训练权重。
  2. 通过该模型运行新的数据集,并记录基础模型中一个(或多个)层的输出。这一过程称为特征提取。
  3. 使用该输出作为新的较小模型的输入数据。



下面是 Keras 中第一种工作流的样子:


base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.


base_model.trainable = False


inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)


model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)







# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

关于 compile()trainable 的重要说明

在模型上调用 compile() 意味着“冻结”该模型的行为。这意味着编译模型时的 trainable 特性值应当在该模型的整个生命周期中保留,直到再次调用 compile。因此,如果您更改任何 trainable 值,请确保在您的模型上再次调用 compile() 以将您的变更考虑在内。

关于 BatchNormalization 层的重要说明

许多图像模型包含 BatchNormalization 层。该层在任何方面都是一个特例。下面是一些注意事项。

  • BatchNormalization 包含 2 个会在训练期间更新的不可训练权重。它们是跟踪输入的均值和方差的变量。
  • 当您设置 bn_layer.trainable = False 时,BatchNormalization 层将以推断模式运行,并且不会更新其均值和方差统计信息。一般而言,其他层的情况并非如此,因为权重可训练性和推断/训练模式是两个正交概念。但是,对于 BatchNormalization 层,两者是关联的。
  • 当您解冻包含 BatchNormalization 层的模型以进行微调时,您应当通过在调用基础模型时传递 training=False 以将 BatchNormalization 层保持在推断模式。否则,应用于不可训练权重的更新会突然破坏模型已经学习的内容。



如果您使用自己的低级训练循环而不是 fit(),则工作流基本保持不变。在应用梯度更新时,您应当注意只考虑清单 model.trainable_weights

# Create base model
base_model = keras.applications.Xception(
    input_shape=(150, 150, 3),
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))


端到端示例:基于 Dogs vs. Cats 数据集微调图像分类模型

为了巩固这些概念,我们先介绍一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型,并将其用于 Kaggle Dogs vs. Cats 分类数据集。


首先,我们使用 TFDS 来获取 Dogs vs. Cats 数据集。如果您拥有自己的数据集,则可能需要使用效用函数 tf.keras.preprocessing.image_dataset_from_directory 从磁盘上存档到类特定的文件夹中的一组图像来生成相似的有标签数据集对象。

使用非常小的数据集时,迁移学习最实用。为了使数据集保持较小状态,我们将原始训练数据(25,000 个图像)的 40% 用于训练,10% 用于验证,10% 用于测试。

import tensorflow_datasets as tfds


train_ds, validation_ds, test_ds = tfds.load(
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

下面是训练数据集中的前 9 个图像。如您所见,它们具有不同的大小。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)


我们还可以看到标签 1 是“狗”,标签 0 是“猫”。


我们的原始图像有各种大小。另外,每个像素由 0 到 255 之间的 3 个整数值(RGB 色阶值)组成。这不太适合馈送神经网络。我们需要做下面两件事:

  • 标准化为固定图像大小。我们选择 150x150。
  • 在 -1 至 1 之间归一化像素值。我们将使用 Normalization 层作为模型本身的一部分来进行此操作。



我们将图像的大小调整为 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))
batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)



from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]


import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
  • 我们添加 Rescaling 层以将输入值(最初在 [0, 255] 范围内)缩放到 [-1, 1] 范围。
  • 我们在分类层之前添加一个 Dropout 层,以进行正则化。
  • 我们确保在调用基础模型时传递 training=False,使其在推断模式下运行,这样,即使在我们解冻基础模型以进行微调后,batchnorm 统计信息也不会更新。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 [==============================] - 1s 0us/step
83683744/83683744 [==============================] - 1s 0us/step
Model: "model"
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
 xception (Functional)       (None, 5, 5, 2048)        20861480  
 global_average_pooling2d (G  (None, 2048)             0         
 dropout (Dropout)           (None, 2048)              0         
 dense_7 (Dense)             (None, 1)                 2049      
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480



epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 32s 92ms/step - loss: 0.1673 - binary_accuracy: 0.9260 - val_loss: 0.0769 - val_binary_accuracy: 0.9712
Epoch 2/20
291/291 [==============================] - 25s 85ms/step - loss: 0.1146 - binary_accuracy: 0.9513 - val_loss: 0.0729 - val_binary_accuracy: 0.9712
Epoch 3/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1004 - binary_accuracy: 0.9610 - val_loss: 0.0711 - val_binary_accuracy: 0.9721
Epoch 4/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1040 - binary_accuracy: 0.9578 - val_loss: 0.0701 - val_binary_accuracy: 0.9751
Epoch 5/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0982 - binary_accuracy: 0.9601 - val_loss: 0.0713 - val_binary_accuracy: 0.9721
Epoch 6/20
291/291 [==============================] - 25s 84ms/step - loss: 0.0971 - binary_accuracy: 0.9591 - val_loss: 0.0725 - val_binary_accuracy: 0.9708
Epoch 7/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1028 - binary_accuracy: 0.9587 - val_loss: 0.0776 - val_binary_accuracy: 0.9699
Epoch 8/20
291/291 [==============================] - 24s 83ms/step - loss: 0.1006 - binary_accuracy: 0.9598 - val_loss: 0.0729 - val_binary_accuracy: 0.9699
Epoch 9/20
291/291 [==============================] - 25s 84ms/step - loss: 0.0958 - binary_accuracy: 0.9603 - val_loss: 0.0715 - val_binary_accuracy: 0.9708
Epoch 10/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1009 - binary_accuracy: 0.9608 - val_loss: 0.0730 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0924 - binary_accuracy: 0.9614 - val_loss: 0.0684 - val_binary_accuracy: 0.9738
Epoch 12/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0931 - binary_accuracy: 0.9618 - val_loss: 0.0691 - val_binary_accuracy: 0.9746
Epoch 13/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0903 - binary_accuracy: 0.9634 - val_loss: 0.0733 - val_binary_accuracy: 0.9703
Epoch 14/20
291/291 [==============================] - 24s 84ms/step - loss: 0.1024 - binary_accuracy: 0.9597 - val_loss: 0.0720 - val_binary_accuracy: 0.9733
Epoch 15/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0930 - binary_accuracy: 0.9631 - val_loss: 0.0744 - val_binary_accuracy: 0.9721
Epoch 16/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0935 - binary_accuracy: 0.9624 - val_loss: 0.0698 - val_binary_accuracy: 0.9733
Epoch 17/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0927 - binary_accuracy: 0.9632 - val_loss: 0.0716 - val_binary_accuracy: 0.9721
Epoch 18/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0913 - binary_accuracy: 0.9622 - val_loss: 0.0765 - val_binary_accuracy: 0.9712
Epoch 19/20
291/291 [==============================] - 24s 84ms/step - loss: 0.0912 - binary_accuracy: 0.9634 - val_loss: 0.0725 - val_binary_accuracy: 0.9721
Epoch 20/20
291/291 [==============================] - 24s 83ms/step - loss: 0.0886 - binary_accuracy: 0.9643 - val_loss: 0.0714 - val_binary_accuracy: 0.9738
<keras.callbacks.History at 0x7fb4202f8520>



重要的是,尽管基础模型变得可训练,但在构建模型过程中,由于我们在调用该模型时传递了 training=False,因此它仍在推断模式下运行。这意味着内部的批次归一化层不会更新其批次统计信息。如果它们更新了这些统计信息,则会破坏该模型到目前为止所学习的表示。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True

    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
 xception (Functional)       (None, 5, 5, 2048)        20861480  
 global_average_pooling2d (G  (None, 2048)             0         
 dropout (Dropout)           (None, 2048)              0         
 dense_7 (Dense)             (None, 1)                 2049      
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
Epoch 1/10
291/291 [==============================] - 79s 192ms/step - loss: 0.0761 - binary_accuracy: 0.9707 - val_loss: 0.0550 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0534 - binary_accuracy: 0.9793 - val_loss: 0.0450 - val_binary_accuracy: 0.9815
Epoch 3/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0405 - binary_accuracy: 0.9832 - val_loss: 0.0461 - val_binary_accuracy: 0.9828
Epoch 4/10
291/291 [==============================] - 55s 190ms/step - loss: 0.0310 - binary_accuracy: 0.9881 - val_loss: 0.0448 - val_binary_accuracy: 0.9832
Epoch 5/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0214 - binary_accuracy: 0.9919 - val_loss: 0.0628 - val_binary_accuracy: 0.9781
Epoch 6/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0211 - binary_accuracy: 0.9926 - val_loss: 0.0442 - val_binary_accuracy: 0.9811
Epoch 7/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0156 - binary_accuracy: 0.9938 - val_loss: 0.0481 - val_binary_accuracy: 0.9841
Epoch 8/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0113 - binary_accuracy: 0.9965 - val_loss: 0.0614 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0095 - binary_accuracy: 0.9970 - val_loss: 0.0556 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0107 - binary_accuracy: 0.9956 - val_loss: 0.0539 - val_binary_accuracy: 0.9854
<keras.callbacks.History at 0x7fb2c44b7790>

经过 10 个周期后,微调在这里为我们提供了出色的改进。