分散ストラテジーを使ってモデルを保存して読み込む

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

このチュートリアルでは、トレーニング中またはトレーニング後に tf.distribute.Strategy を使用して SavedModel 形式でモデルを保存して読み込む方法を説明します。Keras モデルの保存と読み込みには、高レベル(tf.keras.Model.savetf.keras.models.load_model)と低レベル(tf.saved_model.savetf.saved_model.load)の 2 種類の API があります。

SavedModel とシリアル化の全般的な内容については、SavedModel ガイドKeras モデルのシリアル化ガイドをお読みください。では、単純な例から始めましょう。

注意: TensorFlow モデルはコードであるため、信頼できないコードには注意する必要があります。詳細は、TensorFlow を安全に使用するをご覧ください。

依存関係をインポートします。

import tensorflow_datasets as tfds

import tensorflow as tf
2024-01-11 18:14:24.741176: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:14:24.741223: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:14:24.742716: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow Datasets と tf.data でデータを読み込んで準備し、tf.distribute.MirroredStrategy を使ってモデルを作成します。

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

tf.keras.Model.fit を使用してモデルをトレーニングします。

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
2024-01-11 18:14:31.467356: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1704996877.849655   50970 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
235/235 [==============================] - ETA: 0s - loss: 0.3291 - sparse_categorical_accuracy: 0.9094INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
235/235 [==============================] - 8s 8ms/step - loss: 0.3291 - sparse_categorical_accuracy: 0.9094
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9685
<keras.src.callbacks.History at 0x7f7e10386cd0>

モデルを保存して読み込む

作業に使用する単純なモデルを準備できたので、保存と読み込みに使用する API を見てみましょう。使用できる API には、以下の 2 種類があります。

Keras API

Keras API を使用したモデルの保存と読み込みの例を以下に示します。

keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)

tf.distribute.Strategy を使用せずにモデルを復元します。

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
235/235 [==============================] - 2s 4ms/step - loss: 0.0686 - sparse_categorical_accuracy: 0.9798
Epoch 2/2
235/235 [==============================] - 1s 4ms/step - loss: 0.0540 - sparse_categorical_accuracy: 0.9844
<keras.src.callbacks.History at 0x7f7f5048b670>

モデルを復元したら、Model.compile をもう一度呼び出さずにそのままトレーニングを続行できます。これは、保存前にすでにコンパイル済みであるためです。このモデルは、Keras zip アーカイブ形式で保存されており、.keras 拡張子で識別できます。詳細については、Keras の保存に関するガイドをご覧ください。

次に、tf.distribute.Strategy を使用してモデルを復元し、トレーニングします。

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2024-01-11 18:14:45.633878: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2024-01-11 18:14:45.694588: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
15/235 [>.............................] - ETA: 2s - loss: 0.0826 - sparse_categorical_accuracy: 0.9776
2024-01-11 18:14:46.266915: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2024-01-11 18:14:46.267137: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2024-01-11 18:14:46.301591: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
235/235 [==============================] - 3s 11ms/step - loss: 0.0699 - sparse_categorical_accuracy: 0.9796
Epoch 2/2
235/235 [==============================] - 3s 11ms/step - loss: 0.0531 - sparse_categorical_accuracy: 0.9844

Model.fit 出力からわかるように、tf.distribute.Strategy を使って期待どおり読み込まれました。ここで使用されるストラテジーは、保存前と同じストラテジーである必要はありません。

tf.saved_model API

より低レベルの API を使用したモデルの保存方法は、Keras API を使う方法に似ています。

model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

読み込みは、tf.saved_model.load を使用して行えますが、これは低レベル API(したがって、より幅広いユースケースのある API)であるため、Keras モデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。以下に例を示します。

DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

読み込まれたオブジェクトには、それぞれにキーが関連付けられた複数の関数が含まれている可能性があります。"serving_default" キーは、保存された Keras モデルを使用した推論関数のデフォルトのキーです。この関数で推論するには、以下のようにします。

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[ 0.04390968,  0.30341768,  0.05374109, ..., -0.35343656,
         0.03065785, -0.00975093],
       [-0.04910231,  0.16482985,  0.06436244, ..., -0.27770516,
         0.02216907,  0.13293922],
       [-0.05661844,  0.2683993 , -0.06041192, ..., -0.26340052,
         0.02152548,  0.10264045],
       ...,
       [-0.12805948,  0.11079367, -0.10359426, ..., -0.26105058,
         0.0311166 ,  0.02954188],
       [-0.11231118,  0.22162321,  0.04027553, ..., -0.34616578,
         0.02095792,  0.01622906],
       [-0.07966347,  0.08217648, -0.14690818, ..., -0.21150741,
         0.03090278, -0.12792973]], dtype=float32)>}
2024-01-11 18:14:52.675006: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

また、分散方法で読み込んで推論を実行することもできます。

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2024-01-11 18:14:52.889461: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
{'dense_3': PerReplica:{
  0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 0.04390973,  0.30341768,  0.05374111, -0.02709487, -0.28792804,
        -0.19333729,  0.325674  , -0.3534366 ,  0.03065785, -0.00975094],
       [-0.04910226,  0.1648299 ,  0.06436247,  0.0941631 , -0.36874533,
        -0.20513675,  0.09101719, -0.27770516,  0.02216907,  0.1329391 ],
       [-0.05661836,  0.26839924, -0.0604119 ,  0.04293879, -0.37735796,
        -0.01866844,  0.23681116, -0.26340055,  0.02152544,  0.10264044],
       [ 0.10640895,  0.1561212 ,  0.06909597, -0.06987031, -0.1984469 ,
        -0.04289627,  0.18389529, -0.18640813,  0.06818488, -0.11756891],
       [ 0.01074794,  0.21058783, -0.13376951, -0.07198893, -0.34633294,
        -0.0951823 ,  0.09859037, -0.17102982,  0.00822654,  0.02078733],
       [ 0.02912218,  0.19898024, -0.2194208 , -0.09297523, -0.22816458,
        -0.14823863,  0.1251952 , -0.22406554,  0.07672149, -0.06627873],
       [-0.07373619,  0.1642075 ,  0.07785851,  0.0902054 , -0.38338202,
        -0.16163939,  0.15576416, -0.2677853 ,  0.03417471,  0.14622498],
       [-0.01103535,  0.13357913, -0.11604594, -0.1436121 , -0.11463816,
        -0.09246697,  0.02174798, -0.13517785,  0.11130458, -0.16181585],
       [-0.05913471,  0.25684866, -0.08185954,  0.11305049, -0.41555196,
        -0.11584461,  0.17271446, -0.28591448,  0.0171692 , -0.03395621],
       [ 0.08944317,  0.0613468 , -0.1547694 , -0.04008061, -0.12746565,
        -0.06679891,  0.08744669, -0.14509115, -0.00697993, -0.20970711],
       [-0.02094699,  0.1820338 , -0.00243831, -0.01445665, -0.30628368,
        -0.20580608,  0.17843154, -0.25234053, -0.02907382, -0.03199375],
       [ 0.01979519,  0.2740837 , -0.01698644, -0.18219724, -0.40358156,
        -0.21365486,  0.04891555, -0.29936785,  0.06018814, -0.07291923],
       [-0.10811754,  0.13060504, -0.03464153,  0.06798229, -0.43056038,
         0.01354938,  0.21008897, -0.31439742, -0.08188079,  0.11052423],
       [ 0.05240241,  0.17336613, -0.17786968, -0.08816151, -0.34909615,
        -0.11435185,  0.07548714, -0.27006733,  0.09544405, -0.20011938],
       [ 0.15914436,  0.24435948, -0.04742448, -0.02628867, -0.30243245,
        -0.13071185,  0.19322398, -0.34721792,  0.08353946, -0.16963294],
       [ 0.00991537,  0.2729301 , -0.10826392, -0.18254003, -0.303356  ,
        -0.2544545 ,  0.09932231, -0.23833567,  0.09032176,  0.1128719 ],
       [ 0.12592334,  0.10697531, -0.03352933, -0.18785489, -0.08577707,
        -0.20832878,  0.13447234, -0.06923658,  0.01704483, -0.13321611],
       [-0.05993862,  0.05663422, -0.05601699, -0.11318138, -0.32292652,
        -0.13415524,  0.07078642, -0.24608482, -0.03674737, -0.17912035],
       [-0.08632107,  0.1418068 , -0.13200448,  0.01969666, -0.25682348,
        -0.14198145,  0.10951944, -0.27877283,  0.02858731, -0.0374111 ],
       [ 0.11897041,  0.2594269 ,  0.03674663, -0.11091953, -0.30228472,
        -0.11023479,  0.1663589 , -0.376381  ,  0.06430978, -0.16984001],
       [-0.00289295,  0.04047507, -0.11863903, -0.05876009, -0.23271626,
        -0.13283929,  0.12130734, -0.13938479, -0.07576559,  0.01863652],
       [ 0.01375921, -0.00918115, -0.09295098, -0.06612265, -0.159753  ,
        -0.12676233,  0.0832831 , -0.20027536, -0.06356223, -0.1602694 ],
       [-0.11601318,  0.05592554, -0.00799035, -0.03840588, -0.31699634,
        -0.12557843,  0.00217704, -0.13452399, -0.04661082, -0.00160295],
       [-0.04387126,  0.02643408, -0.08738586,  0.0719263 , -0.27419144,
        -0.10074591, -0.00859936, -0.1349473 , -0.02253597,  0.09461489],
       [ 0.00437208,  0.09436843, -0.05663288, -0.00419031, -0.22780542,
        -0.1780376 ,  0.1348795 , -0.1746904 , -0.1173916 , -0.06692152],
       [ 0.03692336,  0.20061986,  0.07223521, -0.00423313, -0.36737412,
        -0.11080196,  0.1832721 , -0.31301066, -0.0066702 , -0.13494146],
       [ 0.05483724,  0.03380407, -0.04579096, -0.0625519 , -0.19104283,
        -0.12388766,  0.04404955, -0.1911788 , -0.02904271, -0.12793231],
       [-0.03292175,  0.0789917 , -0.20172037, -0.05561272, -0.17042139,
        -0.29631954,  0.15214083, -0.10865977, -0.04539655, -0.09254031],
       [ 0.05429688,  0.24875508, -0.25274026, -0.08207695, -0.250165  ,
        -0.07061552,  0.03925603, -0.1252473 ,  0.11538928,  0.02362049],
       [-0.01455817,  0.08509667, -0.02281824,  0.09500916, -0.20391591,
        -0.06684849,  0.23080888, -0.33867043,  0.0283998 ,  0.05314478],
       [-0.00075892,  0.15760201,  0.07448781, -0.02245309, -0.27052298,
        -0.02893338,  0.24336715, -0.29186094, -0.0194285 , -0.04287457],
       [-0.07933994,  0.1979461 ,  0.0075167 ,  0.02225138, -0.36191055,
        -0.01780383,  0.10470055, -0.2530624 ,  0.044648  , -0.06972083],
       [-0.00652711,  0.17066438, -0.03553644, -0.07859261, -0.4780447 ,
        -0.11430528,  0.17088383, -0.35727382,  0.01104823, -0.02751846],
       [ 0.01646123,  0.02834756, -0.15534458, -0.04221815, -0.17725925,
        -0.1818274 ,  0.0135299 , -0.20522466, -0.03495998, -0.04329636],
       [-0.17032023,  0.1409552 , -0.14403607, -0.03054993, -0.40012705,
        -0.09377775,  0.00878634, -0.12934831,  0.03538194,  0.02580342],
       [-0.1051839 ,  0.24074066, -0.08009566, -0.03220345, -0.39412522,
        -0.09787439,  0.25904325, -0.29801127,  0.029187  ,  0.14273585],
       [-0.02889195,  0.27769655,  0.14630178,  0.00704605, -0.42074952,
        -0.09446865,  0.22208358, -0.2960379 , -0.10665543,  0.00482226],
       [ 0.108211  ,  0.1608254 , -0.12860669, -0.15433891, -0.16320105,
        -0.29223463,  0.14248174, -0.1457768 ,  0.076525  , -0.10444155],
       [ 0.05384266,  0.02627722,  0.05595293, -0.05162864, -0.09512395,
         0.02082699,  0.14003542, -0.21606655, -0.02665446, -0.15942779],
       [ 0.08897467,  0.12156868,  0.0072432 , -0.03514582, -0.2993672 ,
        -0.16040912,  0.13929161, -0.25935876,  0.03195405, -0.17140916],
       [-0.05919135,  0.10460456, -0.0541361 , -0.05133465, -0.29787263,
        -0.01104981,  0.09286734, -0.18553476,  0.03485874,  0.02347505],
       [ 0.10552   ,  0.13782081, -0.01198097,  0.02319556, -0.23136751,
        -0.20539609,  0.30999058, -0.33470595, -0.05787981, -0.03471535],
       [ 0.01394254,  0.04610374, -0.1555309 ,  0.02138674, -0.15324359,
        -0.13208178,  0.13295804, -0.31510523, -0.04771263, -0.12409957],
       [ 0.08796158,  0.2096323 ,  0.03023741, -0.06917568, -0.18738158,
        -0.04232989,  0.24464765, -0.32475582, -0.0190682 , -0.12991205],
       [-0.02886308,  0.1945777 , -0.19298053, -0.09160493, -0.32743698,
        -0.11305106,  0.08422519, -0.17924672,  0.06171962,  0.02897899],
       [ 0.12523907,  0.19837332,  0.08683064,  0.08463505, -0.28432304,
        -0.16086362,  0.2658471 , -0.3375882 , -0.00523884, -0.11530196],
       [ 0.02348459,  0.06094605, -0.20900917,  0.08927577, -0.20602939,
        -0.09567806,  0.16529328, -0.22091651,  0.0310331 , -0.01366656],
       [-0.03577258,  0.16763109, -0.00228991, -0.03373875, -0.43727922,
        -0.23163418,  0.15620582, -0.30316573,  0.02358727,  0.02402775],
       [-0.01045046,  0.23059595, -0.10650987,  0.04214311, -0.35271657,
        -0.15367958,  0.18296245, -0.3670715 ,  0.00242524, -0.01590508],
       [ 0.11234806,  0.127929  , -0.1416234 ,  0.01856027, -0.1370542 ,
        -0.07529534,  0.15012704, -0.11457616,  0.05310739, -0.23884141],
       [-0.12933847,  0.18685465, -0.0281913 ,  0.03565511, -0.31908727,
        -0.02797761,  0.01718334, -0.21260841,  0.08538416,  0.10911464],
       [ 0.02774051,  0.14014125, -0.11116177,  0.06069104, -0.2972231 ,
        -0.08201168,  0.18919642, -0.2732424 , -0.07256175, -0.04372493],
       [-0.05285444,  0.09731005, -0.16486159,  0.14584875, -0.24468833,
        -0.11813442,  0.1891388 , -0.27349263,  0.02680733,  0.02208898],
       [ 0.01703629,  0.14432956, -0.07905609, -0.05359066, -0.2380548 ,
        -0.07374218,  0.18941419, -0.2671988 , -0.04146045, -0.01040378],
       [-0.01229454,  0.2238013 , -0.07782885, -0.05288902, -0.34084964,
        -0.07656305,  0.25703445, -0.38321817,  0.02506851,  0.0528542 ],
       [-0.09558831,  0.05508473, -0.20615673, -0.07739481, -0.35988048,
        -0.23042265,  0.05733413, -0.20256129, -0.04815403, -0.07377221],
       [-0.12882753,  0.13142166, -0.13448793,  0.07987095, -0.35866457,
        -0.03530126,  0.09230228, -0.17570223,  0.09357665,  0.06170238],
       [-0.01151971,  0.20554355, -0.04922439, -0.08847281, -0.3856749 ,
        -0.0984439 ,  0.13491108, -0.3456065 ,  0.08735792, -0.03551546],
       [ 0.14842089,  0.34935832,  0.09578269,  0.01213365, -0.28933874,
        -0.07850939,  0.19179463, -0.3372814 ,  0.12811774, -0.01675376],
       [-0.18260533,  0.04856708, -0.11119141, -0.07749753, -0.3635567 ,
        -0.23310214,  0.13667285, -0.14879341, -0.14041449,  0.07667582],
       [-0.17618516,  0.1462372 , -0.18028377,  0.06713124, -0.38335955,
        -0.03557798,  0.10410185, -0.17592236, -0.01435028,  0.15338154],
       [ 0.04055981,  0.24520999, -0.17271513, -0.14913839, -0.3042269 ,
        -0.12039592,  0.2689139 , -0.32811943, -0.00502907,  0.00336872],
       [ 0.03560546,  0.1754866 , -0.11660206, -0.03360491, -0.27283472,
        -0.18731177,  0.08661997, -0.09050013,  0.04968301, -0.07649884],
       [-0.02088141,  0.13798103, -0.18891963,  0.00372715, -0.16121043,
        -0.1467973 ,  0.05670452, -0.16789939,  0.09856127,  0.01790601]],
      dtype=float32)>,
  1: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-0.09833355,  0.15665436, -0.07166141,  0.01613978, -0.3543378 ,
        -0.08038074,  0.2748848 , -0.28265885, -0.04764143,  0.09244367],
       [-0.10978055,  0.0580076 , -0.16466689, -0.07054193, -0.37483492,
        -0.22544345, -0.01834334, -0.23438914, -0.05652276, -0.00428347],
       [ 0.06102126,  0.23393367, -0.12214032, -0.1585012 , -0.47405833,
        -0.18904763,  0.14417635, -0.31168947,  0.07823397, -0.17237163],
       [ 0.13491186, -0.09206396, -0.17387056, -0.02996402, -0.07794836,
        -0.10825613,  0.09467977, -0.12667504, -0.10682128, -0.2489583 ],
       [-0.03715318,  0.09756672, -0.13965347, -0.1194995 , -0.2671111 ,
        -0.21588892,  0.20645793, -0.16797486, -0.0715595 ,  0.10796719],
       [-0.02775428,  0.06620544, -0.07389922, -0.08922986, -0.24126638,
        -0.19932179,  0.09457048, -0.27572516,  0.00439627, -0.13424556],
       [-0.13065772,  0.04677206, -0.11908375, -0.03384325, -0.29424554,
        -0.08093098,  0.04305109, -0.12765141, -0.09444566,  0.03544688],
       [ 0.13585913,  0.06523642, -0.05614541, -0.17160474, -0.16298045,
        -0.24585429,  0.2457303 , -0.1337391 , -0.03644121, -0.06396304],
       [-0.00827954,  0.17156884, -0.08429323, -0.0831846 , -0.29328197,
        -0.13459298,  0.11706164, -0.19470227,  0.06820276, -0.13680328],
       [-0.08339091,  0.2388103 , -0.11374965,  0.10000415, -0.40983635,
        -0.10774978,  0.24996822, -0.33086625, -0.05999116,  0.12141761],
       [ 0.09205646, -0.01435722, -0.17090788, -0.05419794, -0.07356755,
        -0.03496896,  0.0213257 , -0.11919186,  0.00437786, -0.21572636],
       [ 0.02209954,  0.11327548, -0.08586521, -0.10829636, -0.1623912 ,
        -0.14285421,  0.08728215, -0.17823447,  0.10733861, -0.12263825],
       [-0.01373981,  0.22072753, -0.13233714,  0.04230879, -0.38872594,
        -0.08226185,  0.1575003 , -0.24028088,  0.07385449, -0.03566224],
       [ 0.03837798,  0.13289425,  0.04731251, -0.02162105, -0.20273682,
        -0.05832739,  0.18874544, -0.18923599, -0.00774726, -0.07874206],
       [-0.03777552, -0.00108435, -0.1691374 ,  0.07609938, -0.14852983,
        -0.0603545 ,  0.03725195, -0.07498097,  0.03552598,  0.09923848],
       [-0.00528672,  0.09970027, -0.18529828,  0.06165883, -0.18298753,
        -0.18982115,  0.13438234, -0.20037794,  0.04573667, -0.12467755],
       [ 0.02150192,  0.19947992,  0.07855526, -0.06914394, -0.3714067 ,
        -0.11594264,  0.09797562, -0.26935732, -0.02487007,  0.04328841],
       [-0.02913332,  0.06816912, -0.15824044, -0.09580886, -0.27435338,
        -0.27287874,  0.03502546, -0.24163257, -0.0320179 , -0.01755684],
       [-0.09645626,  0.26521719, -0.02974376, -0.11811897, -0.26733863,
        -0.11098027,  0.17992583, -0.17170358,  0.13684484,  0.11146528],
       [ 0.02255949,  0.17668228, -0.18393126,  0.03314231, -0.24813469,
        -0.15481868,  0.09138933, -0.20842537,  0.14024018, -0.09168213],
       [ 0.03577811,  0.20097674,  0.01054525,  0.0281911 , -0.37472528,
        -0.01357181,  0.06443338, -0.3450846 , -0.0717351 , -0.05323324],
       [-0.07593666,  0.01255987, -0.22572437, -0.00785737, -0.2546096 ,
        -0.07944219,  0.1108809 , -0.24050121,  0.04689084, -0.06361024],
       [-0.06727166,  0.12973833, -0.13008472,  0.05364395, -0.40278557,
        -0.05354337,  0.15421708, -0.19313724,  0.034807  ,  0.0324662 ],
       [ 0.01260601,  0.17873257, -0.04235746,  0.04865025, -0.3416106 ,
        -0.18992788,  0.20156004, -0.35417786, -0.02896209, -0.06149546],
       [-0.21761686,  0.04707823, -0.10519136, -0.02330001, -0.44696686,
        -0.06795169, -0.0549196 , -0.07158573,  0.02182202,  0.01991212],
       [-0.04528661,  0.11444949, -0.05191335, -0.00863688, -0.3279702 ,
        -0.08010425,  0.02780636, -0.18188381,  0.01212155,  0.02914584],
       [-0.07794154,  0.18604933, -0.08827695,  0.00928481, -0.43883204,
        -0.08850133,  0.17446704, -0.2918891 ,  0.07034804,  0.02960903],
       [-0.16039802,  0.06828902,  0.04284256,  0.0398021 , -0.3121784 ,
        -0.09335808,  0.03558116, -0.11442499, -0.02113784,  0.20230052],
       [ 0.08210264, -0.05593535, -0.23607862, -0.01597648, -0.14844231,
        -0.1827047 ,  0.0761399 , -0.16778287, -0.07931188, -0.19862217],
       [ 0.12508929, -0.06630521, -0.17870858, -0.00659649, -0.0866005 ,
        -0.1466447 ,  0.11984183, -0.11472597, -0.07381313, -0.25458795],
       [-0.16582993,  0.12019867, -0.12963267,  0.01781088, -0.3867995 ,
        -0.0929767 ,  0.05690374, -0.18877956, -0.09080505,  0.12845796],
       [-0.04166535,  0.11107228, -0.13034783,  0.11846082, -0.2223956 ,
        -0.04571397,  0.19896089, -0.21065699,  0.02911645,  0.0283865 ],
       [ 0.05492067,  0.14563766, -0.05812249, -0.11080064, -0.15607221,
        -0.14377254,  0.11191408, -0.03224962,  0.08461984, -0.10316228],
       [-0.10184978,  0.1774663 , -0.04064796, -0.00778899, -0.24548785,
        -0.04007315,  0.15378167, -0.16978785,  0.08297548,  0.15654306],
       [ 0.03246473, -0.00912588, -0.13997413,  0.01901177, -0.15169062,
        -0.1233996 ,  0.0399372 , -0.14586736, -0.04990026, -0.05761716],
       [-0.04208291,  0.06552696, -0.1491316 , -0.05165012, -0.30772424,
        -0.14849369,  0.05366952, -0.20017433, -0.00489577, -0.05179733],
       [ 0.07999136,  0.1268666 , -0.05649832, -0.11447802, -0.093622  ,
        -0.16343006,  0.18737765, -0.15467761,  0.02121065, -0.18551901],
       [-0.09349701,  0.15481542,  0.02054488, -0.01257675, -0.45885608,
        -0.07252534,  0.11548893, -0.25998053, -0.15728104,  0.13945304],
       [-0.12011181,  0.08322111, -0.13839091,  0.01295412, -0.31343067,
        -0.12020873,  0.10251326, -0.23772089, -0.07965665,  0.12370199],
       [ 0.04355327,  0.24670643, -0.10936423, -0.11540431, -0.47812563,
        -0.11112159,  0.171045  , -0.27362657,  0.05597766, -0.08047073],
       [-0.07981615,  0.1407882 , -0.05179737,  0.11515354, -0.39608997,
         0.12418469,  0.12287964, -0.35551202,  0.05395594,  0.106438  ],
       [ 0.0495141 ,  0.16530135, -0.07892615, -0.04727711, -0.17437093,
        -0.10992847,  0.11209463, -0.17067385,  0.15505184, -0.04079971],
       [ 0.0399215 ,  0.21293016, -0.0784945 , -0.05903067, -0.4351067 ,
        -0.0897992 ,  0.19701788, -0.29322943, -0.02085371, -0.08934651],
       [-0.03191719,  0.09556475, -0.02399754, -0.07302108, -0.31528878,
        -0.01236389,  0.07309026, -0.25867784, -0.02151406,  0.04411262],
       [ 0.1144869 , -0.07487808, -0.09897751, -0.03669246, -0.07122576,
        -0.06312343,  0.0934795 , -0.08661507, -0.06738903, -0.19981778],
       [ 0.05407554,  0.29886216, -0.04840349, -0.08312691, -0.40975356,
        -0.30462724,  0.18296465, -0.36522552, -0.01246376, -0.03964518],
       [ 0.0863948 ,  0.27042687,  0.07862431, -0.07175243, -0.23997524,
        -0.19032082,  0.24878645, -0.2724223 ,  0.04755342, -0.08436754],
       [ 0.01792005,  0.1303249 ,  0.00751538, -0.04509098, -0.24403217,
        -0.10693009,  0.12554878, -0.22423464, -0.01941013, -0.05743247],
       [ 0.10427958,  0.04477797, -0.21033043, -0.0367108 , -0.23080736,
        -0.14509909,  0.09560896, -0.15382816,  0.05470648, -0.18330926],
       [ 0.06824732, -0.00402825, -0.15888713, -0.07463223, -0.16558173,
        -0.08697934,  0.0755794 , -0.16385713, -0.01368479, -0.2683872 ],
       [-0.18796708,  0.08231901, -0.10495865, -0.02746309, -0.4590808 ,
        -0.09270471,  0.11427653, -0.18020634, -0.08563931,  0.114185  ],
       [-0.09791254,  0.03529916, -0.01679311,  0.04053622, -0.28840482,
        -0.05182896,  0.09699607, -0.0433149 , -0.04200018, -0.01914415],
       [-0.11433886,  0.13621204, -0.14640698,  0.09180373, -0.340676  ,
        -0.04544587,  0.16125873, -0.22271709, -0.01954299,  0.10093789],
       [-0.10057668,  0.16380787, -0.11834459, -0.00495677, -0.43168932,
        -0.06923562,  0.10913345, -0.17641947, -0.02874756,  0.06140285],
       [-0.06116604,  0.13700888, -0.03211374, -0.01969503, -0.37303066,
        -0.01124507,  0.18596011, -0.20456824, -0.04335499,  0.06383611],
       [-0.06899989,  0.13957006, -0.15314595,  0.00410667, -0.41311985,
        -0.06501595,  0.21060655, -0.31016892, -0.09425378,  0.05115268],
       [ 0.00838745,  0.1184913 , -0.14133814,  0.11310323, -0.19835803,
        -0.11010459,  0.20528942, -0.3756318 ,  0.00677975, -0.09190771],
       [-0.01702203,  0.23581095, -0.11833926, -0.05140661, -0.3833909 ,
        -0.2917281 ,  0.15348044, -0.2923426 , -0.01539594, -0.07770756],
       [ 0.00376625,  0.14644173, -0.22676727, -0.00695115, -0.25346804,
        -0.08446033,  0.15793926, -0.28682148,  0.06036796, -0.05657917],
       [-0.15886536,  0.22905974, -0.08670876,  0.02516738, -0.44418338,
        -0.05140979,  0.12132007, -0.24890812, -0.12002567,  0.17170556],
       [-0.00594457,  0.1702473 , -0.04271276,  0.03786185, -0.33728057,
        -0.10691148,  0.14885393, -0.2579561 ,  0.03387661, -0.12696798],
       [ 0.12660183,  0.0719447 , -0.0284185 , -0.01238126, -0.1789788 ,
        -0.14557405,  0.11981218, -0.27433586, -0.04192446, -0.15991354],
       [ 0.01714662,  0.07349592, -0.13597934,  0.07258661, -0.17497018,
        -0.1185775 ,  0.12789062, -0.21281385,  0.0407939 , -0.12549406],
       [-0.05560762,  0.09808804, -0.03752606,  0.03666614, -0.34828562,
        -0.08055064,  0.14122654, -0.35993794, -0.03219686, -0.07910113]],
      dtype=float32)>,
  2: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 0.1058839 ,  0.15189251, -0.06218533, -0.07259189, -0.21885905,
        -0.18747613,  0.19386706, -0.20872134, -0.06631209, -0.17281003],
       [-0.06724259,  0.16423771, -0.09042411,  0.06249507, -0.35925558,
        -0.08052117,  0.06635582, -0.21773912, -0.02575091, -0.02225357],
       [-0.12099646,  0.18426058, -0.14509366,  0.03074409, -0.43172374,
        -0.14523304,  0.11278548, -0.263721  ,  0.08033177,  0.11983414],
       [-0.02189111,  0.19432685, -0.06076544, -0.11341121, -0.31391418,
        -0.13985458,  0.01988082, -0.27448916, -0.03138182, -0.09098685],
       [ 0.05217094,  0.10092196,  0.04062135,  0.00210309, -0.2396484 ,
        -0.09674437,  0.11046519, -0.24062449, -0.05751724, -0.1498842 ],
       [ 0.05167808,  0.17948417,  0.02552646, -0.15288836, -0.2610386 ,
        -0.1887011 ,  0.2494685 , -0.25246763, -0.00140664, -0.05116289],
       [-0.06461833,  0.18794754,  0.01723582,  0.01763999, -0.27383786,
         0.01134103,  0.00840673, -0.19577272,  0.04369798,  0.02309319],
       [ 0.1343988 ,  0.02689996, -0.09034813, -0.14479995, -0.11877482,
        -0.20683356,  0.1408261 , -0.07485476, -0.00536738, -0.10480646],
       [ 0.14985448,  0.01248868, -0.18233272, -0.04805535, -0.12569216,
        -0.10754709,  0.10563765, -0.14238729, -0.04125558, -0.30692846],
       [ 0.0174889 ,  0.02731286, -0.09870998, -0.01109812, -0.22202691,
        -0.15064979,  0.0707394 , -0.16593488, -0.04917879, -0.09892345],
       [-0.02009851,  0.1603423 , -0.03657382, -0.04635969, -0.376585  ,
        -0.15620278,  0.09772854, -0.331561  , -0.08288915,  0.03283934],
       [-0.12873377,  0.18706101, -0.08856387,  0.01241396, -0.33643866,
        -0.04564854,  0.19829091, -0.14762393,  0.05415195,  0.09738437],
       [ 0.02919928,  0.2304655 , -0.118172  , -0.02557866, -0.20723924,
        -0.15870155,  0.10659584, -0.29492664,  0.08007928, -0.02028007],
       [ 0.00755258,  0.13715193, -0.04229313, -0.02631894, -0.3507014 ,
        -0.09775324,  0.08749582, -0.28658673, -0.06305373, -0.06657201],
       [ 0.01293204,  0.34690908,  0.04428834,  0.06181278, -0.39632797,
        -0.16290376,  0.3403181 , -0.3575364 , -0.04994334, -0.04745176],
       [-0.1721677 ,  0.15409604,  0.01526155,  0.00422953, -0.40304124,
        -0.20704678,  0.18548846, -0.1347191 , -0.1029956 ,  0.12649345],
       [ 0.00416036,  0.0911807 , -0.14212427, -0.02980888, -0.41507924,
        -0.08092883,  0.09573203, -0.3632136 , -0.05752493, -0.10265536],
       [ 0.08953582,  0.1465146 , -0.0943587 ,  0.01540706, -0.1591577 ,
        -0.10180154,  0.26632407, -0.14738992,  0.05417037, -0.19402057],
       [ 0.09650303,  0.1520698 , -0.12804723, -0.00549375, -0.13437587,
        -0.10345934,  0.11017533, -0.07745089,  0.07155803, -0.18119216],
       [-0.11120163,  0.05948977, -0.12460698,  0.07350107, -0.2501851 ,
        -0.10606505,  0.03863706, -0.19143325,  0.05141674,  0.05777346],
       [-0.06015262,  0.1133145 ,  0.02975767,  0.10413057, -0.4479565 ,
        -0.19103777,  0.16624743, -0.23277   , -0.00409902,  0.16262262],
       [-0.02920761,  0.05220562, -0.01984085, -0.09213866, -0.1952639 ,
        -0.2091724 ,  0.09540902, -0.25029248, -0.01369284, -0.07937889],
       [ 0.08655809,  0.04477531, -0.14414985, -0.13330539, -0.14485738,
        -0.25247657,  0.161552  , -0.10539306, -0.07818886, -0.22021733],
       [ 0.05988639,  0.09053953, -0.20782766,  0.0799824 , -0.36067218,
        -0.11564009,  0.19533472, -0.17254165, -0.03013015, -0.05901307],
       [ 0.08023848,  0.24398087, -0.03037388,  0.03198615, -0.257546  ,
        -0.0717455 ,  0.11799999, -0.2501925 ,  0.08931698,  0.01646613],
       [ 0.01829973,  0.14714168,  0.08777735, -0.04686411, -0.3246078 ,
        -0.11189084,  0.16163842, -0.27373928, -0.05865945, -0.12625104],
       [ 0.01551664,  0.15430057, -0.06801915,  0.04362161, -0.23502748,
        -0.10225095,  0.20541084, -0.3872087 , -0.03092569, -0.10081216],
       [-0.0215794 ,  0.2048584 , -0.1525181 , -0.07231435, -0.35670137,
        -0.20132679,  0.16621551, -0.23583391,  0.03286067,  0.00162398],
       [-0.16785663,  0.14515236, -0.17006704, -0.07688057, -0.33053195,
        -0.30842096,  0.08084383, -0.16129194, -0.01640936,  0.00512123],
       [-0.01012924,  0.18210989, -0.11657718,  0.01039248, -0.3604356 ,
        -0.06368212,  0.17732035, -0.32601178,  0.00564714,  0.079637  ],
       [-0.02081634,  0.1340083 , -0.04119257, -0.11712583, -0.24917075,
        -0.14567725,  0.11910035, -0.19569927, -0.0334375 , -0.02956842],
       [ 0.01040663, -0.03312722, -0.16251259, -0.06878473, -0.29412496,
        -0.12144031,  0.05180205, -0.18370074, -0.08275174, -0.17228474],
       [ 0.16995256,  0.15084882,  0.06516901, -0.12424685, -0.28156468,
        -0.12124009,  0.1919595 , -0.3235635 , -0.16550806, -0.2237745 ],
       [ 0.14201358,  0.29564446, -0.05666097, -0.11691651, -0.38729382,
        -0.12142777,  0.18535727, -0.4062221 , -0.02861272, -0.08876766],
       [-0.13636869,  0.14419225, -0.12272403, -0.12043885, -0.25843346,
        -0.2624993 ,  0.11443155, -0.10935706,  0.07051769, -0.07280853],
       [ 0.1506514 ,  0.231885  ,  0.05595114, -0.00167914, -0.20262057,
        -0.12446204,  0.20270723, -0.24610087,  0.13529469, -0.02456838],
       [ 0.05419209,  0.14951019, -0.07615753, -0.01649981, -0.21320246,
        -0.10147403,  0.21325038, -0.3668218 , -0.06149146, -0.04654743],
       [-0.106455  ,  0.21365353, -0.0212086 ,  0.00301608, -0.42738837,
        -0.03399226,  0.21012497, -0.27643734, -0.04460012,  0.09438403],
       [-0.20142697,  0.2490167 , -0.0610823 ,  0.03178958, -0.46005037,
        -0.11769256,  0.12650084, -0.11382679,  0.05542684,  0.20377554],
       [ 0.06785811,  0.20105124, -0.06821625, -0.07295625, -0.3634079 ,
        -0.14018053,  0.1070195 , -0.32426855, -0.02956013, -0.08126846],
       [-0.16207823,  0.10067027, -0.17504756,  0.05559444, -0.41034544,
        -0.10927348,  0.10540979, -0.1979583 , -0.03317957,  0.13901047],
       [-0.06511355,  0.12664205,  0.01013242, -0.1038757 , -0.25350553,
        -0.07925887, -0.01203869, -0.04564004,  0.0292751 ,  0.04884695],
       [ 0.03302544,  0.25231293,  0.02099521,  0.08474679, -0.34921703,
        -0.09708492,  0.11735824, -0.3118752 , -0.04610436, -0.00284932],
       [-0.05722202,  0.23294924,  0.06066301,  0.0403684 , -0.30209425,
        -0.11734634,  0.0649392 , -0.22963187, -0.00225438,  0.09428503],
       [ 0.03891737,  0.2118478 , -0.02974258, -0.0014833 , -0.41377792,
        -0.10676192,  0.09491015, -0.30824155, -0.05244756, -0.13116692],
       [ 0.0387698 ,  0.12415141, -0.03787622, -0.01751001, -0.23610467,
        -0.101076  ,  0.10127868, -0.21973278,  0.10053106, -0.11506974],
       [-0.16079637,  0.12975846, -0.05280735,  0.0171553 , -0.31617314,
        -0.05437503,  0.11548977, -0.22737938,  0.0040158 ,  0.1828981 ],
       [-0.00830905,  0.06583936,  0.05894554, -0.06423548, -0.25190187,
        -0.16529903,  0.11135682, -0.20058249, -0.12155165,  0.0516177 ],
       [ 0.11394778,  0.13319941, -0.14848323, -0.00475912, -0.14136326,
        -0.08358698,  0.11641169, -0.11600958,  0.04616823, -0.24377736],
       [-0.04899902,  0.02212415, -0.07927552,  0.03247839, -0.28332067,
        -0.05081623,  0.10070023, -0.18853599, -0.04962738, -0.07133716],
       [ 0.06972719,  0.16070195, -0.06629859,  0.02116043, -0.2982213 ,
        -0.06711708,  0.15437523, -0.31704813,  0.09814174,  0.02899931],
       [ 0.10892028,  0.06948914,  0.00259475, -0.14409684, -0.05081703,
        -0.10783089,  0.11664214, -0.1612539 ,  0.03825954, -0.0974022 ],
       [ 0.11645016,  0.09103441, -0.08829162, -0.03281923, -0.11439721,
        -0.04074144,  0.13266224, -0.16234764,  0.0108225 , -0.1980096 ],
       [ 0.04067405,  0.28472447,  0.06697901,  0.00244627, -0.44801572,
        -0.17432919,  0.14524695, -0.33887076, -0.04374426, -0.01839055],
       [-0.07228023,  0.0094237 , -0.0114438 ,  0.05621499, -0.1906197 ,
        -0.10239108,  0.1188737 , -0.06362449, -0.04171722, -0.04671538],
       [ 0.11314879,  0.14128837, -0.10640095, -0.00203198, -0.1638219 ,
        -0.14681938,  0.15794587, -0.13135739, -0.00294857, -0.2193684 ],
       [-0.189158  ,  0.14615512, -0.1951129 ,  0.01067956, -0.46317485,
        -0.12231541,  0.17576897, -0.13713211,  0.05281383,  0.05961443],
       [ 0.13483778,  0.15122853, -0.06812572,  0.02223442, -0.31159732,
        -0.14329562,  0.13942383, -0.3621065 ,  0.0194601 , -0.17312187],
       [-0.07476285,  0.056327  , -0.1509085 ,  0.12273326, -0.25021672,
        -0.09043095,  0.14668559, -0.15018652, -0.03156234,  0.02678441],
       [-0.00460077,  0.1331944 , -0.14329152, -0.05613213, -0.3445416 ,
        -0.16092777,  0.1215454 , -0.25038692,  0.04866181, -0.1493005 ],
       [ 0.01425789,  0.12178193,  0.04116632, -0.08268491, -0.24872722,
        -0.13918278,  0.05446562, -0.19651294,  0.02265346, -0.06737643],
       [-0.01926323,  0.1012551 , -0.05458277,  0.046477  , -0.2919128 ,
        -0.13312258,  0.08117634, -0.38422066, -0.01930849,  0.00814056],
       [ 0.11184652,  0.30679387,  0.09332582, -0.01897855, -0.2433966 ,
        -0.15644044,  0.19370434, -0.18590672,  0.05792751, -0.14004235],
       [-0.01329161,  0.13452995, -0.22033097,  0.07828833, -0.29261756,
        -0.07592015,  0.13703327, -0.16651678,  0.01105583, -0.0150648 ]],
      dtype=float32)>,
  3: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 2.51155347e-03,  1.45447791e-01,  6.41264021e-04,
         1.91463828e-02, -1.85839653e-01, -1.03563413e-01,
         6.61028773e-02, -1.90979049e-01,  6.06872290e-02,
        -1.41341135e-01],
       [-4.54837829e-02,  1.08597815e-01, -1.17722288e-01,
         1.42128780e-01, -2.34650105e-01, -1.13956973e-01,
         1.98587418e-01, -2.94802964e-01, -1.37891099e-02,
         4.00971621e-04],
       [-1.90911219e-02,  1.18707336e-01, -3.10591292e-02,
         4.01906706e-02, -3.41950208e-01, -2.94433124e-02,
         1.32617578e-01, -2.99225211e-01, -3.27008069e-02,
        -1.80488713e-02],
       [-4.05668318e-02,  1.91518039e-01, -2.78929994e-02,
         6.00643232e-02, -1.37896061e-01, -9.06022638e-02,
         1.69026867e-01, -3.06107998e-01,  7.56823719e-02,
         6.13778904e-02],
       [-5.93572706e-02,  1.76578969e-01, -1.66006207e-01,
         5.42817116e-02, -3.62876147e-01, -1.09547049e-01,
         1.69198722e-01, -2.94477433e-01,  1.92826986e-02,
         1.70004927e-03],
       [-1.40251517e-02,  1.31867789e-02, -3.96135375e-02,
        -9.33120027e-02, -2.40361691e-01, -1.00401580e-01,
         1.22912303e-02, -1.80318967e-01, -8.82146955e-02,
        -1.93105519e-01],
       [ 1.16194248e-01,  1.79252625e-02, -1.44289330e-01,
        -4.96181250e-02, -1.14380077e-01, -9.49703977e-02,
         9.81448591e-02, -1.50415093e-01, -3.23846750e-02,
        -2.56290972e-01],
       [-1.55512661e-01,  1.78086460e-01, -3.19233388e-02,
         2.66162679e-02, -4.40953791e-01,  8.28160048e-02,
         9.33465362e-02, -2.50108629e-01,  3.17218527e-02,
         1.32439479e-01],
       [-6.79045916e-02,  1.01774186e-01, -1.06924191e-01,
        -3.54708657e-02, -2.07453206e-01, -3.10046911e-01,
         1.38181359e-01, -2.24454358e-01, -5.50230630e-02,
        -5.98306395e-02],
       [-7.62961432e-02,  1.54165313e-01,  2.80987248e-02,
        -1.35838151e-01, -3.44427884e-01, -1.93413466e-01,
         7.58136064e-02, -1.11874819e-01,  1.76186338e-02,
         1.20480917e-02],
       [ 1.22831598e-01,  1.30915672e-01,  7.43995085e-02,
        -7.53342807e-02, -2.06910223e-01, -1.84539944e-01,
         2.61359870e-01, -2.82187372e-01, -6.73582032e-03,
         1.24582648e-03],
       [ 2.05421969e-02,  2.17275023e-01, -1.98052749e-01,
        -7.11035356e-03, -3.69731605e-01, -1.85717657e-01,
         4.95720804e-02, -1.15070224e-01,  9.22817066e-02,
        -8.54254588e-02],
       [ 6.88252151e-02,  1.44817978e-01, -1.33312225e-01,
        -5.94893545e-02, -2.51493931e-01, -1.20750599e-01,
         1.42469034e-01, -3.34877908e-01, -1.80790052e-02,
        -2.16673821e-01],
       [-1.82654113e-01,  1.76363677e-01, -1.04193613e-01,
        -1.50571056e-02, -5.33530712e-01, -6.55154735e-02,
         2.52551317e-01, -2.92810172e-01,  2.38716491e-02,
         1.04600534e-01],
       [ 2.09688932e-01,  6.80498332e-02,  4.03034799e-02,
        -6.70600384e-02, -9.22301486e-02, -3.91195267e-02,
         2.45474190e-01, -1.92359731e-01,  1.75092369e-03,
        -1.95264012e-01],
       [-2.23329328e-02,  1.45669162e-01, -1.52139455e-01,
         5.34053110e-02, -2.56344557e-01, -9.23609883e-02,
         1.19958594e-01, -2.34045699e-01,  3.79388183e-02,
        -7.73888920e-03],
       [-1.97878480e-02,  1.73301339e-01, -9.23557431e-02,
         8.06293264e-02, -3.48612636e-01, -4.22735885e-03,
         2.06312776e-01, -3.54612112e-01,  4.30748984e-02,
         2.62675956e-02],
       [ 2.90626474e-03,  1.41105443e-01, -9.56210792e-02,
         2.42822953e-02, -1.43771321e-01, -5.69432080e-02,
         1.31530389e-01, -1.60310119e-01,  8.27911496e-02,
        -3.81680019e-02],
       [ 9.71606374e-03,  9.51089263e-02, -1.06061004e-01,
        -3.78770605e-02, -4.01138932e-01, -1.43262267e-01,
         1.28680661e-01, -2.61636645e-01, -1.15659922e-01,
        -1.11245878e-01],
       [ 5.21949939e-02,  1.50751829e-01,  9.35174376e-02,
         3.17564085e-02, -3.50595593e-01, -1.38468161e-01,
         1.85263827e-01, -3.03996593e-01, -5.42629510e-03,
         7.91802928e-02],
       [ 1.37041256e-01,  1.69497833e-01, -8.63870382e-02,
         1.57636702e-02, -1.79991171e-01, -9.08993408e-02,
         1.56794131e-01, -1.47339672e-01,  2.86077783e-02,
        -2.16370448e-01],
       [-2.69663520e-03,  2.21826449e-01,  5.43841496e-02,
        -1.17137842e-01, -2.97215730e-01, -8.82113725e-02,
         1.05724975e-01, -2.20105618e-01,  4.52305302e-02,
         3.74987908e-03],
       [-9.54074338e-02,  9.03669074e-02, -1.19427562e-01,
        -3.28763574e-03, -2.99989522e-01, -4.01741937e-02,
         1.19693980e-01, -3.00709546e-01, -1.82122067e-02,
         6.16809502e-02],
       [-9.99720395e-02,  1.39634579e-01, -2.00303137e-01,
         6.65970296e-02, -3.51577580e-01, -9.84932706e-02,
         9.02529806e-02, -2.44839266e-01,  1.13900602e-02,
         6.42385334e-02],
       [ 1.44478194e-02,  1.83509141e-01, -6.67656586e-02,
        -4.61032502e-02, -4.40125525e-01, -7.73666874e-02,
         3.11786264e-01, -4.44831997e-01, -2.93702632e-02,
         4.22365218e-03],
       [ 1.32541731e-01,  1.82956174e-01, -1.04927093e-01,
         1.22469440e-02, -1.27488539e-01, -6.61835223e-02,
         1.12031206e-01, -9.84050632e-02,  9.86539871e-02,
        -2.05507338e-01],
       [-6.50153607e-02,  1.74549997e-01, -1.71862602e-01,
         6.14198335e-02, -4.02740180e-01, -4.20516059e-02,
         2.40440428e-01, -2.76678443e-01, -5.84294945e-02,
         1.08706802e-01],
       [ 7.48759061e-02,  2.10499942e-01,  7.57246763e-02,
        -3.62860188e-02, -3.67509723e-01, -7.15096444e-02,
         8.84700716e-02, -3.12674284e-01, -9.51475725e-02,
        -5.33458181e-02],
       [-7.26057142e-02,  9.15311202e-02, -4.78885621e-02,
        -1.29553005e-02, -2.53760219e-01, -1.62764773e-01,
         7.15883225e-02, -1.82869673e-01, -7.55081400e-02,
        -3.04650366e-02],
       [-1.08026592e-02,  1.12620339e-01,  1.36063978e-01,
         7.70925432e-02, -1.93178192e-01,  1.07710674e-01,
         1.66693091e-01, -2.04272971e-01, -4.38663363e-03,
         6.33232668e-02],
       [ 1.56896949e-01,  8.96635503e-02,  4.47561368e-02,
        -7.04529062e-02, -6.03420623e-02, -5.07256314e-02,
         1.84834778e-01, -2.29554027e-01,  2.17712931e-02,
        -1.52052462e-01],
       [-1.15724243e-01,  8.04387107e-02, -5.88226505e-02,
         8.92069191e-04, -2.94122517e-01, -1.87400073e-01,
         2.21029714e-01, -1.70096382e-01,  2.51032859e-02,
         1.31270349e-01],
       [ 1.54940831e-02,  1.65156215e-01,  4.22797874e-02,
        -3.61752585e-02, -2.75962710e-01, -1.24710202e-01,
         8.39878917e-02, -2.54829884e-01,  1.27165876e-02,
         3.31448913e-02],
       [ 2.08573253e-03,  1.30548522e-01,  1.03515625e-01,
        -5.84256947e-02, -2.24049121e-01, -8.96962434e-02,
         2.11826205e-01, -2.76590139e-01, -5.25184534e-02,
         5.65478951e-02],
       [ 1.21045813e-01,  3.13128859e-01,  9.14186686e-02,
        -5.13909534e-02, -2.90747494e-01, -1.41330570e-01,
         1.61255151e-01, -4.09655958e-01,  3.58102247e-02,
         2.64274161e-02],
       [ 1.41505450e-02,  8.61316472e-02, -1.45660698e-01,
         5.86940274e-02, -3.43586266e-01, -8.44506472e-02,
         1.70025736e-01, -2.38322318e-01, -4.34658676e-03,
        -2.87827849e-03],
       [ 7.29348212e-02,  1.93091661e-01, -1.56394571e-01,
         4.55366895e-02, -3.35400045e-01, -1.32662982e-01,
         7.06198215e-02, -1.37381166e-01,  8.84873420e-02,
        -1.14571638e-01],
       [-3.92792933e-02,  1.48667991e-01, -9.71383378e-02,
        -6.74230531e-02, -3.02699536e-01, -2.55341195e-02,
         1.27064362e-01, -2.09095284e-01,  1.01266101e-01,
         1.40049338e-01],
       [ 1.23310566e-01,  7.76651576e-02, -5.40870428e-03,
         4.73390855e-02, -2.56887197e-01,  1.34499185e-02,
         1.85226709e-01, -2.55586505e-01,  4.35908213e-02,
        -3.50753143e-02],
       [-8.64604563e-02,  2.31015921e-01, -5.34078926e-02,
        -5.69486693e-02, -4.90515888e-01, -1.08600266e-01,
         1.19341165e-01, -3.48818719e-01, -1.02463067e-02,
         5.82626238e-02],
       [ 2.33721081e-02,  8.51243734e-04, -1.55263841e-01,
        -4.45500016e-02, -2.12267876e-01, -1.53808981e-01,
         6.14735037e-02, -1.64876774e-01, -1.04674906e-01,
        -1.51574746e-01],
       [-1.39495209e-01,  1.52735859e-01, -1.02774106e-01,
         3.62434834e-02, -3.99802774e-01, -4.73229587e-02,
         2.26968199e-01, -2.60911047e-01, -4.54405621e-02,
         1.17595874e-01],
       [ 1.58262476e-02,  8.54684860e-02, -1.06508911e-01,
         1.02037542e-01, -3.03655088e-01, -7.41694570e-02,
         2.05974013e-01, -2.92493582e-01, -1.05211154e-01,
         5.25907800e-03],
       [-8.26321095e-02,  1.80413112e-01, -1.01069041e-01,
        -3.26163694e-02, -4.50304151e-01, -1.06844909e-01,
         5.95432222e-02, -2.05993459e-01,  8.68424997e-02,
         9.42370202e-03],
       [ 4.29255553e-02,  1.63784251e-01, -2.18889564e-02,
         6.25932515e-02, -1.61837101e-01, -2.08140716e-01,
         2.36108944e-01, -2.10580200e-01, -8.19268376e-02,
        -8.67268592e-02],
       [ 1.06723763e-01,  1.73626572e-01, -7.36150444e-02,
        -3.47835794e-02, -2.27102175e-01, -2.34541237e-01,
         1.46386385e-01, -2.80957967e-01, -4.40967083e-03,
        -5.37140332e-02],
       [ 2.75464915e-02,  1.41459256e-01, -3.87002230e-02,
         2.73079798e-02, -3.48300725e-01, -1.21690288e-01,
         2.27318466e-01, -3.45179290e-01,  2.37728134e-02,
        -4.03352082e-02],
       [-1.76691502e-01,  1.37643158e-01, -5.67638800e-02,
        -5.56118786e-04, -3.88640523e-01, -1.25400782e-01,
         2.01528460e-01, -2.48365819e-01, -1.57399729e-01,
         1.01713941e-01],
       [-1.59693748e-01,  2.09404677e-01, -1.04986534e-01,
        -1.48359641e-01, -2.98504502e-01, -2.92371094e-01,
         5.03033698e-02, -1.31363750e-01,  1.03016965e-01,
         2.24144422e-02],
       [ 5.60386404e-02,  1.35745630e-01,  3.90878096e-02,
         7.47692585e-02, -2.68035203e-01, -7.92922825e-02,
         2.57956684e-01, -3.78247499e-01, -6.12850487e-03,
        -1.08967513e-01],
       [-3.56525183e-03,  1.20683648e-01,  7.16836527e-02,
        -4.75444421e-02, -3.20046216e-01, -1.89872205e-01,
         1.36546373e-01, -2.40315855e-01, -8.23161006e-02,
         9.24457610e-02],
       [ 8.29515085e-02, -1.42475218e-03, -1.37439638e-01,
        -7.92527385e-03, -9.10671651e-02, -1.04378976e-01,
         9.12457854e-02, -1.41347721e-01, -5.19168749e-02,
        -2.03986824e-01],
       [ 2.84444056e-02,  1.29459754e-01, -1.53349116e-02,
        -8.56753960e-02, -3.15394014e-01, -1.11945629e-01,
         6.06864467e-02, -3.40235710e-01, -3.83457541e-02,
        -9.85145792e-02],
       [ 1.28373414e-01, -4.11517769e-02, -1.60181478e-01,
        -7.12700635e-02, -1.03279904e-01, -7.15017915e-02,
         7.57745504e-02, -1.26075238e-01, -4.52508181e-02,
        -2.89755970e-01],
       [-6.51197582e-02,  2.04959273e-01, -1.04472905e-01,
         3.50410938e-02, -2.49693528e-01, -1.15549982e-01,
         3.61202434e-02, -1.73137397e-01, -1.81247666e-02,
        -3.57590988e-03],
       [-5.83130457e-02,  1.40513688e-01, -1.33702904e-01,
         2.82108374e-02, -4.03776646e-01, -9.68370736e-02,
         1.41329780e-01, -1.68025374e-01,  3.46826874e-02,
         2.28778720e-02],
       [-3.75167355e-02,  8.63219798e-02, -1.85671210e-01,
         2.99670994e-02, -2.33271241e-01, -6.28029704e-02,
         1.62515551e-01, -1.93014055e-01,  6.13552369e-02,
        -1.79618783e-02],
       [ 3.00599448e-02,  1.02386631e-01,  8.53579044e-02,
         1.64787471e-03, -1.77172825e-01, -1.43747658e-01,
         5.66788465e-02, -1.46428332e-01, -1.47961304e-02,
        -4.46532555e-02],
       [-1.10645309e-01,  1.65324107e-01,  4.55538183e-03,
         2.99948044e-02, -3.46696526e-01,  1.56221874e-02,
         3.43421698e-02, -1.53912872e-01, -1.63781494e-02,
         5.61783984e-02],
       [-5.74554652e-02,  1.73873678e-01,  2.39674598e-02,
         5.73134795e-02, -3.39531004e-01, -2.35163961e-02,
         8.80995244e-02, -1.96591303e-01,  4.05878499e-02,
         3.93691286e-02],
       [-6.33653328e-02,  1.92787185e-01, -1.58524215e-01,
        -2.15648673e-03, -3.60980183e-01, -1.75777331e-01,
         1.52091727e-01, -2.15706527e-01,  8.36684853e-02,
         3.28490287e-02],
       [-1.28059506e-01,  1.10793650e-01, -1.03594311e-01,
         1.62330829e-02, -3.45485270e-01, -1.49970520e-02,
         1.38308808e-01, -2.61050552e-01,  3.11166197e-02,
         2.95418501e-02],
       [-1.12311147e-01,  2.21623197e-01,  4.02755402e-02,
         1.26273036e-01, -3.69358778e-01, -3.77379954e-02,
         1.73841923e-01, -3.46165776e-01,  2.09579170e-02,
         1.62290595e-02],
       [-7.96634629e-02,  8.21764618e-02, -1.46908149e-01,
        -1.15463018e-01, -2.95011848e-01, -2.38702789e-01,
         4.32551615e-02, -2.11507469e-01,  3.09027694e-02,
        -1.27929717e-01]], dtype=float32)>
} }
2024-01-11 18:14:53.554141: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

復元された関数の呼び出しは、保存されたモデル(tf.keras.Model.predict)に対するフォワードパスです。読み込まれた関数をトレーニングし続ける場合はどうでしょうか。または読み込まれた関数をより大きなモデルに埋め込むには?一般的には、この読み込まれたオブジェクトを Keras レイヤーにラップして行うことができます。幸いにも、TF Hub には、以下に示すとおり、この目的に使用できる hub.KerasLayer が用意されています。

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2024-01-11 18:14:54.395157: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
235/235 [==============================] - 4s 7ms/step - loss: 0.3436 - sparse_categorical_accuracy: 0.9023
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.1126 - sparse_categorical_accuracy: 0.9682

上記の例では、hub.KerasLayertf.saved_model.load() から読み込まれた結果を、別のモデルの構築に使用できる Keras レイヤーにラップしています。転移学習を行う際に非常に便利な手法です。

どの API を使用すべきですか?

保存においては、Keras モデルを使用している場合は、低レベル API が実現できる追加の制御が必要でない限り、Keras の Model.save API を使用します。保存しているものが Keras モデルでない場合は、低レベル API の tf.saved_model.save しか使用できません。

読み込みにおいては、使用する API はモデルの読み込みから得ようとしているものによって異なります。Keras モデルを使用できない場合(または使用したくない場合)は、tf.saved_model.load を使用し、使用できる場合は tf.keras.models.load_model を使用します。Keras モデルを保存した場合にのみ、Keras モデルを読み込めることに注意してください。

