Lưu và tải mô hình

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Tiến trình của mô hình có thể được lưu trong và sau khi đào tạo. Điều này có nghĩa là một mô hình có thể tiếp tục lại khi nó đã dừng lại và tránh thời gian đào tạo dài. Lưu cũng có nghĩa là bạn có thể chia sẻ mô hình của mình và những người khác có thể tạo lại tác phẩm của bạn. Khi xuất bản các mô hình và kỹ thuật nghiên cứu, hầu hết các học viên học máy đều chia sẻ:

  • mã để tạo mô hình và
  • trọng lượng hoặc thông số được đào tạo cho mô hình

Chia sẻ dữ liệu này giúp những người khác hiểu cách hoạt động của mô hình và tự mình thử với dữ liệu mới.

Tùy chọn

Có nhiều cách khác nhau để lưu các mô hình TensorFlow tùy thuộc vào API bạn đang sử dụng. Hướng dẫn này sử dụng tf.keras , một API cấp cao để xây dựng và đào tạo các mô hình trong TensorFlow. Đối với các cách tiếp cận khác, hãy xem hướng dẫn Lưu và khôi phục TensorFlow hoặc Lưu trong háo hức .

Thành lập

Cài đặt và nhập

Cài đặt và nhập TensorFlow và các phụ thuộc:

pip install pyyaml h5py  # Required to save models in HDF5 format
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2.8.0-rc1

Lấy một tập dữ liệu mẫu

Để trình bày cách lưu và tải trọng số, bạn sẽ sử dụng tập dữ liệu MNIST . Để tăng tốc những lần chạy này, hãy sử dụng 1000 ví dụ đầu tiên:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

Xác định một mô hình

Bắt đầu bằng cách xây dựng một mô hình tuần tự đơn giản:

# Define a simple sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Lưu các điểm kiểm tra trong quá trình đào tạo

Bạn có thể sử dụng một mô hình đã được đào tạo mà không cần phải đào tạo lại hoặc tiếp tục đào tạo từ nơi bạn đã dừng lại trong trường hợp quá trình đào tạo bị gián đoạn. Lệnh gọi lại tf.keras.callbacks.ModelCheckpoint cho phép bạn liên tục lưu mô hình cả trong và khi kết thúc đào tạo.

Sử dụng cuộc gọi lại của Checkpoint

Tạo lệnh gọi lại tf.keras.callbacks.ModelCheckpoint để tiết kiệm trọng lượng chỉ trong quá trình đào tạo:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
23/32 [====================>.........] - ETA: 0s - loss: 1.3666 - sparse_categorical_accuracy: 0.6060 
Epoch 1: saving model to training_1/cp.ckpt
32/32 [==============================] - 1s 10ms/step - loss: 1.1735 - sparse_categorical_accuracy: 0.6690 - val_loss: 0.7180 - val_sparse_categorical_accuracy: 0.7750
Epoch 2/10
24/32 [=====================>........] - ETA: 0s - loss: 0.4238 - sparse_categorical_accuracy: 0.8789
Epoch 2: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.4201 - sparse_categorical_accuracy: 0.8810 - val_loss: 0.5621 - val_sparse_categorical_accuracy: 0.8150
Epoch 3/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2795 - sparse_categorical_accuracy: 0.9336
Epoch 3: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.4790 - val_sparse_categorical_accuracy: 0.8430
Epoch 4/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2027 - sparse_categorical_accuracy: 0.9427
Epoch 4: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2016 - sparse_categorical_accuracy: 0.9440 - val_loss: 0.4361 - val_sparse_categorical_accuracy: 0.8610
Epoch 5/10
24/32 [=====================>........] - ETA: 0s - loss: 0.1739 - sparse_categorical_accuracy: 0.9583
Epoch 5: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1683 - sparse_categorical_accuracy: 0.9610 - val_loss: 0.4640 - val_sparse_categorical_accuracy: 0.8580
Epoch 6/10
23/32 [====================>.........] - ETA: 0s - loss: 0.1116 - sparse_categorical_accuracy: 0.9796
Epoch 6: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1125 - sparse_categorical_accuracy: 0.9780 - val_loss: 0.4420 - val_sparse_categorical_accuracy: 0.8580
Epoch 7/10
24/32 [=====================>........] - ETA: 0s - loss: 0.0978 - sparse_categorical_accuracy: 0.9831
Epoch 7: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.4163 - val_sparse_categorical_accuracy: 0.8590
Epoch 8/10
21/32 [==================>...........] - ETA: 0s - loss: 0.0669 - sparse_categorical_accuracy: 0.9911
Epoch 8: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4411 - val_sparse_categorical_accuracy: 0.8600
Epoch 9/10
22/32 [===================>..........] - ETA: 0s - loss: 0.0495 - sparse_categorical_accuracy: 0.9972
Epoch 9: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0516 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4064 - val_sparse_categorical_accuracy: 0.8650
Epoch 10/10
24/32 [=====================>........] - ETA: 0s - loss: 0.0436 - sparse_categorical_accuracy: 0.9948
Epoch 10: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0437 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4061 - val_sparse_categorical_accuracy: 0.8770
<keras.callbacks.History at 0x7eff8d865390>

Điều này tạo ra một bộ sưu tập các tệp điểm kiểm tra TensorFlow được cập nhật vào cuối mỗi kỷ nguyên:

os.listdir(checkpoint_dir)
['checkpoint', 'cp.ckpt.index', 'cp.ckpt.data-00000-of-00001']

Miễn là hai mô hình có cùng kiến ​​trúc, bạn có thể chia sẻ trọng số giữa chúng. Vì vậy, khi khôi phục một mô hình từ chỉ trọng số, hãy tạo một mô hình có cùng kiến ​​trúc với mô hình ban đầu và sau đó đặt trọng số của nó.

Bây giờ xây dựng lại một mô hình mới, chưa qua đào tạo và đánh giá nó trên bộ thử nghiệm. Một mô hình chưa được đào tạo sẽ hoạt động ở mức cơ hội (độ chính xác ~ 10%):

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.4473 - sparse_categorical_accuracy: 0.0980 - 145ms/epoch - 5ms/step
Untrained model, accuracy:  9.80%

Sau đó tải trọng lượng từ điểm kiểm tra và đánh giá lại:

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4061 - sparse_categorical_accuracy: 0.8770 - 65ms/epoch - 2ms/step
Restored model, accuracy: 87.70%

Tùy chọn gọi lại điểm kiểm tra

Lệnh gọi lại cung cấp một số tùy chọn để cung cấp tên duy nhất cho các điểm kiểm tra và điều chỉnh tần suất điểm kiểm tra.

Đào tạo một mô hình mới và lưu các điểm kiểm tra được đặt tên duy nhất cứ sau năm kỷ nguyên một lần:

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*batch_size)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt

Epoch 10: saving model to training_2/cp-0010.ckpt

Epoch 15: saving model to training_2/cp-0015.ckpt

Epoch 20: saving model to training_2/cp-0020.ckpt

Epoch 25: saving model to training_2/cp-0025.ckpt

Epoch 30: saving model to training_2/cp-0030.ckpt

Epoch 35: saving model to training_2/cp-0035.ckpt

Epoch 40: saving model to training_2/cp-0040.ckpt

Epoch 45: saving model to training_2/cp-0045.ckpt

Epoch 50: saving model to training_2/cp-0050.ckpt
<keras.callbacks.History at 0x7eff807703d0>

Bây giờ, hãy xem các điểm kiểm tra kết quả và chọn điểm kiểm tra mới nhất:

os.listdir(checkpoint_dir)
['cp-0005.ckpt.data-00000-of-00001',
 'cp-0050.ckpt.index',
 'checkpoint',
 'cp-0010.ckpt.index',
 'cp-0035.ckpt.data-00000-of-00001',
 'cp-0000.ckpt.data-00000-of-00001',
 'cp-0050.ckpt.data-00000-of-00001',
 'cp-0010.ckpt.data-00000-of-00001',
 'cp-0020.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.index',
 'cp-0040.ckpt.index',
 'cp-0025.ckpt.data-00000-of-00001',
 'cp-0045.ckpt.index',
 'cp-0020.ckpt.index',
 'cp-0025.ckpt.index',
 'cp-0030.ckpt.data-00000-of-00001',
 'cp-0030.ckpt.index',
 'cp-0000.ckpt.index',
 'cp-0045.ckpt.data-00000-of-00001',
 'cp-0015.ckpt.index',
 'cp-0015.ckpt.data-00000-of-00001',
 'cp-0005.ckpt.index',
 'cp-0040.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

Để kiểm tra, hãy đặt lại mô hình và tải điểm kiểm tra mới nhất:

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 150ms/epoch - 5ms/step
Restored model, accuracy: 87.70%

Những tập tin này là gì?

Đoạn mã trên lưu trữ các trọng số vào một tập hợp các tệp được định dạng điểm kiểm tra chỉ chứa các trọng số được huấn luyện ở định dạng nhị phân. Các trạm kiểm soát bao gồm:

  • Một hoặc nhiều phân đoạn chứa trọng lượng mô hình của bạn.
  • Một tệp chỉ mục cho biết trọng số nào được lưu trữ trong phân đoạn nào.

Nếu bạn đang đào tạo một mô hình trên một máy duy nhất, bạn sẽ có một phân đoạn với hậu tố: .data-00000-of-00001

Lưu trọng lượng theo cách thủ công

Lưu trọng số theo cách thủ công với phương pháp Model.save_weights . Theo mặc định, tf.keras —và cụ thể là save_weights — sử dụng định dạng điểm kiểm tra TensorFlow với phần mở rộng .ckpt (lưu ở HDF5 với phần mở rộng .h5 được đề cập trong hướng dẫn mô hình Lưu và tuần tự hóa ):

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 143ms/epoch - 4ms/step
Restored model, accuracy: 87.70%

Lưu toàn bộ mô hình

Gọi model.save để lưu kiến ​​trúc, trọng số và cấu hình đào tạo của mô hình trong một tệp / thư mục duy nhất. Điều này cho phép bạn xuất một mô hình để nó có thể được sử dụng mà không cần truy cập vào mã Python gốc *. Vì trạng thái trình tối ưu hóa được khôi phục, bạn có thể tiếp tục đào tạo từ chính xác nơi bạn đã dừng lại.

Toàn bộ mô hình có thể được lưu ở hai định dạng tệp khác nhau ( SavedModelHDF5 ). Định dạng TensorFlow SavedModel là định dạng tệp mặc định trong TF2.x. Tuy nhiên, các mô hình có thể được lưu ở định dạng HDF5 . Chi tiết hơn về cách lưu toàn bộ mô hình ở hai định dạng tệp được mô tả bên dưới.

Lưu một mô hình đầy đủ chức năng là rất hữu ích — bạn có thể tải chúng trong TensorFlow.js ( Mô hình đã lưu , HDF5 ), sau đó đào tạo và chạy chúng trong trình duyệt web hoặc chuyển đổi chúng để chạy trên thiết bị di động bằng TensorFlow Lite ( Mô hình đã lưu , HDF5 )

* Các đối tượng tùy chỉnh (ví dụ như các mô hình hoặc lớp phân lớp) yêu cầu đặc biệt chú ý khi lưu và tải. Xem phần Lưu đối tượng tùy chỉnh bên dưới

Định dạng SavedModel

Định dạng SavedModel là một cách khác để tuần tự hóa các mô hình. Các mô hình đã lưu ở định dạng này có thể được khôi phục bằng tf.keras.models.load_model và tương thích với TensorFlow Serving. Hướng dẫn SavedModel đi sâu vào chi tiết về cách phân phối / kiểm tra SavedModel. Phần bên dưới minh họa các bước để lưu và khôi phục mô hình.

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1988 - sparse_categorical_accuracy: 0.6550
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4180 - sparse_categorical_accuracy: 0.8930
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2900 - sparse_categorical_accuracy: 0.9220
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9540
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1593 - sparse_categorical_accuracy: 0.9630
2022-01-26 07:30:22.888387: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
INFO:tensorflow:Assets written to: saved_model/my_model/assets

Định dạng SavedModel là một thư mục chứa tệp nhị phân protobuf và điểm kiểm tra TensorFlow. Kiểm tra thư mục mô hình đã lưu:

# my_model directory
ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model
assets  keras_metadata.pb  saved_model.pb  variables

Tải lại mô hình Keras mới từ mô hình đã lưu:

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_10 (Dense)            (None, 512)               401920    
                                                                 
 dropout_5 (Dropout)         (None, 512)               0         
                                                                 
 dense_11 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Mô hình khôi phục được biên dịch với các đối số giống như mô hình ban đầu. Thử chạy đánh giá và dự đoán với mô hình đã tải:

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4577 - sparse_categorical_accuracy: 0.8430 - 156ms/epoch - 5ms/step
Restored model, accuracy: 84.30%
(1000, 10)

Định dạng HDF5

Keras cung cấp một định dạng lưu cơ bản sử dụng tiêu chuẩn HDF5 .

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1383 - sparse_categorical_accuracy: 0.6970
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4094 - sparse_categorical_accuracy: 0.8920
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2936 - sparse_categorical_accuracy: 0.9160
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2050 - sparse_categorical_accuracy: 0.9460
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1485 - sparse_categorical_accuracy: 0.9690

Bây giờ, hãy tạo lại mô hình từ tệp đó:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 512)               401920    
                                                                 
 dropout_6 (Dropout)         (None, 512)               0         
                                                                 
 dense_13 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Kiểm tra độ chính xác của nó:

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4266 - sparse_categorical_accuracy: 0.8620 - 141ms/epoch - 4ms/step
Restored model, accuracy: 86.20%

Keras lưu các mô hình bằng cách kiểm tra kiến ​​trúc của chúng. Kỹ thuật này tiết kiệm mọi thứ:

  • Các giá trị trọng lượng
  • Kiến trúc của mô hình
  • Cấu hình đào tạo của mô hình (những gì bạn truyền cho phương thức .compile() )
  • Trình tối ưu hóa và trạng thái của nó, nếu có (điều này cho phép bạn bắt đầu lại quá trình đào tạo từ nơi bạn đã dừng lại)

Keras không thể lưu các trình tối ưu hóa v1.x (từ tf.compat.v1.train ) vì chúng không tương thích với các trạm kiểm soát. Đối với trình tối ưu hóa v1.x, bạn cần phải biên dịch lại mô hình sau khi tải — làm mất trạng thái của trình tối ưu hóa.

Lưu các đối tượng tùy chỉnh

Nếu bạn đang sử dụng định dạng SavedModel, bạn có thể bỏ qua phần này. Sự khác biệt chính giữa HDF5 và SavedModel là HDF5 sử dụng cấu hình đối tượng để lưu kiến ​​trúc mô hình, trong khi SavedModel lưu đồ thị thực thi. Do đó, SavedModels có thể lưu các đối tượng tùy chỉnh như mô hình lớp con và lớp tùy chỉnh mà không yêu cầu mã gốc.

Để lưu các đối tượng tùy chỉnh vào HDF5, bạn phải làm như sau:

  1. Xác định phương thức get_config trong đối tượng của bạn và tùy chọn một from_config .
    • get_config(self) trả về một từ điển JSON-serializable các tham số cần thiết để tạo lại đối tượng.
    • from_config(cls, config) sử dụng cấu hình trả về từ get_config để tạo một đối tượng mới. Theo mặc định, hàm này sẽ sử dụng cấu hình làm kwargs khởi tạo ( return cls(**config) ).
  2. Chuyển đối tượng đến đối số custom_objects khi tải mô hình. Đối số phải là một từ điển ánh xạ tên lớp chuỗi với lớp Python. Ví dụ: tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})

Xem hướng dẫn Viết các lớp và mô hình từ đầu để biết các ví dụ về các đối tượng tùy chỉnh và get_config .

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.