Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Tiến trình của mô hình có thể được lưu trong và sau khi đào tạo. Điều này có nghĩa là một mô hình có thể tiếp tục lại khi nó đã dừng lại và tránh thời gian đào tạo dài. Lưu cũng có nghĩa là bạn có thể chia sẻ mô hình của mình và những người khác có thể tạo lại tác phẩm của bạn. Khi xuất bản các mô hình và kỹ thuật nghiên cứu, hầu hết các học viên học máy đều chia sẻ:
- mã để tạo mô hình và
- trọng lượng hoặc thông số được đào tạo cho mô hình
Chia sẻ dữ liệu này giúp những người khác hiểu cách hoạt động của mô hình và tự mình thử với dữ liệu mới.
Tùy chọn
Có nhiều cách khác nhau để lưu các mô hình TensorFlow tùy thuộc vào API bạn đang sử dụng. Hướng dẫn này sử dụng tf.keras , một API cấp cao để xây dựng và đào tạo các mô hình trong TensorFlow. Đối với các cách tiếp cận khác, hãy xem hướng dẫn Lưu và khôi phục TensorFlow hoặc Lưu trong háo hức .
Thành lập
Cài đặt và nhập
Cài đặt và nhập TensorFlow và các phụ thuộc:
pip install pyyaml h5py # Required to save models in HDF5 format
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
2.8.0-rc1
Lấy một tập dữ liệu mẫu
Để trình bày cách lưu và tải trọng số, bạn sẽ sử dụng tập dữ liệu MNIST . Để tăng tốc những lần chạy này, hãy sử dụng 1000 ví dụ đầu tiên:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
Xác định một mô hình
Bắt đầu bằng cách xây dựng một mô hình tuần tự đơn giản:
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 dropout (Dropout) (None, 512) 0 dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
Lưu các điểm kiểm tra trong quá trình đào tạo
Bạn có thể sử dụng một mô hình đã được đào tạo mà không cần phải đào tạo lại hoặc tiếp tục đào tạo từ nơi bạn đã dừng lại trong trường hợp quá trình đào tạo bị gián đoạn. Lệnh gọi lại tf.keras.callbacks.ModelCheckpoint
cho phép bạn liên tục lưu mô hình cả trong và khi kết thúc đào tạo.
Sử dụng cuộc gọi lại của Checkpoint
Tạo lệnh gọi lại tf.keras.callbacks.ModelCheckpoint
để tiết kiệm trọng lượng chỉ trong quá trình đào tạo:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10 23/32 [====================>.........] - ETA: 0s - loss: 1.3666 - sparse_categorical_accuracy: 0.6060 Epoch 1: saving model to training_1/cp.ckpt 32/32 [==============================] - 1s 10ms/step - loss: 1.1735 - sparse_categorical_accuracy: 0.6690 - val_loss: 0.7180 - val_sparse_categorical_accuracy: 0.7750 Epoch 2/10 24/32 [=====================>........] - ETA: 0s - loss: 0.4238 - sparse_categorical_accuracy: 0.8789 Epoch 2: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.4201 - sparse_categorical_accuracy: 0.8810 - val_loss: 0.5621 - val_sparse_categorical_accuracy: 0.8150 Epoch 3/10 24/32 [=====================>........] - ETA: 0s - loss: 0.2795 - sparse_categorical_accuracy: 0.9336 Epoch 3: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.4790 - val_sparse_categorical_accuracy: 0.8430 Epoch 4/10 24/32 [=====================>........] - ETA: 0s - loss: 0.2027 - sparse_categorical_accuracy: 0.9427 Epoch 4: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.2016 - sparse_categorical_accuracy: 0.9440 - val_loss: 0.4361 - val_sparse_categorical_accuracy: 0.8610 Epoch 5/10 24/32 [=====================>........] - ETA: 0s - loss: 0.1739 - sparse_categorical_accuracy: 0.9583 Epoch 5: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1683 - sparse_categorical_accuracy: 0.9610 - val_loss: 0.4640 - val_sparse_categorical_accuracy: 0.8580 Epoch 6/10 23/32 [====================>.........] - ETA: 0s - loss: 0.1116 - sparse_categorical_accuracy: 0.9796 Epoch 6: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1125 - sparse_categorical_accuracy: 0.9780 - val_loss: 0.4420 - val_sparse_categorical_accuracy: 0.8580 Epoch 7/10 24/32 [=====================>........] - ETA: 0s - loss: 0.0978 - sparse_categorical_accuracy: 0.9831 Epoch 7: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.4163 - val_sparse_categorical_accuracy: 0.8590 Epoch 8/10 21/32 [==================>...........] - ETA: 0s - loss: 0.0669 - sparse_categorical_accuracy: 0.9911 Epoch 8: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 6ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4411 - val_sparse_categorical_accuracy: 0.8600 Epoch 9/10 22/32 [===================>..........] - ETA: 0s - loss: 0.0495 - sparse_categorical_accuracy: 0.9972 Epoch 9: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0516 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4064 - val_sparse_categorical_accuracy: 0.8650 Epoch 10/10 24/32 [=====================>........] - ETA: 0s - loss: 0.0436 - sparse_categorical_accuracy: 0.9948 Epoch 10: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0437 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4061 - val_sparse_categorical_accuracy: 0.8770 <keras.callbacks.History at 0x7eff8d865390>
Điều này tạo ra một bộ sưu tập các tệp điểm kiểm tra TensorFlow được cập nhật vào cuối mỗi kỷ nguyên:
os.listdir(checkpoint_dir)
['checkpoint', 'cp.ckpt.index', 'cp.ckpt.data-00000-of-00001']
Miễn là hai mô hình có cùng kiến trúc, bạn có thể chia sẻ trọng số giữa chúng. Vì vậy, khi khôi phục một mô hình từ chỉ trọng số, hãy tạo một mô hình có cùng kiến trúc với mô hình ban đầu và sau đó đặt trọng số của nó.
Bây giờ xây dựng lại một mô hình mới, chưa qua đào tạo và đánh giá nó trên bộ thử nghiệm. Một mô hình chưa được đào tạo sẽ hoạt động ở mức cơ hội (độ chính xác ~ 10%):
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.4473 - sparse_categorical_accuracy: 0.0980 - 145ms/epoch - 5ms/step Untrained model, accuracy: 9.80%
Sau đó tải trọng lượng từ điểm kiểm tra và đánh giá lại:
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4061 - sparse_categorical_accuracy: 0.8770 - 65ms/epoch - 2ms/step Restored model, accuracy: 87.70%
Tùy chọn gọi lại điểm kiểm tra
Lệnh gọi lại cung cấp một số tùy chọn để cung cấp tên duy nhất cho các điểm kiểm tra và điều chỉnh tần suất điểm kiểm tra.
Đào tạo một mô hình mới và lưu các điểm kiểm tra được đặt tên duy nhất cứ sau năm kỷ nguyên một lần:
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*batch_size)
# Create a new model instance
model = create_model()
# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt Epoch 10: saving model to training_2/cp-0010.ckpt Epoch 15: saving model to training_2/cp-0015.ckpt Epoch 20: saving model to training_2/cp-0020.ckpt Epoch 25: saving model to training_2/cp-0025.ckpt Epoch 30: saving model to training_2/cp-0030.ckpt Epoch 35: saving model to training_2/cp-0035.ckpt Epoch 40: saving model to training_2/cp-0040.ckpt Epoch 45: saving model to training_2/cp-0045.ckpt Epoch 50: saving model to training_2/cp-0050.ckpt <keras.callbacks.History at 0x7eff807703d0>
Bây giờ, hãy xem các điểm kiểm tra kết quả và chọn điểm kiểm tra mới nhất:
os.listdir(checkpoint_dir)
['cp-0005.ckpt.data-00000-of-00001', 'cp-0050.ckpt.index', 'checkpoint', 'cp-0010.ckpt.index', 'cp-0035.ckpt.data-00000-of-00001', 'cp-0000.ckpt.data-00000-of-00001', 'cp-0050.ckpt.data-00000-of-00001', 'cp-0010.ckpt.data-00000-of-00001', 'cp-0020.ckpt.data-00000-of-00001', 'cp-0035.ckpt.index', 'cp-0040.ckpt.index', 'cp-0025.ckpt.data-00000-of-00001', 'cp-0045.ckpt.index', 'cp-0020.ckpt.index', 'cp-0025.ckpt.index', 'cp-0030.ckpt.data-00000-of-00001', 'cp-0030.ckpt.index', 'cp-0000.ckpt.index', 'cp-0045.ckpt.data-00000-of-00001', 'cp-0015.ckpt.index', 'cp-0015.ckpt.data-00000-of-00001', 'cp-0005.ckpt.index', 'cp-0040.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'
Để kiểm tra, hãy đặt lại mô hình và tải điểm kiểm tra mới nhất:
# Create a new model instance
model = create_model()
# Load the previously saved weights
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 150ms/epoch - 5ms/step Restored model, accuracy: 87.70%
Những tập tin này là gì?
Đoạn mã trên lưu trữ các trọng số vào một tập hợp các tệp được định dạng điểm kiểm tra chỉ chứa các trọng số được huấn luyện ở định dạng nhị phân. Các trạm kiểm soát bao gồm:
- Một hoặc nhiều phân đoạn chứa trọng lượng mô hình của bạn.
- Một tệp chỉ mục cho biết trọng số nào được lưu trữ trong phân đoạn nào.
Nếu bạn đang đào tạo một mô hình trên một máy duy nhất, bạn sẽ có một phân đoạn với hậu tố: .data-00000-of-00001
Lưu trọng lượng theo cách thủ công
Lưu trọng số theo cách thủ công với phương pháp Model.save_weights
. Theo mặc định, tf.keras
—và cụ thể là save_weights
— sử dụng định dạng điểm kiểm tra TensorFlow với phần mở rộng .ckpt
(lưu ở HDF5 với phần mở rộng .h5
được đề cập trong hướng dẫn mô hình Lưu và tuần tự hóa ):
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 143ms/epoch - 4ms/step Restored model, accuracy: 87.70%
Lưu toàn bộ mô hình
Gọi model.save
để lưu kiến trúc, trọng số và cấu hình đào tạo của mô hình trong một tệp / thư mục duy nhất. Điều này cho phép bạn xuất một mô hình để nó có thể được sử dụng mà không cần truy cập vào mã Python gốc *. Vì trạng thái trình tối ưu hóa được khôi phục, bạn có thể tiếp tục đào tạo từ chính xác nơi bạn đã dừng lại.
Toàn bộ mô hình có thể được lưu ở hai định dạng tệp khác nhau ( SavedModel
và HDF5
). Định dạng TensorFlow SavedModel
là định dạng tệp mặc định trong TF2.x. Tuy nhiên, các mô hình có thể được lưu ở định dạng HDF5
. Chi tiết hơn về cách lưu toàn bộ mô hình ở hai định dạng tệp được mô tả bên dưới.
Lưu một mô hình đầy đủ chức năng là rất hữu ích — bạn có thể tải chúng trong TensorFlow.js ( Mô hình đã lưu , HDF5 ), sau đó đào tạo và chạy chúng trong trình duyệt web hoặc chuyển đổi chúng để chạy trên thiết bị di động bằng TensorFlow Lite ( Mô hình đã lưu , HDF5 )
* Các đối tượng tùy chỉnh (ví dụ như các mô hình hoặc lớp phân lớp) yêu cầu đặc biệt chú ý khi lưu và tải. Xem phần Lưu đối tượng tùy chỉnh bên dưới
Định dạng SavedModel
Định dạng SavedModel là một cách khác để tuần tự hóa các mô hình. Các mô hình đã lưu ở định dạng này có thể được khôi phục bằng tf.keras.models.load_model
và tương thích với TensorFlow Serving. Hướng dẫn SavedModel đi sâu vào chi tiết về cách phân phối / kiểm tra SavedModel. Phần bên dưới minh họa các bước để lưu và khôi phục mô hình.
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1988 - sparse_categorical_accuracy: 0.6550 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4180 - sparse_categorical_accuracy: 0.8930 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2900 - sparse_categorical_accuracy: 0.9220 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9540 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1593 - sparse_categorical_accuracy: 0.9630 2022-01-26 07:30:22.888387: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate INFO:tensorflow:Assets written to: saved_model/my_model/assets
Định dạng SavedModel là một thư mục chứa tệp nhị phân protobuf và điểm kiểm tra TensorFlow. Kiểm tra thư mục mô hình đã lưu:
# my_model directory
ls saved_model
# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model assets keras_metadata.pb saved_model.pb variables
Tải lại mô hình Keras mới từ mô hình đã lưu:
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_10 (Dense) (None, 512) 401920 dropout_5 (Dropout) (None, 512) 0 dense_11 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
Mô hình khôi phục được biên dịch với các đối số giống như mô hình ban đầu. Thử chạy đánh giá và dự đoán với mô hình đã tải:
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4577 - sparse_categorical_accuracy: 0.8430 - 156ms/epoch - 5ms/step Restored model, accuracy: 84.30% (1000, 10)
Định dạng HDF5
Keras cung cấp một định dạng lưu cơ bản sử dụng tiêu chuẩn HDF5 .
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1383 - sparse_categorical_accuracy: 0.6970 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4094 - sparse_categorical_accuracy: 0.8920 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2936 - sparse_categorical_accuracy: 0.9160 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2050 - sparse_categorical_accuracy: 0.9460 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1485 - sparse_categorical_accuracy: 0.9690
Bây giờ, hãy tạo lại mô hình từ tệp đó:
# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')
# Show the model architecture
new_model.summary()
Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_12 (Dense) (None, 512) 401920 dropout_6 (Dropout) (None, 512) 0 dense_13 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
Kiểm tra độ chính xác của nó:
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4266 - sparse_categorical_accuracy: 0.8620 - 141ms/epoch - 4ms/step Restored model, accuracy: 86.20%
Keras lưu các mô hình bằng cách kiểm tra kiến trúc của chúng. Kỹ thuật này tiết kiệm mọi thứ:
- Các giá trị trọng lượng
- Kiến trúc của mô hình
- Cấu hình đào tạo của mô hình (những gì bạn truyền cho phương thức
.compile()
) - Trình tối ưu hóa và trạng thái của nó, nếu có (điều này cho phép bạn bắt đầu lại quá trình đào tạo từ nơi bạn đã dừng lại)
Keras không thể lưu các trình tối ưu hóa v1.x
(từ tf.compat.v1.train
) vì chúng không tương thích với các trạm kiểm soát. Đối với trình tối ưu hóa v1.x, bạn cần phải biên dịch lại mô hình sau khi tải — làm mất trạng thái của trình tối ưu hóa.
Lưu các đối tượng tùy chỉnh
Nếu bạn đang sử dụng định dạng SavedModel, bạn có thể bỏ qua phần này. Sự khác biệt chính giữa HDF5 và SavedModel là HDF5 sử dụng cấu hình đối tượng để lưu kiến trúc mô hình, trong khi SavedModel lưu đồ thị thực thi. Do đó, SavedModels có thể lưu các đối tượng tùy chỉnh như mô hình lớp con và lớp tùy chỉnh mà không yêu cầu mã gốc.
Để lưu các đối tượng tùy chỉnh vào HDF5, bạn phải làm như sau:
- Xác định phương thức
get_config
trong đối tượng của bạn và tùy chọn mộtfrom_config
.-
get_config(self)
trả về một từ điển JSON-serializable các tham số cần thiết để tạo lại đối tượng. -
from_config(cls, config)
sử dụng cấu hình trả về từget_config
để tạo một đối tượng mới. Theo mặc định, hàm này sẽ sử dụng cấu hình làm kwargs khởi tạo (return cls(**config)
).
-
- Chuyển đối tượng đến đối số
custom_objects
khi tải mô hình. Đối số phải là một từ điển ánh xạ tên lớp chuỗi với lớp Python. Ví dụ:tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
Xem hướng dẫn Viết các lớp và mô hình từ đầu để biết các ví dụ về các đối tượng tùy chỉnh và get_config
.
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.