API を混在させることも可能です。model.save で Keras モデルを保存し、低レベルの tf.saved_model.load API を使用して、非 Keras モデルを読み込むことができます。

model = get_model()

# Saving the model using Keras `Model.save`
model.save(saved_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

ローカルデバイスからの読み込みまたは保存

ローカル I/O デバイスから読み込みと保存を行い、リモートデバイスでトレーニングする場合(Cloud TPU を使用する場合など)、tf.saved_model.SaveOptionstf.saved_model.LoadOptionsexperimental_io_device を使用して、I/O デバイスを localhost に設定する必要があります。以下に例を示します。

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

警告

Keras モデルを特定の方法で作成してから、トレーニングする前に保存するという、以下のような特別なケースがあります。

class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(saved_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built.
ValueError:  Model <__main__.SubclassedModel object at 0x7f7f09dc29a0> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.

SavedModel は tf.function をトレースする際に生成される tf.types.experimental.ConcreteFunction オブジェクトを保存します(詳細は、グラフと tf.function の基本ガイドの関数はいつトレースしますか? をご覧ください)。このような ValueError が発生した場合、Model.save がトレースされた ConcreteFunction を見つけられなかったか作成できなかったことが原因です。

注意: 少なくとも 1 つの ConcreteFunction がない場合にモデルを保存しないことをお勧めします。そうでない場合、低レベル API は、ConcreteFunction シグネチャのない状態で SavedModel を生成してしまうためです(SavedModel 形式については、こちらをご覧ください)。以下に例を示します。

tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7f7f09c0f250>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7f7f09c0f250>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
_SignatureMap({})

通常、モデルのフォワードパス(call メソッド)は、モデルが Keras の Model.fit メソッドを通じて初めて呼び出されたときに、自動的にトレースされます。また、最初のレイヤーを tf.keras.layers.InputLayer などにして、input_shape キーワード引数に渡すことで入力形状を設定している場合、Keras の Sequential API と Functional API によって ConcreteFunction が生成されることもあります。

モデルにトレース済みの ConcreteFunction が存在するかを確認するには、Model.save_specNone になっていることを確認します。

print(my_model.save_spec() is None)
True

tf.keras.Model.fit を使ってモデルをトレーニングし、save_spec が定義され、モデルの保存が機能するかを確認しましょう。

BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(saved_model_path)
Epoch 1/2
7/7 [==============================] - 1s 2ms/step - loss: 4.7761
Epoch 2/2
7/7 [==============================] - 0s 2ms/step - loss: 4.4889
False
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}).
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets