TensorFlow 2.0 Beta is available Learn more

Save and load a model using tf.distribute.Strategy

tf.distribute.Strategy" />
View on TensorFlow.org View source on GitHub Download notebook

Overview

It's common to save and load a model during training. There are two sets of APIs for saving and loading a keras model: a high-level API, and a low-level API. This tutorial demonstrates how you can use the SavedModel APIs when using tf.distribute.Strategy. To learn about SavedModel and serialization in general, please read the saved model guide, and the Keras model serialization guide. Let's start with a simple example:

Import dependencies:

from __future__ import absolute_import, division, print_function, unicode_literals

try:
  # %tensorflow_version only exists in Colab.
  !pip install -q tf-nightly-gpu-2.0-preview
except Exception:
  pass
import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()
ERROR: tensorflow-gpu 2.0.0b1 has requirement tb-nightly<1.14.0a20190604,>=1.14.0a20190603, but you'll have tb-nightly 1.15.0a20190806 which is incompatible.

Prepare the data and model using tf.distribute.Strategy:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model

Train the model:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
WARNING: Logging before flag parsing goes to stderr.
W0813 02:04:16.386963 140169908946688 dataset_builder.py:439] Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.

Epoch 1/2

W0813 02:04:16.972727 140169908946688 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:466: BaseResourceVariable.constraint (from tensorflow.python.ops.resource_variable_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Apply a constraint manually following the optimizer update step.

938/938 [==============================] - 11s 12ms/step - loss: 0.1895 - accuracy: 0.9457
Epoch 2/2
938/938 [==============================] - 7s 8ms/step - loss: 0.0649 - accuracy: 0.9809

<tensorflow.python.keras.callbacks.History at 0x7f7b40164ac8>

Save and load the model

Now that you have a simple model to work with, let's take a look at the saving/loading APIs. There are two sets of APIs available:

The Keras APIs

Here is an example of saving and loading a model with the Keras APIs:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)  # save() should be called out of strategy scope
W0813 02:04:36.591645 140169908946688 deprecation.py:506] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1784: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Restore the model without tf.distribute.Strategy:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
W0813 02:04:37.468916 140169908946688 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/math_grad.py:1393: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Epoch 1/2
938/938 [==============================] - 10s 11ms/step - loss: 0.0468 - accuracy: 0.9854
Epoch 2/2
938/938 [==============================] - 9s 10ms/step - loss: 0.0320 - accuracy: 0.9906

<tensorflow.python.keras.callbacks.History at 0x7f7bc87b4898>

After restoring the model, you can continue training on it, even without needing to call compile() again, since it is already compiled before saving. The model is saved in the TensorFlow's standard SavedModel proto format. For more information, please refer to the guide to saved_model format.

It is important to only call the model.save() method out of the scope of tf.distribute.strategy. Calling it within the scope is not supported.

Now to load the model and train it using a tf.distribute.Strategy:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 17s 18ms/step - loss: 0.0454 - accuracy: 0.9864
Epoch 2/2
938/938 [==============================] - 15s 16ms/step - loss: 0.0327 - accuracy: 0.9900

As you can see, loading works as expected with tf.distribute.Strategy. The strategy used here does not have to be the same strategy used before saving.

The tf.saved_model APIs

Now let's take a look at the lower level APIs. Saving the model is similar to the keras API:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)

Loading can be done with tf.saved_model.load(). However, since it is an API that is on the lower level (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

The loaded object may contain multiple functions, each associated with a key. The "serving_default" is the default key for the inference function with a saved Keras model. To do an inference with this function:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: id=23971, shape=(64, 10), dtype=float32, numpy=
array([[0.10346866, 0.09814832, 0.10097326, 0.10860109, 0.09718737,
        0.08897626, 0.10873589, 0.08713679, 0.11888701, 0.08788531],
       [0.10382051, 0.09923346, 0.09153106, 0.10073772, 0.11399817,
        0.10367267, 0.11159892, 0.08408294, 0.10567528, 0.08564924],
       [0.0860683 , 0.0833439 , 0.10841124, 0.09725087, 0.11422459,
        0.09504855, 0.12179998, 0.09789883, 0.11483619, 0.08111753],
       [0.09941313, 0.09360484, 0.10998822, 0.09961234, 0.09676877,
        0.10579284, 0.09764036, 0.09880069, 0.10527612, 0.09310272],
       [0.09434031, 0.08653356, 0.11041423, 0.09923773, 0.100683  ,
        0.10154828, 0.11327942, 0.09163021, 0.11172436, 0.09060891],
       [0.10459457, 0.1062756 , 0.10589107, 0.09333517, 0.11420649,
        0.09591331, 0.10444249, 0.07514606, 0.11983742, 0.08035784],
       [0.09011389, 0.08499905, 0.10331811, 0.09727702, 0.11758666,
        0.10133488, 0.1252136 , 0.08713547, 0.11386759, 0.07915378],
       [0.10236397, 0.0893259 , 0.10680854, 0.10644416, 0.10444679,
        0.10033372, 0.10489715, 0.08380076, 0.11453015, 0.08704887],
       [0.10032947, 0.08363276, 0.11155015, 0.09546151, 0.09045216,
        0.10530999, 0.09962162, 0.1049034 , 0.1095127 , 0.09922627],
       [0.09604491, 0.09951681, 0.10141731, 0.09927326, 0.09821585,
        0.09225526, 0.10479179, 0.1026044 , 0.11287221, 0.09300816],
       [0.09835577, 0.09605151, 0.0969343 , 0.10511729, 0.10452414,
        0.09522843, 0.102549  , 0.09733117, 0.11260031, 0.09130809],
       [0.10363709, 0.08875845, 0.0971687 , 0.10388689, 0.10951288,
        0.09372345, 0.09966854, 0.09660895, 0.11106581, 0.09596929],
       [0.10253987, 0.1005487 , 0.09886398, 0.0959296 , 0.12276863,
        0.09967854, 0.11348157, 0.07610316, 0.11335331, 0.07673255],
       [0.09531612, 0.08293157, 0.1085951 , 0.11214884, 0.11267255,
        0.08390686, 0.11346839, 0.08213136, 0.12693633, 0.08189286],
       [0.09784577, 0.09043782, 0.10765995, 0.10589457, 0.10576151,
        0.09812345, 0.1022763 , 0.09766556, 0.10891743, 0.08541768],
       [0.09836663, 0.0956066 , 0.09372915, 0.09864435, 0.10803682,
        0.09337633, 0.10925217, 0.09124055, 0.12442075, 0.08732665],
       [0.09563009, 0.09587347, 0.09856581, 0.10026651, 0.10715912,
        0.09647533, 0.1104801 , 0.09847449, 0.1086933 , 0.08838181],
       [0.0927564 , 0.0878072 , 0.10179386, 0.10248121, 0.10363802,
        0.09150893, 0.11779571, 0.09414804, 0.11937328, 0.08869733],
       [0.09928082, 0.10291645, 0.10191446, 0.1038011 , 0.09991258,
        0.09750739, 0.10221776, 0.09721742, 0.10483133, 0.09040064],
       [0.09401325, 0.08501768, 0.10261724, 0.10452802, 0.10948581,
        0.09882095, 0.11370245, 0.09654196, 0.10376322, 0.09150937],
       [0.09879462, 0.0944241 , 0.0987532 , 0.10302619, 0.10345107,
        0.10714982, 0.10008726, 0.09951942, 0.10102852, 0.09376577],
       [0.10038041, 0.09397505, 0.10546975, 0.10778999, 0.09694247,
        0.1008195 , 0.10354031, 0.08085191, 0.11862423, 0.09160632],
       [0.09137096, 0.08472124, 0.10200507, 0.10545688, 0.11515714,
        0.09663229, 0.1151142 , 0.09717429, 0.10974383, 0.08262403],
       [0.09915587, 0.09616621, 0.09546374, 0.10454717, 0.10498554,
        0.10542189, 0.09955135, 0.09873347, 0.10257642, 0.09339834],
       [0.10190742, 0.09066901, 0.10286693, 0.10081208, 0.09404963,
        0.10061713, 0.10532601, 0.10141323, 0.10938738, 0.09295118],
       [0.08722662, 0.09220009, 0.10506706, 0.10774104, 0.10665391,
        0.09189102, 0.12218424, 0.09075772, 0.11364452, 0.08263379],
       [0.10126304, 0.09231077, 0.10558996, 0.10467736, 0.09878772,
        0.09894125, 0.0998457 , 0.09535211, 0.11192872, 0.09130341],
       [0.09153916, 0.08554792, 0.10534914, 0.11211953, 0.09906203,
        0.08445288, 0.11558711, 0.08930107, 0.12027376, 0.0967674 ],
       [0.09811854, 0.08584639, 0.12018801, 0.10556579, 0.09237039,
        0.10188867, 0.11382465, 0.08533856, 0.11210523, 0.08475371],
       [0.09963494, 0.09025937, 0.09705737, 0.10906896, 0.10578959,
        0.09470603, 0.10918968, 0.09673195, 0.10590146, 0.09166072],
       [0.10042295, 0.08928952, 0.09714521, 0.10439382, 0.10304204,
        0.10713115, 0.10581578, 0.10797798, 0.10102931, 0.0837523 ],
       [0.10250884, 0.09027012, 0.11323065, 0.09910946, 0.09581795,
        0.1041921 , 0.10375739, 0.0950167 , 0.10356285, 0.09253392],
       [0.09492254, 0.09104889, 0.09787201, 0.10392901, 0.11104202,
        0.10191151, 0.11694434, 0.09806208, 0.10368336, 0.08058424],
       [0.09570879, 0.09808693, 0.10059498, 0.09546126, 0.10593103,
        0.10418717, 0.09805436, 0.10377921, 0.10595299, 0.09224328],
       [0.09451246, 0.08945645, 0.11076438, 0.10362288, 0.11730896,
        0.09819566, 0.11507961, 0.09011741, 0.10245928, 0.0784829 ],
       [0.10133371, 0.09073492, 0.10146359, 0.11033185, 0.10479241,
        0.10856219, 0.10829584, 0.09569849, 0.10044938, 0.07833764],
       [0.08926281, 0.08991302, 0.10260765, 0.09982484, 0.10565517,
        0.09334648, 0.11699689, 0.09533685, 0.11997543, 0.08708089],
       [0.09704331, 0.10004682, 0.10441159, 0.09855799, 0.09446484,
        0.09947531, 0.10544829, 0.10129734, 0.1052985 , 0.09395598],
       [0.09684438, 0.0827558 , 0.11028913, 0.10251626, 0.1007083 ,
        0.1036883 , 0.10700043, 0.10187439, 0.10405426, 0.09026877],
       [0.10322517, 0.09957865, 0.1060021 , 0.11137328, 0.10719634,
        0.08707383, 0.10309171, 0.0781871 , 0.12115877, 0.08311309],
       [0.10013492, 0.08762013, 0.11150011, 0.10167935, 0.11142804,
        0.09226192, 0.11050308, 0.08894029, 0.11361992, 0.08231225],
       [0.10519367, 0.08927482, 0.11147374, 0.10377401, 0.11059026,
        0.09473526, 0.10771249, 0.07541236, 0.11464737, 0.08718599],
       [0.09099451, 0.09068823, 0.10858792, 0.09598251, 0.1155775 ,
        0.0925648 , 0.11430975, 0.09074868, 0.11264577, 0.08790037],
       [0.09718288, 0.08952714, 0.10844639, 0.10096531, 0.10303098,
        0.10404167, 0.11779907, 0.09356453, 0.10306749, 0.08237456],
       [0.10563843, 0.09636243, 0.09834308, 0.11278491, 0.0999744 ,
        0.09548616, 0.10330789, 0.08411493, 0.11448289, 0.08950488],
       [0.09583376, 0.09914181, 0.09573716, 0.09818669, 0.11077876,
        0.09936378, 0.10235663, 0.10066731, 0.10771012, 0.09022401],
       [0.09806755, 0.09639712, 0.10536952, 0.09963136, 0.10438179,
        0.09355605, 0.11569549, 0.09011205, 0.10902449, 0.08776458],
       [0.09463413, 0.08494024, 0.10401177, 0.10046989, 0.10830265,
        0.09928988, 0.10311138, 0.0958096 , 0.11203949, 0.09739097],
       [0.09722895, 0.09561229, 0.09854706, 0.09688969, 0.10467776,
        0.10473985, 0.10349254, 0.09999184, 0.10621898, 0.09260107],
       [0.08788903, 0.08491514, 0.11268945, 0.1183198 , 0.10632903,
        0.08981604, 0.11899668, 0.07622642, 0.12137191, 0.08344647],
       [0.09505428, 0.08619108, 0.11239511, 0.10939098, 0.09526334,
        0.08840484, 0.11757697, 0.09548493, 0.10809403, 0.09214443],
       [0.10484948, 0.08087845, 0.10810892, 0.10396265, 0.0945467 ,
        0.10732662, 0.09872332, 0.09874402, 0.10383707, 0.09902277],
       [0.10167606, 0.08799185, 0.10649195, 0.09645592, 0.09039402,
        0.10914844, 0.0963586 , 0.1069734 , 0.10400245, 0.1005073 ],
       [0.09246631, 0.08742926, 0.10730387, 0.10324221, 0.11251386,
        0.0882643 , 0.11868344, 0.08621413, 0.11619555, 0.08768709],
       [0.08872671, 0.07926777, 0.11125548, 0.09687474, 0.10811875,
        0.0915281 , 0.12041357, 0.0934279 , 0.12884356, 0.08154348],
       [0.09522267, 0.08930329, 0.09525562, 0.09706242, 0.12696648,
        0.08585808, 0.10983131, 0.08950891, 0.12006874, 0.09092251],
       [0.09627675, 0.09688797, 0.10596254, 0.09848142, 0.10003708,
        0.10214969, 0.10099179, 0.10044516, 0.10327348, 0.09549418],
       [0.09488672, 0.09927677, 0.10260586, 0.09797408, 0.11105204,
        0.09413116, 0.10274227, 0.09463804, 0.1016879 , 0.10100509],
       [0.08359265, 0.08350577, 0.11496595, 0.10058615, 0.10874845,
        0.09675391, 0.12959388, 0.09232371, 0.10971707, 0.08021242],
       [0.09676141, 0.0925442 , 0.10196703, 0.10463198, 0.11411104,
        0.08869814, 0.12117643, 0.08268759, 0.11293685, 0.08448534],
       [0.09520389, 0.0890058 , 0.10589833, 0.10243975, 0.10009909,
        0.08957859, 0.12266893, 0.08958936, 0.11756258, 0.08795375],
       [0.09484362, 0.08396075, 0.10584073, 0.1092435 , 0.11112304,
        0.09732556, 0.1183947 , 0.09026497, 0.10388077, 0.0851224 ],
       [0.08712243, 0.09426759, 0.09095844, 0.09610551, 0.12273245,
        0.09062532, 0.12355014, 0.08334675, 0.13011955, 0.08117187],
       [0.09488916, 0.09992828, 0.09538072, 0.10277639, 0.10406377,
        0.10078109, 0.10034841, 0.10029319, 0.10515746, 0.09638151]],
      dtype=float32)>}

You can also load and do inference in a distributed manner:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.experimental_run_v2(inference_func, 
                                         args=(batch,))
W0813 02:05:32.735954 140169908946688 mirrored_strategy.py:660] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
W0813 02:05:32.765982 140169908946688 mirrored_strategy.py:660] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
W0813 02:05:32.771373 140169908946688 mirrored_strategy.py:660] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
W0813 02:05:32.776773 140169908946688 mirrored_strategy.py:660] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.
W0813 02:05:32.782411 140169908946688 mirrored_strategy.py:660] Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Calling the restored function is just a forward pass on the saved model (predict). What if yout want to continue training the loaded function? Or embed the loaded function into a bigger model? A common practice is to wrap this loaded object to a Keras layer to achieve this. Luckily, TF Hub has hub.KerasLayer for this purpose, shown here:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 10s 11ms/step - loss: 0.2023 - accuracy: 0.9421
Epoch 2/2
938/938 [==============================] - 7s 8ms/step - loss: 0.0678 - accuracy: 0.9797

As you can see, hub.KerasLayer wraps the result loaded back from tf.saved_model.load() into a Keras layer that can be used to build another model. This is very useful for transfer learning.

Which API should I use?

For saving, if you are working with a keras model, it is almost always recommended to use the Keras's model.save() API. If what you are saving is not a Keras model, then the lower level API is your only choice.

For loading, which API you use depends on what you want to get from the loading API. If you cannot (or do not want to) get a Keras model, then use tf.saved_model.load(). Otherwise, use tf.keras.models.load_model(). Note that you can get a Keras model back only if you saved a Keras model.

It is possible to mix and match the APIs. You can save a Keras model with model.save, and load a non-Keras model with the low-level API, tf.saved_model.load.

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)

Caveats

A special case is when you have a Keras model that does not have well-defined inputs. For example, a Sequential model can be created without any input shapes (Sequential([Dense(3), ...]). Subclassed models also do not have well-defined inputs after initialization. In this case, you should stick with the lower level APIs on both saving and loading, otherwise you will get an error.

To check if your model has well-defined inputs, just check if model.inputs is None. If it is not None, you are all good. Input shapes are automatically defined when the model is used in .fit, .evaluate, .predict, or when calling the model (model(inputs)).

Here is an example:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
W0813 02:05:55.880124 140169908946688 save.py:161] Skipping full serialization of Keras model <__main__.SubclassedModel object at 0x7f7bb46f5e10>, because its inputs are not defined.
W0813 02:05:55.883413 140169908946688 save.py:168] Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f7bb46f5cf8>, because it is not built.