ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
การบันทึกน้ำหนัก/พารามิเตอร์รุ่นหรือรุ่น "ดีที่สุด" อย่างต่อเนื่องมีประโยชน์มากมาย ซึ่งรวมถึงความสามารถในการติดตามความคืบหน้าการฝึกอบรมและโหลดแบบจำลองที่บันทึกไว้จากสถานะต่างๆ ที่บันทึกไว้
ใน TensorFlow 1 เพื่อกำหนดค่าการบันทึกจุดตรวจสอบระหว่างการฝึก/การตรวจสอบด้วย tf.estimator.Estimator
API คุณระบุกำหนดการใน tf.estimator.RunConfig
หรือใช้ tf.estimator.CheckpointSaverHook
คู่มือนี้สาธิตวิธีการย้ายจากเวิร์กโฟลว์นี้ไปยัง TensorFlow 2 Keras API
ใน TensorFlow 2 คุณสามารถกำหนดค่า tf.keras.callbacks.ModelCheckpoint
ได้หลายวิธี:
- บันทึกเวอร์ชัน "ดีที่สุด" ตามเมตริกที่มอนิเตอร์โดยใช้พารามิเตอร์
save_best_only=True
โดยที่monitor
สามารถเป็นได้ เช่น'loss'
,'val_loss'
,'accuracy', or
'val_accuracy'` - บันทึกอย่างต่อเนื่องที่ความถี่ที่แน่นอน (โดยใช้อาร์กิวเมนต์
save_freq
) - บันทึกน้ำหนัก/พารามิเตอร์เท่านั้นแทนทั้งโมเดลโดยการตั้งค่า
save_weights_only
เป็นTrue
สำหรับรายละเอียดเพิ่มเติม โปรดดูเอกสาร tf.keras.callbacks.ModelCheckpoint
API และหัวข้อ Save checkpoints between training ในบทช่วยสอน บันทึกและโหลดโมเดล เรียนรู้เพิ่มเติมเกี่ยวกับรูปแบบจุดตรวจสอบในส่วน รูปแบบจุดตรวจ TF ในคู่มือ บันทึกและโหลดแบบจำลอง Keras นอกจากนี้ เพื่อเพิ่มความทนทานต่อข้อผิดพลาด คุณสามารถใช้ tf.keras.callbacks.BackupAndRestore
หรือ tf.train.Checkpoint
สำหรับการตรวจสอบด้วยตนเอง เรียนรู้เพิ่มเติมใน คู่มือการย้ายข้อมูลการยอมรับข้อผิดพลาด
Keras callbacks เป็นอ็อบเจ็กต์ที่ถูกเรียกที่จุดต่างๆ ระหว่างการฝึก/การประเมิน/การทำนายใน Keras Model.fit
/ Model.evaluate
/ Model.predict
APIs ในตัว เรียนรู้เพิ่มเติมในส่วน ขั้นตอนถัดไป ที่ส่วนท้ายของคำแนะนำ
ติดตั้ง
เริ่มต้นด้วยการนำเข้าและชุดข้อมูลอย่างง่ายเพื่อการสาธิต:
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step 11501568/11490434 [==============================] - 0s 0us/step
TensorFlow 1: บันทึกจุดตรวจด้วย tf.estimator APIs
ตัวอย่าง TensorFlow 1 นี้แสดงวิธีกำหนดค่า tf.estimator.RunConfig
เพื่อบันทึกจุดตรวจสอบในทุกขั้นตอนระหว่างการฝึก/ประเมินผลด้วย tf.estimator.Estimator
APIs:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmplrkjo9in', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmp/ipykernel_20296/3980459272.py:18: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead. WARNING:tensorflow:From /tmp/ipykernel_20296/3980459272.py:18: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 1 or save_checkpoints_secs None. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:65: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/monitored_session.py:914: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the `tf.data` module. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:47 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-1 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26374s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:47 INFO:tensorflow:Saving dict for global step 1: accuracy = 0.1765625, average_loss = 2.2546134, global_step = 1, loss = 288.5905 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1: /tmp/tmplrkjo9in/model.ckpt-1 INFO:tensorflow:loss = 118.3231, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:48 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-2 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.36662s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:48 INFO:tensorflow:Saving dict for global step 2: accuracy = 0.2859375, average_loss = 2.1868849, global_step = 2, loss = 279.92126 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2: /tmp/tmplrkjo9in/model.ckpt-2 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:48 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22792s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:48 INFO:tensorflow:Saving dict for global step 3: accuracy = 0.35078126, average_loss = 2.1220195, global_step = 3, loss = 271.6185 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmp/tmplrkjo9in/model.ckpt-3 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4... INFO:tensorflow:Saving checkpoints for 4 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:49 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-4 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22387s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:49 INFO:tensorflow:Saving dict for global step 4: accuracy = 0.40234375, average_loss = 2.0655982, global_step = 4, loss = 264.39658 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 4: /tmp/tmplrkjo9in/model.ckpt-4 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5... INFO:tensorflow:Saving checkpoints for 5 into /tmp/tmplrkjo9in/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1054: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:49 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-5 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22548s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:49 INFO:tensorflow:Saving dict for global step 5: accuracy = 0.42421874, average_loss = 2.0072064, global_step = 5, loss = 256.92242 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5: /tmp/tmplrkjo9in/model.ckpt-5 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6... INFO:tensorflow:Saving checkpoints for 6 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:50 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-6 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22806s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:50 INFO:tensorflow:Saving dict for global step 6: accuracy = 0.43984374, average_loss = 1.9473753, global_step = 6, loss = 249.26404 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6: /tmp/tmplrkjo9in/model.ckpt-6 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7... INFO:tensorflow:Saving checkpoints for 7 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:50 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-7 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.23091s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:50 INFO:tensorflow:Saving dict for global step 7: accuracy = 0.44296876, average_loss = 1.8903366, global_step = 7, loss = 241.96309 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7: /tmp/tmplrkjo9in/model.ckpt-7 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8... INFO:tensorflow:Saving checkpoints for 8 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:51 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-8 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22453s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:51 INFO:tensorflow:Saving dict for global step 8: accuracy = 0.44453126, average_loss = 1.8294731, global_step = 8, loss = 234.17256 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8: /tmp/tmplrkjo9in/model.ckpt-8 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9... INFO:tensorflow:Saving checkpoints for 9 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:51 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-9 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.22271s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:51 INFO:tensorflow:Saving dict for global step 9: accuracy = 0.47734374, average_loss = 1.7674354, global_step = 9, loss = 226.23174 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9: /tmp/tmplrkjo9in/model.ckpt-9 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmplrkjo9in/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-14T02:28:52 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmplrkjo9in/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.38483s INFO:tensorflow:Finished evaluation at 2022-01-14-02:28:52 INFO:tensorflow:Saving dict for global step 10: accuracy = 0.5140625, average_loss = 1.7108486, global_step = 10, loss = 218.98862 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmplrkjo9in/model.ckpt-10 INFO:tensorflow:Loss for final step: 96.2236. ({'accuracy': 0.5140625, 'average_loss': 1.7108486, 'loss': 218.98862, 'global_step': 10}, [])
%ls {classifier.model_dir}
checkpoint eval/ events.out.tfevents.1642127326.kokoro-gcp-ubuntu-prod-837339153 graph.pbtxt model.ckpt-10.data-00000-of-00001 model.ckpt-10.index model.ckpt-10.meta model.ckpt-6.data-00000-of-00001 model.ckpt-6.index model.ckpt-6.meta model.ckpt-7.data-00000-of-00001 model.ckpt-7.index model.ckpt-7.meta model.ckpt-8.data-00000-of-00001 model.ckpt-8.index model.ckpt-8.meta model.ckpt-9.data-00000-of-00001 model.ckpt-9.index model.ckpt-9.meta
TensorFlow 2: บันทึกจุดตรวจด้วย Keras callback สำหรับ Model.fit
ใน TensorFlow 2 เมื่อคุณใช้ Keras Model.fit
(หรือ Model.evaluate
) ในตัวสำหรับการฝึกอบรม/การประเมิน คุณสามารถกำหนดค่า tf.keras.callbacks.ModelCheckpoint
แล้วส่งต่อไปยังพารามิเตอร์การ callbacks
ของ Model.fit
(หรือ Model.evaluate
). (เรียนรู้เพิ่มเติมในเอกสาร API และส่วน การใช้การเรียกกลับ ในการ ฝึกอบรมและการประเมินด้วยคู่มือวิธีการใน ตัว)
ในตัวอย่างด้านล่าง คุณจะใช้การเรียกกลับ tf.keras.callbacks.ModelCheckpoint
เพื่อจัดเก็บจุดตรวจในไดเร็กทอรีชั่วคราว:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
Epoch 1/10 1840/1875 [============================>.] - ETA: 0s - loss: 0.2224 - accuracy: 0.9348 2022-01-14 02:28:56.714889: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 4s 2ms/step - loss: 0.2208 - accuracy: 0.9354 - val_loss: 0.1132 - val_accuracy: 0.9669 Epoch 2/10 1870/1875 [============================>.] - ETA: 0s - loss: 0.0961 - accuracy: 0.9706INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0962 - accuracy: 0.9706 - val_loss: 0.0784 - val_accuracy: 0.9753 Epoch 3/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0696 - accuracy: 0.9781INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 2ms/step - loss: 0.0695 - accuracy: 0.9782 - val_loss: 0.0684 - val_accuracy: 0.9788 Epoch 4/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0529 - accuracy: 0.9826INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0531 - accuracy: 0.9826 - val_loss: 0.0671 - val_accuracy: 0.9791 Epoch 5/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0423 - accuracy: 0.9860INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0424 - accuracy: 0.9860 - val_loss: 0.0772 - val_accuracy: 0.9757 Epoch 6/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0345 - accuracy: 0.9888INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0345 - accuracy: 0.9888 - val_loss: 0.0669 - val_accuracy: 0.9811 Epoch 7/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0314 - accuracy: 0.9895INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0313 - accuracy: 0.9895 - val_loss: 0.0718 - val_accuracy: 0.9800 Epoch 8/10 1870/1875 [============================>.] - ETA: 0s - loss: 0.0298 - accuracy: 0.9899INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0298 - accuracy: 0.9899 - val_loss: 0.0632 - val_accuracy: 0.9825 Epoch 9/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0230 - accuracy: 0.9925INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0231 - accuracy: 0.9924 - val_loss: 0.0748 - val_accuracy: 0.9800 Epoch 10/10 1860/1875 [============================>.] - ETA: 0s - loss: 0.0220 - accuracy: 0.9920INFO:tensorflow:Assets written to: /tmp/tmpb85suru4/assets 1875/1875 [==============================] - 3s 1ms/step - loss: 0.0222 - accuracy: 0.9920 - val_loss: 0.0703 - val_accuracy: 0.9825 <keras.callbacks.History at 0x7f638c204410>
%ls {model_checkpoint_callback.filepath}
assets/ keras_metadata.pb saved_model.pb variables/
ขั้นตอนถัดไป
เรียนรู้เพิ่มเติมเกี่ยวกับจุดตรวจใน:
- เอกสาร API:
tf.keras.callbacks.ModelCheckpoint
- บทช่วยสอน: บันทึกและโหลดโมเดล (ส่วน จุดตรวจสอบบันทึกระหว่างการฝึก )
- คำแนะนำ: บันทึกและโหลดโมเดล Keras (ส่วน รูปแบบ TF Checkpoint )
เรียนรู้เพิ่มเติมเกี่ยวกับการโทรกลับใน:
- เอกสาร API:
tf.keras.callbacks.Callback
- คำแนะนำ: การเขียนการโทรกลับของคุณเอง
- คำแนะนำ: การฝึกอบรมและการประเมินด้วยวิธีการที่มีอยู่แล้วภายใน (ส่วน การใช้การเรียกกลับ )
คุณอาจพบว่าแหล่งข้อมูลที่เกี่ยวข้องกับการย้ายข้อมูลต่อไปนี้มีประโยชน์:
- คู่มือการโยกย้ายความทนทานต่อข้อผิดพลาด :
tf.keras.callbacks.BackupAndRestore
forModel.fit
หรือtf.train.Checkpoint
และtf.train.CheckpointManager
APIs สำหรับลูปการฝึกแบบกำหนดเอง - คู่มือการหยุดการย้ายข้อมูลก่อน กำหนด:
tf.keras.callbacks.EarlyStopping
เป็นการโทรกลับเพื่อหยุดก่อนกำหนดภายใน - คู่มือการโยกย้าย TensorBoard : TensorBoard เปิดใช้งานการติดตามและแสดงตัวชี้วัด
- คู่มือการย้ายข้อมูล LoggingTensorHook และ StopAtStepHook ไปยัง Keras
- คู่มือการโทรกลับของ SessionRunHook ถึง Keras