Google I / O là một kết quả hoàn hảo! Cập nhật các phiên TensorFlow Xem phiên

Chuyển giao việc học và tinh chỉnh

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Thành lập

import numpy as np
import tensorflow as tf
from tensorflow import keras

Giới thiệu

Học tập chuyển giao bao gồm tham gia các tính năng đã học về một vấn đề, và tận dụng chúng trên một mới, vấn đề tương tự. Ví dụ, các tính năng từ một mô hình đã học để xác định các loài chim có thể hữu ích để khởi động một mô hình dùng để xác định tanukis.

Học chuyển giao thường được thực hiện cho các nhiệm vụ mà tập dữ liệu của bạn có quá ít dữ liệu để đào tạo một mô hình quy mô đầy đủ từ đầu.

Hiện thân phổ biến nhất của học chuyển tiếp trong bối cảnh học sâu là quy trình làm việc sau:

  1. Lấy các lớp từ một mô hình đã được đào tạo trước đó.
  2. Đóng băng chúng, để tránh phá hủy bất kỳ thông tin nào mà chúng chứa trong các đợt huấn luyện trong tương lai.
  3. Thêm một số lớp mới, có thể đào tạo lên trên các lớp đã đóng băng. Họ sẽ học cách biến các tính năng cũ thành dự đoán trên một tập dữ liệu mới.
  4. Đào tạo các lớp mới trên tập dữ liệu của bạn.

A, bước không bắt buộc cuối cùng, là tinh chỉnh, trong đó bao gồm unfreezing toàn bộ mô hình mà bạn thu được ở trên (hoặc một phần của nó), và tái đào tạo nó trên các dữ liệu mới với một tỷ lệ học rất thấp. Điều này có thể đạt được những cải tiến có ý nghĩa, bằng cách từng bước điều chỉnh các tính năng được đào tạo trước cho phù hợp với dữ liệu mới.

Đầu tiên, chúng ta sẽ đi qua Keras trainable API cụ thể, làm nền tảng cho hầu hết các quy trình công việc học tập chuyển & tinh chỉnh.

Sau đó, chúng tôi sẽ chứng minh quy trình làm việc điển hình bằng cách lấy một mô hình được đào tạo trước trên tập dữ liệu ImageNet và đào tạo lại nó trên tập dữ liệu phân loại Kaggle "mèo vs chó".

Này được chuyển thể từ sâu Learning với Python và bài 2016 trên blog "xây dựng mô hình phân loại hình ảnh mạnh mẽ sử dụng rất ít dữ liệu" .

Lớp đóng băng: hiểu biết về trainable thuộc tính

Lớp & mô hình có ba thuộc tính trọng lượng:

  • weights là danh sách của tất cả các trọng lượng biến của lớp.
  • trainable_weights là danh sách những người có nghĩa là để được cập nhật (thông qua gradient descent) để giảm thiểu thiệt hại trong quá trình đào tạo.
  • non_trainable_weights là danh sách những người không có nghĩa là để được đào tạo. Thông thường, chúng được cập nhật bởi mô hình trong quá trình chuyển tiếp.

Ví dụ: Dense lớp có 2 trọng khả năng huấn luyện (kernel & bias)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

Nói chung, tất cả các trọng lượng đều là trọng lượng có thể huấn luyện được. Chỉ được xây dựng trong lớp mà có trọng lượng không khả năng huấn luyện là BatchNormalization lớp. Nó sử dụng trọng lượng không thể đào tạo để theo dõi giá trị trung bình và phương sai của các đầu vào trong quá trình đào tạo. Để tìm hiểu cách sử dụng trọng lượng không khả năng huấn luyện trong lớp tùy chỉnh của riêng bạn, hãy xem các hướng dẫn để viết các lớp mới từ đầu .

Ví dụ: BatchNormalization lớp có 2 trọng khả năng huấn luyện và 2 tạ phi khả năng huấn luyện

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Lớp & mô hình cũng có một thuộc tính boolean trainable . Giá trị của nó có thể được thay đổi. Thiết layer.trainable để False di chuyển tất cả trọng lượng của lớp từ khả năng huấn luyện để không dễ huấn luyện. Đây gọi là "đóng băng" các lớp: thực trạng của một lớp đông lạnh sẽ không được cập nhật trong thời gian đào tạo (hoặc khi đào tạo với fit() hoặc khi tập luyện với bất kỳ vòng lặp tùy chỉnh dựa trên trainable_weights để áp dụng bản cập nhật gradient).

Ví dụ: thiết lập trainable để False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Khi một trọng lượng có thể huấn luyện trở nên không thể huấn luyện, giá trị của nó sẽ không còn được cập nhật trong quá trình huấn luyện.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

Đừng nhầm lẫn giữa layer.trainable thuộc tính với lập luận training trong layer.__call__() (mà điều khiển cho dù lớp nên chạy qua phía trước của nó trong chế độ suy luận hoặc chế độ đào tạo). Để biết thêm thông tin, vui lòng xem Keras FAQ .

Thiết lập đệ quy của trainable thuộc tính

Nếu bạn thiết lập trainable = False trên một mô hình hoặc trên bất kỳ lớp có lớp con, tất cả trẻ em lớp trở thành phi khả năng huấn luyện là tốt.

Thí dụ:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Quy trình học tập chuyển giao điển hình

Điều này dẫn chúng ta đến cách một quy trình học tập chuyển giao điển hình có thể được thực hiện trong Keras:

  1. Khởi tạo mô hình cơ sở và tải các trọng lượng đã được đào tạo trước vào đó.
  2. Freeze tất cả các lớp trong mô hình cơ sở bằng cách thiết lập trainable = False .
  3. Tạo một mô hình mới trên đầu ra của một (hoặc một số) lớp từ mô hình cơ sở.
  4. Đào tạo mô hình mới của bạn trên tập dữ liệu mới của bạn.

Lưu ý rằng một quy trình thay thế, nhẹ hơn cũng có thể là:

  1. Khởi tạo mô hình cơ sở và tải các trọng lượng đã được đào tạo trước vào đó.
  2. Chạy tập dữ liệu mới của bạn thông qua nó và ghi lại kết quả đầu ra của một (hoặc một số) lớp từ mô hình cơ sở. Đây được gọi là khai thác tính năng.
  3. Sử dụng đầu ra đó làm dữ liệu đầu vào cho một mô hình mới, nhỏ hơn.

Ưu điểm chính của quy trình làm việc thứ hai đó là bạn chỉ chạy mô hình cơ sở một lần trên dữ liệu của mình, thay vì một lần cho mỗi kỷ nguyên đào tạo. Vì vậy, nó nhanh hơn và rẻ hơn rất nhiều.

Tuy nhiên, một vấn đề với quy trình làm việc thứ hai đó là nó không cho phép bạn tự động sửa đổi dữ liệu đầu vào của mô hình mới trong quá trình đào tạo, ví dụ như bắt buộc khi thực hiện tăng dữ liệu. Học chuyển giao thường được sử dụng cho các nhiệm vụ khi tập dữ liệu mới của bạn có quá ít dữ liệu để đào tạo mô hình quy mô đầy đủ từ đầu và trong các tình huống như vậy, việc tăng dữ liệu là rất quan trọng. Vì vậy, trong những gì tiếp theo, chúng tôi sẽ tập trung vào quy trình làm việc đầu tiên.

Đây là quy trình làm việc đầu tiên trông như thế nào trong Keras:

Đầu tiên, khởi tạo một mô hình cơ sở với các trọng lượng được đào tạo trước.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Sau đó, đóng băng mô hình cơ sở.

base_model.trainable = False

Tạo một mô hình mới trên đầu trang.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Đào tạo mô hình trên dữ liệu mới.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Tinh chỉnh

Khi mô hình của bạn đã hội tụ trên dữ liệu mới, bạn có thể cố gắng giải phóng toàn bộ hoặc một phần của mô hình cơ sở và đào tạo lại toàn bộ mô hình từ đầu đến cuối với tỷ lệ học tập rất thấp.

Đây là bước cuối cùng tùy chọn có thể mang lại cho bạn những cải tiến gia tăng. Nó cũng có thể dẫn đến tình trạng quá tải nhanh chóng - hãy ghi nhớ điều đó.

Nó là rất quan trọng để chỉ thực hiện bước này sau khi mô hình với lớp đông lạnh đã được đào tạo để hội tụ. Nếu bạn trộn các lớp có thể huấn luyện được khởi tạo ngẫu nhiên với các lớp có thể huấn luyện chứa các tính năng đã được huấn luyện trước, các lớp được khởi tạo ngẫu nhiên sẽ gây ra các cập nhật gradient rất lớn trong quá trình huấn luyện, điều này sẽ phá hủy các tính năng đã được huấn luyện trước của bạn.

Điều quan trọng là sử dụng tỷ lệ học tập rất thấp ở giai đoạn này, bởi vì bạn đang đào tạo một mô hình lớn hơn nhiều so với trong vòng đào tạo đầu tiên, trên một tập dữ liệu thường rất nhỏ. Do đó, bạn có nguy cơ bị quá sức rất nhanh nếu áp dụng các biện pháp cập nhật trọng lượng lớn. Ở đây, bạn chỉ muốn đọc các trọng số được huấn luyện trước theo cách tăng dần.

Đây là cách thực hiện tinh chỉnh toàn bộ mô hình cơ sở:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Lưu ý quan trọng về compile()trainable

Gọi compile() trên một mô hình có nghĩa là để "đóng băng" các hành vi của mô hình đó. Điều này ngụ ý rằng trainable các giá trị thuộc tính tại thời điểm mô hình được biên dịch cần được bảo tồn trong suốt cuộc đời của mô hình đó, cho đến khi compile được gọi là một lần nữa. Do đó, nếu bạn thay đổi bất kỳ trainable giá trị, hãy chắc chắn để gọi compile() một lần nữa vào mô hình của bạn cho những thay đổi của bạn sẽ được đưa vào tính toán.

Ghi chú quan trọng về BatchNormalization lớp

Nhiều mô hình hình ảnh chứa BatchNormalization lớp. Lớp đó là một trường hợp đặc biệt trên mọi số lượng có thể tưởng tượng được. Dưới đây là một số điều cần ghi nhớ.

  • BatchNormalization chứa 2 tạ phi khả năng huấn luyện mà có được cập nhật trong đào tạo. Đây là các biến theo dõi giá trị trung bình và phương sai của các yếu tố đầu vào.
  • Khi bạn thiết lập bn_layer.trainable = False , các BatchNormalization lớp sẽ chạy ở chế độ suy luận, và sẽ không cập nhật trung bình & sai thống kê của nó. Đây không phải là trường hợp cho các lớp khác nói chung, như trainability cân & suy luận / phương thức đào tạo là hai khái niệm trực giao . Nhưng hai được gắn trong trường hợp của BatchNormalization lớp.
  • Khi bạn unfreeze một mô hình có chứa BatchNormalization lớp để làm tinh chỉnh, bạn nên giữ BatchNormalization lớp trong chế độ suy luận bằng cách thông qua training=False khi gọi mô hình cơ sở. Nếu không, các bản cập nhật được áp dụng cho các trọng lượng không thể đào tạo sẽ đột ngột phá hủy những gì mà mô hình đã học được.

Bạn sẽ thấy mẫu này hoạt động trong ví dụ end-to-end ở cuối hướng dẫn này.

Chuyển giao việc học và tinh chỉnh với một vòng đào tạo tùy chỉnh

Nếu thay vì fit() , bạn đang sử dụng vòng lặp đào tạo ở mức độ thấp của riêng mình, việc nghỉ công việc về cơ bản giống nhau. Bạn nên cẩn thận để chỉ đưa vào tài khoản trong danh sách model.trainable_weights khi áp dụng bản cập nhật gradient:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Tương tự như vậy để tinh chỉnh.

Một ví dụ từ đầu đến cuối: tinh chỉnh mô hình phân loại hình ảnh trên tập dữ liệu mèo và chó

Để củng cố những khái niệm này, hãy hướng dẫn bạn qua một ví dụ cụ thể về học tập và tinh chỉnh chuyển giao từ đầu đến cuối. Chúng tôi sẽ tải mô hình Xception, được đào tạo trước trên ImageNet và sử dụng nó trên tập dữ liệu phân loại Kaggle "mèo so với chó".

Lấy dữ liệu

Đầu tiên, hãy tìm nạp tập dữ liệu mèo và chó bằng TFDS. Nếu bạn có dữ liệu riêng của bạn, có thể bạn sẽ muốn sử dụng tiện ích tf.keras.preprocessing.image_dataset_from_directory để tạo ra dữ liệu được dán nhãn tương tự như các đối tượng từ một tập hợp các hình ảnh trên đĩa nộp vào các thư mục đẳng cấp cụ thể.

Học chuyển giao hữu ích nhất khi làm việc với các tập dữ liệu rất nhỏ. Để giữ cho tập dữ liệu của chúng tôi nhỏ, chúng tôi sẽ sử dụng 40% dữ liệu đào tạo ban đầu (25.000 hình ảnh) để đào tạo, 10% để xác thực và 10% để kiểm tra.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Đây là 9 hình ảnh đầu tiên trong tập dữ liệu đào tạo - như bạn có thể thấy, chúng đều có kích thước khác nhau.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Chúng ta cũng có thể thấy rằng nhãn 1 là "chó" và nhãn 0 là "mèo".

Chuẩn hóa dữ liệu

Hình ảnh thô của chúng tôi có nhiều kích cỡ khác nhau. Ngoài ra, mỗi pixel bao gồm 3 giá trị nguyên từ 0 đến 255 (giá trị mức RGB). Đây không phải là một sự phù hợp tuyệt vời để nuôi mạng nơ-ron. Chúng ta cần làm 2 việc:

  • Chuẩn hóa thành kích thước hình ảnh cố định. Chúng tôi chọn 150x150.
  • Giá trị pixel Normalize giữa -1 và 1. Chúng tôi sẽ thực hiện điều này bằng cách sử dụng Normalization lớp như là một phần của mô hình riêng của mình.

Nói chung, bạn nên phát triển các mô hình lấy dữ liệu thô làm đầu vào, trái ngược với các mô hình lấy dữ liệu đã được xử lý trước. Lý do là, nếu mô hình của bạn yêu cầu dữ liệu được xử lý trước, bất kỳ khi nào bạn xuất mô hình của mình để sử dụng ở nơi khác (trong trình duyệt web, trong ứng dụng dành cho thiết bị di động), bạn sẽ cần thực hiện lại cùng một quy trình xử lý trước. Điều này trở nên rất phức tạp rất nhanh chóng. Vì vậy, chúng ta nên thực hiện ít tiền xử lý nhất có thể trước khi đưa vào mô hình.

Ở đây, chúng tôi sẽ thực hiện thay đổi kích thước hình ảnh trong đường ống dữ liệu (vì mạng nơ-ron sâu chỉ có thể xử lý các lô dữ liệu liền kề) và chúng tôi sẽ thực hiện điều chỉnh tỷ lệ giá trị đầu vào như một phần của mô hình, khi chúng tôi tạo nó.

Hãy thay đổi kích thước hình ảnh thành 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Bên cạnh đó, hãy tập hợp dữ liệu và sử dụng bộ nhớ đệm & tìm nạp trước để tối ưu hóa tốc độ tải.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Sử dụng tăng dữ liệu ngẫu nhiên

Khi bạn không có tập dữ liệu hình ảnh lớn, bạn nên đưa vào đa dạng mẫu một cách giả tạo bằng cách áp dụng các phép biến đổi ngẫu nhiên nhưng thực tế cho hình ảnh huấn luyện, chẳng hạn như lật ngang ngẫu nhiên hoặc các phép quay ngẫu nhiên nhỏ. Điều này giúp mô hình hiển thị các khía cạnh khác nhau của dữ liệu đào tạo trong khi làm chậm quá trình trang bị quá mức.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Hãy hình dung hình ảnh đầu tiên của lô đầu tiên trông như thế nào sau nhiều lần biến đổi ngẫu nhiên:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Xây dựng một mô hình

Bây giờ chúng ta hãy xây dựng một mô hình theo kế hoạch chi tiết mà chúng tôi đã giải thích trước đó.

Lưu ý rằng:

  • Chúng tôi thêm một Rescaling lớp để các giá trị đầu vào quy mô (ban đầu trong [0, 255] range) đến [-1, 1] phạm vi.
  • Chúng tôi thêm một Dropout lớp trước lớp phân loại, đối với quy tắc.
  • Chúng tôi đảm bảo để vượt qua training=False khi gọi mô hình cơ sở, do đó nó chạy trong chế độ suy luận, để thống kê batchnorm không được cập nhật ngay cả sau khi chúng tôi giải toả các mô hình cơ sở cho việc tinh chỉnh.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Đào tạo lớp trên cùng

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

Thực hiện một vòng tinh chỉnh toàn bộ mô hình

Cuối cùng, hãy giải phóng mô hình cơ sở và đào tạo toàn bộ mô hình từ đầu đến cuối với tỷ lệ học tập thấp.

Quan trọng hơn, mặc dù các mô hình cơ sở trở nên dễ huấn luyện, nó vẫn chạy trong chế độ suy luận vì chúng ta thông qua training=False khi gọi đó là khi chúng ta xây dựng mô hình. Điều này có nghĩa là các lớp chuẩn hóa hàng loạt bên trong sẽ không cập nhật thống kê hàng loạt của chúng. Nếu họ làm vậy, họ sẽ phá hủy các đại diện mà mô hình đã học cho đến nay.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

Sau 10 kỷ, việc tinh chỉnh mang lại cho chúng tôi một cải tiến tốt đẹp ở đây.