迁移检查点保存

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

持续保存“最佳”模型或模型权重/参数有许多好处,包括能够跟踪训练进度并从不同的保存状态加载保存的模型。

在 TensorFlow 1 中,要使用 tf.estimator.Estimator API 在训练/验证期间配置检查点保存,可以在 tf.estimator.RunConfig 中指定计划或使用 tf.estimator.CheckpointSaverHook。本指南演示了如何从该工作流迁移到 TensorFlow 2 Keras API。

在 TensorFlow 2 中,可以通过多种方式配置 tf.keras.callbacks.ModelCheckpoint

  • 根据使用 save_best_only=True 参数监视的指标保存“最佳”版本,其中 monitor 可以是 'loss''val_loss''accuracy' 或 'val_accuracy'`。
  • 以特定频率持续保存(使用 save_freq 参数)。
  • 通过将 save_weights_only 设置为 True,仅保存权重/参数而不是整个模型。

有关详情,请参阅 tf.keras.callbacks.ModelCheckpoint API 文档和保存和加载模型教程中的训练期间保存检查点部分。在保存和加载 Keras 模型指南中的 TF 检查点格式部分中详细了解检查点格式。另外,要添加容错,可以使用 tf.keras.callbacks.BackupAndRestoretf.train.Checkpoint 手动设置检查点。在容错迁移指南中了解详情。

Keras 回调是在内置 Keras Model.fit/Model.evaluate/Model.predict API 中的训练/评估/预测期间的不同点调用的对象。请在指南末尾的后续步骤部分中了解详情。

安装

从导入和用于演示目的的简单数据集开始:

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
2022-12-14 20:42:49.695528: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:42:49.695617: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:42:49.695626: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

TensorFlow 1:使用 tf.estimator API 保存检查点

此 TensorFlow 1 示例展示了如何配置 tf.estimator.RunConfig 以在使用 tf.estimator.Estimator API 进行训练/评估期间的每一步保存检查点:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]

config = tf1.estimator.RunConfig(save_summary_steps=1,
                                 save_checkpoints_steps=1)

path = tempfile.mkdtemp()

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2,
    model_dir=path,
    config = config
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)

test_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_test},
    y=y_test.astype(np.int32),
    num_epochs=10,
    shuffle=False
)

train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
                                   steps=10,
                                   throttle_secs=0)

tf1.estimator.train_and_evaluate(estimator=classifier,
                                train_spec=train_spec,
                                eval_spec=eval_spec)
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp9t2mhb2d', '_tf_random_seed': None, '_save_summary_steps': 1, '_save_checkpoints_steps': 1, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_44833/3980459272.py:18: The name tf.estimator.inputs is deprecated. Please use tf.compat.v1.estimator.inputs instead.

WARNING:tensorflow:From /tmpfs/tmp/ipykernel_44833/3980459272.py:18: The name tf.estimator.inputs.numpy_input_fn is deprecated. Please use tf.compat.v1.estimator.inputs.numpy_input_fn instead.

INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 1 or save_checkpoints_secs None.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:910: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
2022-12-14 20:42:55.529293: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:42:56
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-1
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.29589s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:42:56
INFO:tensorflow:Saving dict for global step 1: accuracy = 0.14765625, average_loss = 2.3110993, global_step = 1, loss = 295.8207
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 1: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-1
INFO:tensorflow:loss = 118.449585, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:42:57
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-2
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.29818s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:42:57
INFO:tensorflow:Saving dict for global step 2: accuracy = 0.1921875, average_loss = 2.248653, global_step = 2, loss = 287.82758
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-2
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:42:57
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.29778s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:42:58
INFO:tensorflow:Saving dict for global step 3: accuracy = 0.234375, average_loss = 2.202866, global_step = 3, loss = 281.96686
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-3
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4...
INFO:tensorflow:Saving checkpoints for 4 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:42:58
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-4
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.30090s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:42:58
INFO:tensorflow:Saving dict for global step 4: accuracy = 0.31640625, average_loss = 2.162916, global_step = 4, loss = 276.85324
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 4: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-4
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5...
INFO:tensorflow:Saving checkpoints for 5 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saver.py:1064: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:42:59
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-5
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.29170s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:42:59
INFO:tensorflow:Saving dict for global step 5: accuracy = 0.37578124, average_loss = 2.1222653, global_step = 5, loss = 271.64996
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-5
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6...
INFO:tensorflow:Saving checkpoints for 6 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:42:59
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-6
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.29652s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:42:59
INFO:tensorflow:Saving dict for global step 6: accuracy = 0.42890626, average_loss = 2.0763998, global_step = 6, loss = 265.77917
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-6
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7...
INFO:tensorflow:Saving checkpoints for 7 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:43:00
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-7
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.29004s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:43:00
INFO:tensorflow:Saving dict for global step 7: accuracy = 0.45234376, average_loss = 2.0237553, global_step = 7, loss = 259.04068
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 7: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-7
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8...
INFO:tensorflow:Saving checkpoints for 8 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:43:00
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-8
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.29159s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:43:01
INFO:tensorflow:Saving dict for global step 8: accuracy = 0.4609375, average_loss = 1.9725058, global_step = 8, loss = 252.48074
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-8
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9...
INFO:tensorflow:Saving checkpoints for 9 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:43:01
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-9
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.30247s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:43:01
INFO:tensorflow:Saving dict for global step 9: accuracy = 0.46171874, average_loss = 1.9281521, global_step = 9, loss = 246.80347
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 9: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-9
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmp9t2mhb2d/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T20:43:02
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.28789s
INFO:tensorflow:Finished evaluation at 2022-12-14-20:43:02
INFO:tensorflow:Saving dict for global step 10: accuracy = 0.4625, average_loss = 1.8868959, global_step = 10, loss = 241.52267
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmpfs/tmp/tmp9t2mhb2d/model.ckpt-10
INFO:tensorflow:Loss for final step: 94.11531.
({'accuracy': 0.4625,
  'average_loss': 1.8868959,
  'loss': 241.52267,
  'global_step': 10},
 [])
%ls {classifier.model_dir}
checkpoint
eval/
events.out.tfevents.1671050575.kokoro-gcp-ubuntu-prod-129375217
graph.pbtxt
model.ckpt-10.data-00000-of-00001
model.ckpt-10.index
model.ckpt-10.meta
model.ckpt-6.data-00000-of-00001
model.ckpt-6.index
model.ckpt-6.meta
model.ckpt-7.data-00000-of-00001
model.ckpt-7.index
model.ckpt-7.meta
model.ckpt-8.data-00000-of-00001
model.ckpt-8.index
model.ckpt-8.meta
model.ckpt-9.data-00000-of-00001
model.ckpt-9.index
model.ckpt-9.meta

TensorFlow 2:使用 Model.fit 的 Keras 回调保存检查点

在 TensorFlow 2 中,使用内置 Keras Model.fit(或 Model.evaluate)进行训练/评估时,可以配置 tf.keras.callbacks.ModelCheckpoint,然后将其传递给 Model.fit(或 Model.evaluate)的 callbacks 参数。(请在 API 文档和使用内置方法进行训练和评估指南中的使用回调部分中了解详情。)

在下面的示例中,您将使用 tf.keras.callbacks.ModelCheckpoint 回调将检查点存储在临时目录中:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

model = create_model()
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'],
              steps_per_execution=10)

log_dir = tempfile.mkdtemp()

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=log_dir)

model.fit(x=x_train,
          y=y_train,
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[model_checkpoint_callback])
Epoch 1/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.2198 - accuracy: 0.9358INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 5s 2ms/step - loss: 0.2189 - accuracy: 0.9360 - val_loss: 0.1039 - val_accuracy: 0.9670
Epoch 2/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0959 - accuracy: 0.9707INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0955 - accuracy: 0.9708 - val_loss: 0.0815 - val_accuracy: 0.9761
Epoch 3/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0704 - accuracy: 0.9782INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0702 - accuracy: 0.9783 - val_loss: 0.0832 - val_accuracy: 0.9730
Epoch 4/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0516 - accuracy: 0.9831INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0517 - accuracy: 0.9831 - val_loss: 0.0610 - val_accuracy: 0.9807
Epoch 5/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0437 - accuracy: 0.9854INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0437 - accuracy: 0.9854 - val_loss: 0.0664 - val_accuracy: 0.9798
Epoch 6/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0354 - accuracy: 0.9884INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0355 - accuracy: 0.9884 - val_loss: 0.0718 - val_accuracy: 0.9788
Epoch 7/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0306 - accuracy: 0.9897INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0305 - accuracy: 0.9897 - val_loss: 0.0668 - val_accuracy: 0.9814
Epoch 8/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0282 - accuracy: 0.9905INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0282 - accuracy: 0.9905 - val_loss: 0.0778 - val_accuracy: 0.9790
Epoch 9/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0246 - accuracy: 0.9916INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0246 - accuracy: 0.9916 - val_loss: 0.0734 - val_accuracy: 0.9815
Epoch 10/10
1860/1875 [============================>.] - ETA: 0s - loss: 0.0230 - accuracy: 0.9920INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpu4b5x9h0/assets
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0230 - accuracy: 0.9920 - val_loss: 0.0653 - val_accuracy: 0.9846
<keras.callbacks.History at 0x7f0430034f70>
%ls {model_checkpoint_callback.filepath}
assets/  fingerprint.pb  keras_metadata.pb  saved_model.pb  variables/

后续步骤

在以下资源中详细了解检查点:

以下资源中详细了解回调:

此外,您可能还会发现下列与迁移相关的资源十分有用: