Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier |
Une fois que vous avez migré votre modèle des graphiques et des sessions de TensorFlow 1 vers les API TensorFlow 2, telles que tf.function
, tf.Module
et tf.keras.Model
, vous pouvez migrer le code d'enregistrement et de chargement du modèle. Ce notebook fournit des exemples d'enregistrement et de chargement au format SavedModel dans TensorFlow 1 et TensorFlow 2. Voici un bref aperçu des modifications d'API associées pour la migration de TensorFlow 1 vers TensorFlow 2 :
TensorFlow 1 | Migration vers TensorFlow 2 | |
---|---|---|
Économie | tf.compat.v1.saved_model.Builder tf.compat.v1.saved_model.simple_save | tf.saved_model.save Keras : tf.keras.models.save_model |
Chargement | tf.compat.v1.saved_model.load | tf.saved_model.load Keras : tf.keras.models.load_model |
Signatures : un ensemble de saisie et des tenseurs de sortie qui peut être utilisé pour exécuter le | Généré à l'aide des *.signature_def (par exemple tf.compat.v1.saved_model.predict_signature_def ) | Écrivez un tf.function et exportez-le en utilisant l'argument signatures dans tf.saved_model.save . |
Classification et régression : types spéciaux de signatures | Généré avectf.compat.v1.saved_model.classification_signature_def ,tf.compat.v1.saved_model.regression_signature_def ,et certaines exportations d'Estimator. | Ces deux types de signature ont été supprimés de TensorFlow 2. Si la bibliothèque de service requiert ces noms de méthode, tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater . |
Pour une explication plus détaillée du mappage, reportez-vous à la section Changements de TensorFlow 1 à TensorFlow 2 ci-dessous.
Installer
Les exemples ci-dessous montrent comment exporter et charger le même modèle factice TensorFlow (défini comme add_two
ci-dessous) dans un format SavedModel à l'aide des API TensorFlow 1 et TensorFlow 2. Commencez par configurer les importations et les fonctions utilitaires :
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import shutil
def remove_dir(path):
try:
shutil.rmtree(path)
except:
pass
def add_two(input):
return input + 2
TensorFlow 1 : Enregistrer et exporter un modèle enregistré
Dans TensorFlow 1, vous utilisez les tf.compat.v1.saved_model.Builder
, tf.compat.v1.saved_model.simple_save
et tf.estimator.Estimator.export_saved_model
pour créer, enregistrer et exporter le graphique et la session TensorFlow :
1. Enregistrez le graphique en tant que SavedModel avec SavedModelBuilder
remove_dir("saved-model-builder")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add two output: ", sess.run(output, {input: 3.}))
# Save with SavedModelBuilder
builder = tf1.saved_model.Builder('saved-model-builder')
sig_def = tf1.saved_model.predict_signature_def(
inputs={'input': input},
outputs={'output': output})
builder.add_meta_graph_and_variables(
sess, tags=["serve"], signature_def_map={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
})
builder.save()
add two output: 5.0 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:208: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: saved-model-builder/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/bin/saved_model_cli", line 8, in <module> sys.exit(main()) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 1211, in main args.func(args) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 769, in run init_tpu=args.init_tpu, tf_debug=args.tf_debug) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 417, in run_saved_model_with_feed_dict tag_set) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_utils.py", line 117, in get_meta_graph_def saved_model = read_saved_model(saved_model_dir) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_utils.py", line 55, in read_saved_model raise IOError("SavedModel file does not exist at: %s" % saved_model_dir) OSError: SavedModel file does not exist at: simple-save
2. Construire un SavedModel pour servir
remove_dir("simple-save")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add_two output: ", sess.run(output, {input: 3.}))
tf1.saved_model.simple_save(
sess, 'simple-save',
inputs={'input': input},
outputs={'output': output})
add_two output: 5.0 WARNING:tensorflow:From /tmp/ipykernel_26511/250978412.py:12: simple_save (from tensorflow.python.saved_model.simple_save) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.simple_save. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: simple-save/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Result for output key output: 12.0
3. Exporter le graphique d'inférence de l'estimateur en tant que modèle enregistré
Dans la définition de l'Estimator model_fn
(défini ci-dessous), vous pouvez définir des signatures dans votre modèle en retournant export_outputs
dans le tf.estimator.EstimatorSpec
. Il existe différents types de sorties :
-
tf.estimator.export.ClassificationOutput
-
tf.estimator.export.RegressionOutput
-
tf.estimator.export.PredictOutput
Ceux-ci produiront respectivement des types de signature de classification, de régression et de prédiction.
Lorsque l'estimateur est exporté avec tf.estimator.Estimator.export_saved_model
, ces signatures seront enregistrées avec le modèle.
def model_fn(features, labels, mode):
output = add_two(features['input'])
step = tf1.train.get_global_step()
return tf.estimator.EstimatorSpec(
mode,
predictions=output,
train_op=step.assign_add(1),
loss=tf.constant(0.),
export_outputs={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \
tf.estimator.export.PredictOutput({'output': output})})
est = tf.estimator.Estimator(model_fn, 'estimator-checkpoints')
# Train for one step to create a checkpoint.
def train_fn():
return tf.data.Dataset.from_tensors({'input': 3.})
est.train(train_fn, steps=1)
# This utility function `build_raw_serving_input_receiver_fn` takes in raw
# tensor features and builds an "input serving receiver function", which
# creates placeholder inputs to the model.
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
{'input': tf.constant(3.)}) # Pass in a dummy input batch.
estimator_path = est.export_saved_model('exported-estimator', serving_input_fn)
# Estimator's export_saved_model creates a time stamped directory. Move this
# to a set path so it can be inspected with `saved_model_cli` in the cell below.
!rm -rf estimator-model
import shutil
shutil.move(estimator_path, 'estimator-model')
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': 'estimator-checkpoints', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:401: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.0, step = 1 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Loss for final step: 0.0. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from estimator-checkpoints/model.ckpt-1 INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: exported-estimator/temp-1636162129/saved_model.pb 'estimator-model'
!saved_model_cli run --dir estimator-model --tag_set serve \
--signature_def serving_default --input_exprs input=[10]
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from estimator-model/variables/variables Result for output key output: [12.]
TensorFlow 2 : Enregistrer et exporter un modèle enregistré
Enregistrer et exporter un SavedModel défini avec tf.Module
Pour exporter votre modèle dans TensorFlow 2, vous devez définir un tf.Module
ou un tf.keras.Model
pour contenir toutes les variables et fonctions de votre modèle. Ensuite, vous pouvez appeler tf.saved_model.save
pour créer un SavedModel. Reportez-vous à la section Enregistrement d'un modèle personnalisé dans le guide Utilisation du format SavedModel pour en savoir plus.
class MyModel(tf.Module):
@tf.function
def __call__(self, input):
return add_two(input)
model = MyModel()
@tf.function
def serving_default(input):
return {'output': model(input)}
signature_function = serving_default.get_concrete_function(
tf.TensorSpec(shape=[], dtype=tf.float32))
tf.saved_model.save(
model, 'tf2-save', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function})
INFO:tensorflow:Assets written to: tf2-save/assets 2021-11-06 01:28:53.105391: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
!saved_model_cli run --dir tf2-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from tf2-save/variables/variables Result for output key output: 12.0
Enregistrer et exporter un SavedModel défini avec Keras
Les API Keras pour l'enregistrement et l'exportation Mode.save
ou tf.keras.models.save_model
peuvent exporter un SavedModel à partir d'un tf.keras.Model
. Consultez les modèles Enregistrer et charger Keras pour plus de détails.
inp = tf.keras.Input(3)
out = add_two(inp)
model = tf.keras.Model(inputs=inp, outputs=out)
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def serving_default(input):
return {'output': model(input)}
model.save('keras-model', save_format='tf', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default})
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (). INFO:tensorflow:Assets written to: keras-model/assets
!saved_model_cli run --dir keras-model --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from keras-model/variables/variables Result for output key output: 12.0
Chargement d'un modèle enregistré
Un SavedModel enregistré avec l'une des API ci-dessus peut être chargé à l'aide des API TensorFlow 1 ou TensorFlow.
Un SavedModel TensorFlow 1 peut généralement être utilisé pour l'inférence lorsqu'il est chargé dans TensorFlow 2, mais l'entraînement (génération de gradients) n'est possible que si le SavedModel contient des variables de ressource . Vous pouvez vérifier le dtype des variables—si la variable dtype contient "_ref", alors c'est une variable de référence.
Un SavedModel TensorFlow 2 peut être chargé et exécuté à partir de TensorFlow 1 tant que le SavedModel est enregistré avec des signatures.
Les sections ci-dessous contiennent des exemples de code montrant comment charger les SavedModels enregistrés dans les sections précédentes et appeler la signature exportée.
TensorFlow 1 : charger un modèle enregistré avec tf.saved_model.load
Dans TensorFlow 1, vous pouvez importer un SavedModel directement dans le graphique et la session actuels à l'aide tf.saved_model.load
. Vous pouvez appeler Session.run
sur les noms d'entrée et de sortie du tenseur :
def load_tf1(path, input):
print('Loading from', path)
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
meta_graph = tf1.saved_model.load(sess, ["serve"], path)
sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
input_name = sig_def.inputs['input'].name
output_name = sig_def.outputs['output'].name
print(' Output with input', input, ': ',
sess.run(output_name, feed_dict={input_name: input}))
load_tf1('saved-model-builder', 5.)
load_tf1('simple-save', 5.)
load_tf1('estimator-model', [5.]) # Estimator's input must be batched.
load_tf1('tf2-save', 5.)
load_tf1('keras-model', 5.)
Loading from saved-model-builder WARNING:tensorflow:From /tmp/ipykernel_26511/1548963983.py:5: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from estimator-model INFO:tensorflow:Restoring parameters from estimator-model/variables/variables Output with input [5.0] : [7.] Loading from tf2-save INFO:tensorflow:Restoring parameters from tf2-save/variables/variables Output with input 5.0 : 7.0 Loading from keras-model INFO:tensorflow:Restoring parameters from keras-model/variables/variables Output with input 5.0 : 7.0
TensorFlow 2 : charger un modèle enregistré avec tf.saved_model
Dans TensorFlow 2, les objets sont chargés dans un objet Python qui stocke les variables et les fonctions. Ceci est compatible avec les modèles enregistrés à partir de TensorFlow 1.
Consultez la documentation de l'API tf.saved_model.load
et la section Chargement et utilisation d'un modèle personnalisé du guide Utilisation du format SavedModel pour plus de détails.
def load_tf2(path, input):
print('Loading from', path)
loaded = tf.saved_model.load(path)
out = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY](
tf.constant(input))['output']
print(' Output with input', input, ': ', out)
load_tf2('saved-model-builder', 5.)
load_tf2('simple-save', 5.)
load_tf2('estimator-model', [5.]) # Estimator's input must be batched.
load_tf2('tf2-save', 5.)
load_tf2('keras-model', 5.)
Loading from saved-model-builder INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from estimator-model Output with input [5.0] : tf.Tensor([7.], shape=(1,), dtype=float32) Loading from tf2-save Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from keras-model Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32)
Les modèles enregistrés avec l'API TensorFlow 2 peuvent également accéder aux tf.function
s et aux variables associées au modèle (au lieu de celles exportées en tant que signatures). Par example:
loaded = tf.saved_model.load('tf2-save')
print('restored __call__:', loaded.__call__)
print('output with input 5.', loaded(5))
restored __call__: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7f30cc940990> output with input 5. tf.Tensor(7.0, shape=(), dtype=float32)
TensorFlow 2 : charger un modèle enregistré avec Keras
L'API de chargement Keras tf.keras.models.load_model
- vous permet de recharger un modèle enregistré dans un objet Keras Model. Notez que cela vous permet uniquement de charger des modèles enregistrés avec Keras ( Model.save
ou tf.keras.models.save_model
).
Les modèles enregistrés avec tf.saved_model.save
doivent être chargés avec tf.saved_model.load
. Vous pouvez charger un modèle Keras enregistré avec Model.save
en utilisant tf.saved_model.load
mais vous n'obtiendrez que le graphique TensorFlow. Reportez-vous à la documentation de l'API tf.keras.models.load_model
et au guide Enregistrer et charger les modèles Keras pour plus de détails.
loaded_model = tf.keras.models.load_model('keras-model')
loaded_model.predict_on_batch(tf.constant([1, 3, 4]))
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually. WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (3,). array([3., 5., 6.], dtype=float32)
GraphDef et MetaGraphDef
Il n'y a pas de moyen simple de charger un GraphDef
ou un MetaGraphDef
dans TF2. Cependant, vous pouvez convertir le code TF1 qui importe le graphique en une fonction concrete_function
TF2 à l'aide de v1.wrap_function
.
Tout d'abord, enregistrez un MetaGraphDef :
# Save a simple multiplication computation:
with tf.Graph().as_default() as g:
x = tf1.placeholder(tf.float32, shape=[], name='x')
v = tf.Variable(3.0, name='v')
y = tf.multiply(x, v, name='y')
with tf1.Session() as sess:
sess.run(v.initializer)
print(sess.run(y, feed_dict={x: 5}))
s = tf1.train.Saver()
s.export_meta_graph('multiply.pb', as_text=True)
s.save(sess, 'multiply_values.ckpt')
15.0
À l'aide des API TF1, vous pouvez utiliser tf1.train.import_meta_graph
pour importer le graphique et restaurer les valeurs :
with tf.Graph().as_default() as g:
meta = tf1.train.import_meta_graph('multiply.pb')
x = g.get_tensor_by_name('x:0')
y = g.get_tensor_by_name('y:0')
with tf1.Session() as sess:
meta.restore(sess, 'multiply_values.ckpt')
print(sess.run(y, feed_dict={x: 5}))
INFO:tensorflow:Restoring parameters from multiply_values.ckpt 15.0
Il n'y a pas d'API TF2 pour charger le graphe, mais vous pouvez toujours l'importer dans une fonction concrète exécutable en mode impatient :
def import_multiply():
# Any graph-building code is allowed here.
tf1.train.import_meta_graph('multiply.pb')
# Creates a tf.function with all the imported elements in the function graph.
wrapped_import = tf1.wrap_function(import_multiply, [])
import_graph = wrapped_import.graph
x = import_graph.get_tensor_by_name('x:0')
y = import_graph.get_tensor_by_name('y:0')
# Restore the variable values.
tf1.train.Saver(wrapped_import.variables).restore(
sess=None, save_path='multiply_values.ckpt')
# Create a concrete function by pruning the wrap_function (similar to sess.run).
multiply_fn = wrapped_import.prune(feeds=x, fetches=y)
# Run this function
multiply_fn(tf.constant(5.)) # inputs to concrete functions must be Tensors.
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. INFO:tensorflow:Restoring parameters from multiply_values.ckpt <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
Changements de TensorFlow 1 à TensorFlow 2
Cette section répertorie les principaux termes d'enregistrement et de chargement de TensorFlow 1, leurs équivalents TensorFlow 2 et ce qui a changé.
Modèle enregistré
SavedModel est un format qui stocke un programme TensorFlow complet avec des paramètres et des calculs. Il contient les signatures utilisées par les plates-formes de service pour exécuter le modèle.
Le format de fichier lui-même n'a pas changé de manière significative, de sorte que SavedModels peut être chargé et servi à l'aide des API TensorFlow 1 ou TensorFlow 2.
Différences entre TensorFlow 1 et TensorFlow 2
Les cas d'utilisation de diffusion et d' inférence n'ont pas été mis à jour dans TensorFlow 2, à l'exception des modifications de l'API : l'amélioration a été introduite dans la possibilité de réutiliser et de composer des modèles chargés à partir de SavedModel.
Dans TensorFlow 2, le programme est représenté par des objets tels que tf.Variable
, tf.Module
ou des modèles Keras de niveau supérieur ( tf.keras.Model
) et des calques ( tf.keras.layers
). Il n'y a plus de variables globales qui ont des valeurs stockées dans une session, et le graphique existe maintenant dans différents tf.function
s. Par conséquent, lors d'une exportation de modèle, SavedModel enregistre chaque graphique de composant et de fonction séparément.
Lorsque vous écrivez un programme TensorFlow avec les API TensorFlow Python, vous devez créer un objet pour gérer les variables, les fonctions et les autres ressources. Généralement, cela est accompli en utilisant les API Keras, mais vous pouvez également créer l'objet en créant ou en sous- tf.Module
.
Les modèles Keras ( tf.keras.Model
) et tf.Module
automatiquement les variables et les fonctions qui leur sont attachées. SavedModel enregistre ces connexions entre les modules, les variables et les fonctions, afin qu'elles puissent être restaurées lors du chargement.
Signatures
Les signatures sont les points de terminaison d'un SavedModel - elles indiquent à l'utilisateur comment exécuter le modèle et quelles entrées sont nécessaires.
Dans TensorFlow 1, les signatures sont créées en répertoriant les tenseurs d'entrée et de sortie. Dans TensorFlow 2, les signatures sont générées en transmettant des fonctions concrètes . (En savoir plus sur les fonctions TensorFlow dans le guide Introduction aux graphes et tf.function .) En bref, une fonction concrète est générée à partir d'une tf.function
:
# Option 1: Specify an input signature.
@tf.function(input_signature=[...])
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn
})
# Option 2: Call `get_concrete_function`
@tf.function
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn.get_concrete_function(...)
})
Session.run
Dans TensorFlow 1, vous pouvez appeler Session.run
avec le graphique importé tant que vous connaissez déjà les noms des tenseurs. Cela vous permet de récupérer les valeurs de variables restaurées ou d'exécuter des parties du modèle qui n'ont pas été exportées dans les signatures.
Dans TensorFlow 2, vous pouvez accéder directement à une variable, telle qu'une matrice de pondérations ( kernel
) :
model = tf.Module()
model.dense_layer = tf.keras.layers.Dense(...)
tf.saved_model.save('my_saved_model')
loaded = tf.saved_model.load('my_saved_model')
loaded.dense_layer.kernel
ou appelez tf.function
s attaché à l'objet modèle : par exemple, loaded.__call__
.
Contrairement à TF1, il n'y a aucun moyen d'extraire des parties d'une fonction et d'accéder à des valeurs intermédiaires. Vous devez exporter toutes les fonctionnalités nécessaires dans l'objet enregistré.
Remarques sur la migration de TensorFlow Serving
SavedModel a été créé à l'origine pour fonctionner avec TensorFlow Serving . Cette plateforme propose différents types de requêtes de prédiction : classifier, régresser et prédire.
L'API TensorFlow 1 vous permet de créer ces types de signatures avec les utilitaires :
-
tf.compat.v1.saved_model.classification_signature_def
-
tf.compat.v1.saved_model.regression_signature_def
-
tf.compat.v1.saved_model.predict_signature_def
La classification ( classification_signature_def
) et la régression ( regression_signature_def
) restreignent les entrées et les sorties, de sorte que les entrées doivent être un tf.Example
et les sorties doivent être classes
, scores
ou prediction
. Pendant ce temps, la signature de prédiction ( predict_signature_def
) n'a aucune restriction.
Les SavedModels exportés avec l'API TensorFlow 2 sont compatibles avec TensorFlow Serving, mais ne contiendront que des signatures de prédiction. Les signatures de classification et de régression ont été supprimées.
Si vous avez besoin de l'utilisation des signatures de classification et de régression, vous pouvez modifier le SavedModel exporté à l'aide tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater
.
Prochaines étapes
Pour en savoir plus sur SavedModels dans TensorFlow 2, consultez les guides suivants :
Si vous utilisez TensorFlow Hub, ces guides peuvent vous être utiles :