在 TensorFlow.org 上查看 | 在 Google Colab 运行 | 在 Github 上查看源代码 | 下载笔记本 |
将模型从 TensorFlow 1 的计算图和会话迁移到 TensorFlow 2 API(例如 tf.function
、tf.Module
和 tf.keras.Model
)后,您可以迁移模型保存和加载代码。此笔记本提供了如何在 TensorFlow 1 和 TensorFlow 2 中以 SavedModel 格式保存和加载的示例。下面是从 TensorFlow 1 迁移到 TensorFlow 2 的相关 API 变更的快速概览:
TensorFlow 1 | 迁移到 TensorFlow 2 | |
---|---|---|
保存 | tf.compat.v1.saved_model.Builder tf.compat.v1.saved_model.simple_save |
tf.saved_model.save Keras: tf.keras.models.save_model |
加载 | tf.compat.v1.saved_model.load |
tf.saved_model.load Keras: tf.keras.models.load_model |
签名:一组输入 和输出张量, 可用于运行 |
使用 *.signature_def 效用函数生成例如 tf.compat.v1.saved_model.predict_signature_def ) |
编写一个 tf.function 并使用 tf.saved_model.save 中的 signatures 参数将其导出。 |
分类 和回归: 特殊类型的签名 |
使用tf.compat.v1.saved_model.classification_signature_def 、tf.compat.v1.saved_model.regression_signature_def 和某些 Estimator 导出生成。 |
这两种签名类型已从 TensorFlow 2 中移除。 如果应用库需要这些方法名称, 可以使用 tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater 。 |
有关映射的更深入解释,请参阅下面的从 TensorFlow 1 到 TensorFlow 2 的变化部分。
安装
下面的示例展示了如何使用 TensorFlow 1 和 TensorFlow 2 API 将相同的虚拟 TensorFlow 模型(定义为下面的 add_two
)导出并加载到 SavedModel 格式。首先,设置导入和效用函数:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import shutil
def remove_dir(path):
try:
shutil.rmtree(path)
except:
pass
def add_two(input):
return input + 2
2023-11-07 19:28:36.909140: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 19:28:36.909183: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 19:28:36.910794: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow 1:保存和导出 SavedModel
在 TensorFlow 1 中,使用 tf.compat.v1.saved_model.Builder
、tf.compat.v1.saved_model.simple_save
和 tf.estimator.Estimator.export_saved_model
API 来构建、保存及导出 TensorFlow 计算图和会话:
1. 使用 SavedModelBuilder 将计算图保存为 SavedModel
remove_dir("saved-model-builder")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add two output: ", sess.run(output, {input: 3.}))
# Save with SavedModelBuilder
builder = tf1.saved_model.Builder('saved-model-builder')
sig_def = tf1.saved_model.predict_signature_def(
inputs={'input': input},
outputs={'output': output})
builder.add_meta_graph_and_variables(
sess, tags=["serve"], signature_def_map={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
})
builder.save()
add two output: 5.0 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:225: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2. INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: saved-model-builder/saved_model.pb
!saved_model_cli run --dir saved-model-builder --tag_set serve \
--signature_def serving_default --input_exprs input=10
2023-11-07 19:28:41.592742: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 19:28:41.592788: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 19:28:41.594296: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. W1107 19:28:44.876232 139908854114112 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. INFO:tensorflow:Saver not created because there are no variables in the graph to restore I1107 19:28:44.878519 139908854114112 saver.py:1635] Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. I1107 19:28:44.878687 139908854114112 loader_impl.py:451] The specified SavedModel has no variables; no checkpoints were restored. Result for output key output: 12.0
2. 为应用构建 SavedModel
remove_dir("simple-save")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add_two output: ", sess.run(output, {input: 3.}))
tf1.saved_model.simple_save(
sess, 'simple-save',
inputs={'input': input},
outputs={'output': output})
add_two output: 5.0 WARNING:tensorflow:From /tmpfs/tmp/ipykernel_188391/250978412.py:9: simple_save (from tensorflow.python.saved_model.simple_save) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: simple-save/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
2023-11-07 19:28:46.512662: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 19:28:46.512706: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 19:28:46.514275: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. W1107 19:28:49.773898 139809369614144 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. INFO:tensorflow:Saver not created because there are no variables in the graph to restore I1107 19:28:49.776148 139809369614144 saver.py:1635] Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. I1107 19:28:49.776336 139809369614144 loader_impl.py:451] The specified SavedModel has no variables; no checkpoints were restored. Result for output key output: 12.0
3. 将 Estimator 推断计算图导出为 SavedModel
在 Estimator model_fn
(定义如下)的定义中,您可以通过在 tf.estimator.EstimatorSpec
中返回 export_outputs
来定义模型中的签名。下面是不同类型的输出:
tf.estimator.export.ClassificationOutput
tf.estimator.export.RegressionOutput
tf.estimator.export.PredictOutput
这些输出将分别产生分类、回归和预测特征类型。
使用 tf.estimator.Estimator.export_saved_model
导出 Estimator 时,这些签名将随模型一起保存。
def model_fn(features, labels, mode):
output = add_two(features['input'])
step = tf1.train.get_global_step()
return tf.estimator.EstimatorSpec(
mode,
predictions=output,
train_op=step.assign_add(1),
loss=tf.constant(0.),
export_outputs={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \
tf.estimator.export.PredictOutput({'output': output})})
est = tf.estimator.Estimator(model_fn, 'estimator-checkpoints')
# Train for one step to create a checkpoint.
def train_fn():
return tf.data.Dataset.from_tensors({'input': 3.})
est.train(train_fn, steps=1)
# This utility function `build_raw_serving_input_receiver_fn` takes in raw
# tensor features and builds an "input serving receiver function", which
# creates placeholder inputs to the model.
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
{'input': tf.constant(3.)}) # Pass in a dummy input batch.
estimator_path = est.export_saved_model('exported-estimator', serving_input_fn)
# Estimator's export_saved_model creates a time stamped directory. Move this
# to a set path so it can be inspected with `saved_model_cli` in the cell below.
!rm -rf estimator-model
import shutil
shutil.move(estimator_path, 'estimator-model')
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_188391/2386241072.py:12: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': 'estimator-checkpoints', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_188391/2386241072.py:11: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_188391/2386241072.py:4: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 0.0, step = 1 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Loss for final step: 0.0. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_188391/2386241072.py:22: build_raw_serving_input_receiver_fn (from tensorflow_estimator.python.estimator.export.export) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/export/export.py:375: ServingInputReceiver.__new__ (from tensorflow_estimator.python.estimator.export.export) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py:83: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from estimator-checkpoints/model.ckpt-1 INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: exported-estimator/temp-1699385331/saved_model.pb 'estimator-model'
!saved_model_cli run --dir estimator-model --tag_set serve \
--signature_def serving_default --input_exprs input=[10]
2023-11-07 19:28:52.040331: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 19:28:52.040385: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 19:28:52.041928: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. W1107 19:28:55.267001 140468613252928 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. INFO:tensorflow:Restoring parameters from estimator-model/variables/variables I1107 19:28:55.274052 140468613252928 saver.py:1413] Restoring parameters from estimator-model/variables/variables Result for output key output: [12.]
TensorFlow 2:保存和导出 SavedModel
保存并导出使用 tf.Module 定义的 SavedModel
要在 TensorFlow 2 中导出模型,必须定义 tf.Module
或 tf.keras.Model
来保存模型的所有变量和函数。随后,可以调用 tf.saved_model.save
来创建 SavedModel。请参阅使用 SavedModel 格式指南中的保存自定义模型部分来了解详情。
class MyModel(tf.Module):
@tf.function
def __call__(self, input):
return add_two(input)
model = MyModel()
@tf.function
def serving_default(input):
return {'output': model(input)}
signature_function = serving_default.get_concrete_function(
tf.TensorSpec(shape=[], dtype=tf.float32))
tf.saved_model.save(
model, 'tf2-save', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function})
INFO:tensorflow:Assets written to: tf2-save/assets
!saved_model_cli run --dir tf2-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
2023-11-07 19:28:57.001103: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 19:28:57.001151: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 19:28:57.002660: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. W1107 19:29:00.273315 140671910680384 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. INFO:tensorflow:Restoring parameters from tf2-save/variables/variables I1107 19:29:00.288938 140671910680384 saver.py:1413] Restoring parameters from tf2-save/variables/variables Result for output key output: 12.0
保存并导出使用 Keras 定义的 SavedModel
已弃用:对于 Keras 对象,建议使用新的高级 .keras
格式和 tf.keras.Model.export
,如此处的指南所示。对于现有代码,继续支持低级 SavedModel 格式。
用于保存和导出的 Keras API(Mode.save
或 tf.keras.models.save_model
)可以从 tf.keras.Model
导出 SavedModel。请查看保存和加载 Keras 模型,了解更多详细信息。
inp = tf.keras.Input(3)
out = add_two(inp)
model = tf.keras.Model(inputs=inp, outputs=out)
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def serving_default(input):
return {'output': model(input)}
model.save('keras-model', save_format='tf', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default})
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. INFO:tensorflow:Assets written to: keras-model/assets INFO:tensorflow:Assets written to: keras-model/assets
!saved_model_cli run --dir keras-model --tag_set serve \
--signature_def serving_default --input_exprs input=10
2023-11-07 19:29:02.069251: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 19:29:02.069300: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 19:29:02.070788: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. W1107 19:29:05.318873 139861087278912 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. INFO:tensorflow:Restoring parameters from keras-model/variables/variables I1107 19:29:05.334767 139861087278912 saver.py:1413] Restoring parameters from keras-model/variables/variables Result for output key output: 12.0
加载 SavedModel
可以使用 TensorFlow 1 或 TensorFlow 2 API 加载使用上述任何 API 保存的 SavedModel。
TensorFlow 1 SavedModel 在加载到 TensorFlow 2 时通常可用于推断,但只有在 SavedModel 包含资源变量时才能进行训练(生成梯度)。您可以检查变量的数据类型,如果变量数据类型包含“_ref"”,则它是引用变量。
只要 SavedModel 随签名一起保存,就可以在 TensorFlow 1 中加载和执行 TensorFlow 2 SavedModel。
下面的部分包含代码示例,展示了如何加载前面部分中保存的 SavedModel 以及调用导出的签名。
TensorFlow 1:使用 tf.saved_model.load 加载 SavedModel
在 TensorFlow 1 中,可以使用 tf.saved_model.load
将 SavedModel 直接导入当前计算图和会话。可以在张量输入和输出名称上调用 Session.run
:
def load_tf1(path, input):
print('Loading from', path)
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
meta_graph = tf1.saved_model.load(sess, ["serve"], path)
sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
input_name = sig_def.inputs['input'].name
output_name = sig_def.outputs['output'].name
print(' Output with input', input, ': ',
sess.run(output_name, feed_dict={input_name: input}))
load_tf1('saved-model-builder', 5.)
load_tf1('simple-save', 5.)
load_tf1('estimator-model', [5.]) # Estimator's input must be batched.
load_tf1('tf2-save', 5.)
load_tf1('keras-model', 5.)
Loading from saved-model-builder WARNING:tensorflow:From /tmpfs/tmp/ipykernel_188391/1548963983.py:5: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_188391/1548963983.py:5: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.saved_model.load` instead. INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from estimator-model INFO:tensorflow:Restoring parameters from estimator-model/variables/variables INFO:tensorflow:Restoring parameters from estimator-model/variables/variables Output with input [5.0] : [7.] Loading from tf2-save INFO:tensorflow:Restoring parameters from tf2-save/variables/variables INFO:tensorflow:Restoring parameters from tf2-save/variables/variables Output with input 5.0 : 7.0 Loading from keras-model INFO:tensorflow:Restoring parameters from keras-model/variables/variables INFO:tensorflow:Restoring parameters from keras-model/variables/variables Output with input 5.0 : 7.0
TensorFlow 2:加载使用 tf.saved_model 保存的模型
在 TensorFlow 2 中,对象会加载到存储变量和函数的 Python 对象中。这与从 TensorFlow 1 保存的模型兼容。
查看 tf.saved_model.load
API 文档和使用 SavedModel 格式指南中的加载和使用自定义模型部分来了解详情。
def load_tf2(path, input):
print('Loading from', path)
loaded = tf.saved_model.load(path)
out = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY](
tf.constant(input))['output']
print(' Output with input', input, ': ', out)
load_tf2('saved-model-builder', 5.)
load_tf2('simple-save', 5.)
load_tf2('estimator-model', [5.]) # Estimator's input must be batched.
load_tf2('tf2-save', 5.)
load_tf2('keras-model', 5.)
Loading from saved-model-builder INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from estimator-model Output with input [5.0] : tf.Tensor([7.], shape=(1,), dtype=float32) Loading from tf2-save Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from keras-model Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32)
使用 TensorFlow 2 API 保存的模型还可以访问附加到模型的 tf.function
和变量(而不是作为签名导出的那些条目)。例如:
loaded = tf.saved_model.load('tf2-save')
print('restored __call__:', loaded.__call__)
print('output with input 5.', loaded(5))
restored __call__: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7f61a8804f10> output with input 5. tf.Tensor(7.0, shape=(), dtype=float32)
TensorFlow 2:加载使用 Keras 保存的模型
已弃用:对于 Keras 对象,建议使用新的高级 .keras
格式和 tf.keras.Model.export
,如此处的指南所示。对于现有代码,继续支持低级 SavedModel 格式。
Keras 加载 API (tf.keras.models.load_model
) 允许您将保存的模型重新加载回 Keras 模型对象。请注意,这只允许您加载使用 Keras 保存的 SavedModel(Model.save
或 tf.keras.models.save_model
)
使用 tf.saved_model.save
保存的模型应使用 tf.saved_model.load
加载。可以使用 tf.saved_model.load
加载使用 Model.save
保存的 Keras 模型,但这样做只会获得 TensorFlow 计算图。有关详情,请参阅 tf.keras.models.load_model
API 文档与保存和加载 Keras 模型指南。
loaded_model = tf.keras.models.load_model('keras-model')
loaded_model.predict_on_batch(tf.constant([1, 3, 4]))
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually. WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually. array([3., 5., 6.], dtype=float32)
GraphDef 和 MetaGraphDef
<a name="graphdef_and_metagraphdef">
没有直接方式可以将原始 GraphDef
或 MetaGraphDef
加载到 TF2。但是,可以使用 v1.wrap_function
将导入计算图的 TF1 代码转换为 TF2 concrete_function
。
首先,保存 MetaGraphDef:
# Save a simple multiplication computation:
with tf.Graph().as_default() as g:
x = tf1.placeholder(tf.float32, shape=[], name='x')
v = tf.Variable(3.0, name='v')
y = tf.multiply(x, v, name='y')
with tf1.Session() as sess:
sess.run(v.initializer)
print(sess.run(y, feed_dict={x: 5}))
s = tf1.train.Saver()
s.export_meta_graph('multiply.pb', as_text=True)
s.save(sess, 'multiply_values.ckpt')
15.0
利用 TF1 API,可以使用 tf1.train.import_meta_graph
导入计算图并恢复值:
with tf.Graph().as_default() as g:
meta = tf1.train.import_meta_graph('multiply.pb')
x = g.get_tensor_by_name('x:0')
y = g.get_tensor_by_name('y:0')
with tf1.Session() as sess:
meta.restore(sess, 'multiply_values.ckpt')
print(sess.run(y, feed_dict={x: 5}))
INFO:tensorflow:Restoring parameters from multiply_values.ckpt INFO:tensorflow:Restoring parameters from multiply_values.ckpt 15.0
没有用于加载计算图的 TF2 API,但您仍然可以将其导入到可以在 Eager 模式下执行的具体函数中:
def import_multiply():
# Any graph-building code is allowed here.
tf1.train.import_meta_graph('multiply.pb')
# Creates a tf.function with all the imported elements in the function graph.
wrapped_import = tf1.wrap_function(import_multiply, [])
import_graph = wrapped_import.graph
x = import_graph.get_tensor_by_name('x:0')
y = import_graph.get_tensor_by_name('y:0')
# Restore the variable values.
tf1.train.Saver(wrapped_import.variables).restore(
sess=None, save_path='multiply_values.ckpt')
# Create a concrete function by pruning the wrap_function (similar to sess.run).
multiply_fn = wrapped_import.prune(feeds=x, fetches=y)
# Run this function
multiply_fn(tf.constant(5.)) # inputs to concrete functions must be Tensors.
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. INFO:tensorflow:Restoring parameters from multiply_values.ckpt INFO:tensorflow:Restoring parameters from multiply_values.ckpt <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
从 TensorFlow 1 到 TensorFlow 2 的变化
<a id="changes_from_tf1_to_tf2">
本部分列出了 TensorFlow 1 中的关键保存和加载术语、它们的 TensorFlow 2 等效项以及发生了哪些变化。
SavedModel
SavedModel 是一种存储带参数和计算的完整 TensorFlow 程序的格式。它包含应用平台用来运行模型的签名。
文件格式本身没有重大变化,因此可以使用 TensorFlow 1 或 TensorFlow 2 API 加载和应用 SavedModel。
TensorFlow 1 和 TensorFlow 2 的区别
除了 API 变更外,TensorFlow 2 中的应用和推断用例没有更新。在重用和组合从 SavedModel 加载的模型的能力中引入了改进。
在 TensorFlow 2 中,程序由 tf.Variable
、tf.Module
或更高级别的 Keras 模型 (tf.keras.Model
) 和层 (tf.keras.layers
) 等对象表示。不再具有在会话中存储值的全局变量,并且计算图现在存在于不同的 tf.function
中。因此,在模型导出期间,SavedModel 会分别保存每个组件和函数计算图。
使用 TensorFlow Python API 编写 TensorFlow 程序时,您必须构建一个对象来管理变量、函数和其他资源。通常,可以通过使用 Keras API 来完成此目的,但也可以通过创建或子类化 tf.Module
来构建对象。
Keras 模型 (tf.keras.Model
) 和 tf.Module
会自动跟踪附加到它们的变量和函数。SavedModel 会保存各个模块、变量和函数之间的相应连接,以便在加载时可以恢复。
签名
签名是 SavedModel 的端点 – 它们告诉用户如何运行模型以及需要哪些输入。
在 TensorFlow 1 中,签名是通过列出输入和输出张量来创建的。在 TensorFlow 2 中,签名是通过传入具体函数来生成的。(在计算图和 tf.function 简介指南中阅读更多关于 TensorFlow 函数的信息,特别是多态性:一个函数,多个计算图部分。)简而言之,具体函数是从 tf.function
生成的:
# Option 1: Specify an input signature.
@tf.function(input_signature=[...])
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn
})
# Option 2: Call `get_concrete_function`
@tf.function
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn.get_concrete_function(...)
})
Session.run
在 TensorFlow 1 中,只要您已经知道张量名称,就可以使用导入的计算图调用 Session.run
。这样就可以检索恢复的变量值,或者运行未在签名中导出的模型部分。
在 TensorFlow 2 中,可以直接访问变量,例如权重矩阵 (kernel
):
model = tf.Module()
model.dense_layer = tf.keras.layers.Dense(...)
tf.saved_model.save('my_saved_model')
loaded = tf.saved_model.load('my_saved_model')
loaded.dense_layer.kernel
或者调用附加到模型对象的 tf.function
:例如 loaded.__call__
。
与 TF1 不同,没有办法提取函数的一部分并访问中间值。您必须在保存的对象中导出所有需要的功能。
TensorFlow Serving 迁移指南
SavedModel 最初是为了与 TensorFlow Serving 一起使用而创建的。此平台提供不同类型的预测请求:分类、回归和预测。
TensorFlow 1 API 允许您使用效用函数创建这些类型的签名:
tf.compat.v1.saved_model.classification_signature_def
tf.compat.v1.saved_model.regression_signature_def
tf.compat.v1.saved_model.predict_signature_def
分类 (classification_signature_def
) 和回归 (regression_signature_def
) 会限制输入和输出,因此输入必须是 tf.Example
,输出必须是 classes
、scores
或 prediction
。同时,预测签名 (predict_signature_def
) 没有限制。
使用 TensorFlow 2 API 导出的 SavedModel 与 TensorFlow Serving 兼容,但仅包含预测签名。分类和回归签名已被移除。
如果您需要使用分类和回归签名,则可以使用 tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater
修改导出的 SavedModel。
后续步骤
要详细了解 TensorFlow 2 中的 SavedModel,请查看以下指南:
如果您使用的是 TensorFlow Hub,则可能会发现下列指南十分有用: