분산 전략을 사용한 모델 저장 및 불러오기

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

개요

이 튜토리얼에서는 훈련 중 또는 훈련 후에 tf.distribute.Strategy를 사용하여 SavedModel 형식으로 모델을 저장하고 로드하는 방법을 보여줍니다. Keras 모델을 저장하고 로드하는 API에는 상위 수준(tf.keras.Model.savetf.keras.models.load_model) 및 하위 수준(tf.saved_model.savetf.saved_model.load)의 두 가지 종류가 있습니다.

SavedModel 및 직렬화에 대한 일반적인 내용은 저장된 모델 가이드Keras 모델 직렬화 가이드를 참조하세요. 간단한 예부터 시작하겠습니다.

주의: TensorFlow 모델은 코드이며 신뢰할 수 없는 코드에 주의하는 것이 중요합니다. 자세한 내용은 TensorFlow 안전하게 사용하기를 참조하세요.

필요한 패키지 가져오기:

import tensorflow_datasets as tfds

import tensorflow as tf
2022-12-15 02:03:06.609082: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:03:06.609187: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:03:06.609198: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TensorFlow Datasets 및 tf.data를 사용하여 데이터를 로드 및 준비하고 tf.distribute.MirroredStrategy를 사용하여 모델을 생성합니다.

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

tf.keras.Model.fit을 사용하여 모델을 훈련시킵니다.

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2022-12-15 02:03:13.997281: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
235/235 [==============================] - 12s 8ms/step - loss: 0.3160 - sparse_categorical_accuracy: 0.9121
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.0934 - sparse_categorical_accuracy: 0.9727
<keras.callbacks.History at 0x7f264012cfd0>

모델 저장하고 불러오기

이제 작업할 간단한 모델이 생겼으므로 API 저장/로드하기를 살펴보겠습니다. 두 가지 종류의 API를 사용할 수 있습니다.

Keras API

다음은 Keras API를 사용하여 모델을 저장하고 로드하는 예입니다.

keras_model_path = '/tmp/keras_save'
model.save(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

tf.distribute.Strategy없이 모델 복원시키기:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
235/235 [==============================] - 2s 3ms/step - loss: 0.0653 - sparse_categorical_accuracy: 0.9810
Epoch 2/2
235/235 [==============================] - 1s 3ms/step - loss: 0.0497 - sparse_categorical_accuracy: 0.9851
<keras.callbacks.History at 0x7f2614191a00>

모델을 복원한 후 Model.compile을 다시 호출할 필요 없이 모델에 대한 훈련을 계속할 수 있습니다. 저장하기 전에 이미 컴파일되었기 때문입니다. 모델은 TensorFlow의 표준 SavedModel proto 형식으로 저장됩니다. 자세한 내용은 SavedModel 형식 가이드를 참조하세요.

이제 모델을 복원하고 tf.distribute.Strategy를 사용하여 훈련시킵니다.

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2022-12-15 02:03:32.471904: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-12-15 02:03:32.534336: W tensorflow/core/framework/dataset.cc:769] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
235/235 [==============================] - 4s 13ms/step - loss: 0.0663 - sparse_categorical_accuracy: 0.9803
Epoch 2/2
235/235 [==============================] - 3s 13ms/step - loss: 0.0487 - sparse_categorical_accuracy: 0.9853

Model.fit 출력에서 볼 수 있듯이 로드는 tf.distribute.Strategy에서 예상대로 작동합니다. 여기에 사용된 전략은 저장하기 전에 사용한 전략과 동일하지 않아도 됩니다.

tf.saved_model API

하위 수준 API로 모델을 저장하는 것은 Keras API와 유사합니다.

model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

tf.saved_model.load를 사용하여 로드할 수 있습니다. 그러나 이것은 하위 수준 API(따라서 사용 사례의 범위가 더 넓음)이기 때문에 Keras 모델을 반환하지 않습니다. 대신 추론을 수행하는 데 사용할 수 있는 함수가 포함된 객체를 반환합니다. 예를 들면 다음과 같습니다.

DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

로드된 객체에는 각각 키와 연결된 여러 함수가 포함될 수 있습니다. "serving_default" 키는 저장된 Keras 모델이 있는 추론 함수의 기본 키입니다. 이 함수로 추론하려면 다음과 같이 합니다.

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
2022-12-15 02:03:40.597935: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[-0.07287998, -0.06065657,  0.03483895, ..., -0.18132585,
        -0.08910748,  0.06052908],
       [-0.31163222, -0.02226205, -0.13526598, ..., -0.04082797,
        -0.23755129,  0.04015619],
       [-0.26187348, -0.09974411, -0.18940257, ..., -0.17564824,
        -0.11286059, -0.00997314],
       ...,
       [-0.24347621, -0.09062086, -0.08512177, ..., -0.0498026 ,
        -0.02609731, -0.0913346 ],
       [-0.12913392, -0.0360139 , -0.04862705, ...,  0.04664   ,
        -0.02082852, -0.01882938],
       [-0.12633127, -0.14244772, -0.08439697, ..., -0.18966636,
        -0.21226291, -0.08730254]], dtype=float32)>}

또한 분산방식으로 불러오고 추론할 수 있습니다:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2022-12-15 02:03:40.864456: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
{'dense_3': PerReplica:{
  0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-7.28799775e-02, -6.06565699e-02,  3.48389521e-02,
         1.01685286e-01,  1.32346496e-01,  3.62287462e-02,
        -4.74997982e-02, -1.81325853e-01, -8.91074836e-02,
         6.05290756e-02],
       [-3.11632216e-01, -2.22620517e-02, -1.35265976e-01,
         2.21318334e-01,  1.89438641e-01,  2.96817273e-02,
        -1.30453795e-01, -4.08279747e-02, -2.37551287e-01,
         4.01561856e-02],
       [-2.61873484e-01, -9.97441113e-02, -1.89402565e-01,
         3.95955741e-02, -7.23269284e-02, -9.36339349e-02,
        -1.40353218e-02, -1.75648242e-01, -1.12860590e-01,
        -9.97313857e-03],
       [ 7.93537498e-03, -8.39788094e-02, -5.49430102e-02,
         5.06623089e-02, -8.49889666e-02, -6.26260862e-02,
        -1.38619915e-01, -1.57361925e-01, -6.83496967e-02,
        -4.66652215e-02],
       [-1.02909938e-01, -2.37993933e-02, -1.04971468e-01,
         4.83149216e-02,  1.14822879e-01, -5.87940440e-02,
        -1.37954786e-01, -6.18818626e-02, -1.89428061e-01,
        -4.02544215e-02],
       [-1.79917410e-01, -1.91135854e-01, -3.26386169e-02,
         1.00716278e-01,  2.36715883e-01, -5.11703491e-02,
        -1.15932301e-01, -1.63346350e-01, -1.08930893e-01,
        -1.59819424e-02],
       [-2.23248005e-01, -1.07156508e-01, -1.56170949e-01,
         2.09210753e-01,  2.30618030e-01,  4.42695320e-02,
        -7.93789476e-02,  6.18112087e-03, -2.80879021e-01,
         3.52547914e-02],
       [-1.66931152e-01, -5.24550006e-02, -1.39176026e-02,
         1.81497350e-01,  7.80940354e-02, -3.89323756e-02,
        -1.06163725e-01, -1.52070493e-01, -5.14130965e-02,
        -5.70558012e-03],
       [-2.14375734e-01, -1.00615025e-01, -5.67342192e-02,
         2.86581397e-01,  1.90791816e-01,  5.83283305e-02,
        -5.73150776e-02, -1.12479657e-01, -1.91715017e-01,
         1.06393799e-01],
       [-8.23477358e-02,  1.97841227e-02, -7.04994239e-03,
         4.14061472e-02, -6.09277114e-02,  2.74188817e-03,
        -1.26273513e-01, -1.23661444e-01,  1.32542774e-02,
        -7.26202130e-02],
       [-1.20762050e-01, -1.25235796e-01, -1.10947207e-01,
         1.59376606e-01,  1.12468407e-01,  2.80564874e-02,
        -5.00258692e-02, -5.37324846e-02, -1.23057850e-01,
         4.76358086e-03],
       [-1.23137876e-01, -1.30127490e-01,  1.55361220e-02,
         3.30984667e-02,  1.98005736e-01,  1.39378458e-02,
        -9.16113257e-02, -1.57005906e-01, -2.35327348e-01,
        -2.09969282e-02],
       [-1.67611763e-01, -9.82501507e-02, -1.11254290e-01,
        -1.74203143e-02,  8.70675594e-03, -5.14681041e-02,
         2.80853733e-03, -1.12006947e-01, -1.90147415e-01,
         3.16145048e-02],
       [-7.57670403e-02, -1.68783948e-01,  5.70745356e-02,
         8.46423954e-03,  4.79865223e-02, -5.61940446e-02,
        -9.11350548e-03, -8.33343565e-02,  3.85395586e-02,
        -4.47189920e-02],
       [ 3.98642831e-02, -1.68964177e-01,  2.34704465e-03,
         8.85255933e-02,  3.34273912e-02, -8.45933557e-02,
        -1.02341652e-01, -1.77880928e-01, -8.69040787e-02,
        -6.89195544e-02],
       [-8.19052011e-02, -1.38401985e-04, -1.96416005e-02,
         3.40183526e-02,  6.29564151e-02,  2.68965885e-02,
        -8.98694545e-02, -2.06203222e-01, -1.66078120e-01,
        -3.65025401e-02],
       [-9.61717218e-03, -1.54140770e-01, -3.60958166e-02,
         6.36731610e-02,  6.29597902e-02, -1.41386196e-01,
        -6.57406449e-02, -1.47993594e-01, -6.41547143e-03,
        -1.13182023e-01],
       [-1.13841832e-01, -8.26231986e-02, -8.23264197e-02,
         1.14308886e-01,  9.28876474e-02, -1.47088230e-01,
        -5.06824814e-02, -1.51430473e-01, -1.95156753e-01,
        -1.11864254e-01],
       [-1.46024257e-01, -1.62329197e-01, -5.31365648e-02,
         5.37539646e-02, -1.25871524e-02,  2.54438221e-02,
         3.39765660e-02, -1.52085036e-01,  4.29450050e-02,
        -7.55875334e-02],
       [-6.77345693e-03, -1.47028357e-01,  3.87683138e-02,
         8.03087652e-03,  2.25953609e-02, -1.80521369e-01,
        -1.05571032e-01, -1.97115004e-01,  2.36714631e-03,
        -7.06604421e-02],
       [-1.36034355e-01, -1.27116233e-01, -3.08149569e-02,
         1.42252713e-01,  9.05230045e-02, -2.37616599e-02,
         6.10212982e-02,  3.56803834e-03, -1.16188146e-01,
         2.34036706e-03],
       [-1.47848710e-01, -1.09710269e-01, -5.98610975e-02,
         8.27923268e-02,  6.82088137e-02, -7.21749067e-02,
         1.95025653e-02, -1.33840680e-01, -2.63085887e-02,
        -3.75156254e-02],
       [-1.98249251e-01,  4.15598974e-02, -1.83626860e-01,
         2.63186693e-01,  1.28984734e-01, -6.18314743e-03,
        -9.33448151e-02, -2.75875032e-02, -2.89361358e-01,
         2.90494934e-02],
       [-1.68448567e-01, -7.70408213e-02, -8.50498229e-02,
         1.18215643e-01,  2.00099796e-01, -4.61177528e-03,
        -5.47068566e-02, -4.25368398e-02, -1.29340470e-01,
         6.70121610e-03],
       [-1.09302014e-01,  2.13234872e-03,  1.28173493e-02,
        -3.39557603e-02,  5.92609271e-02,  9.22352076e-03,
        -1.26339003e-01, -1.23679519e-01, -1.04768775e-01,
        -4.50533628e-03],
       [-1.66910559e-01, -1.22267604e-02, -5.61841838e-02,
         1.53836221e-01,  9.32691768e-02, -3.45964134e-02,
        -8.44600052e-02, -3.49652767e-03, -1.91222817e-01,
        -6.34718090e-02],
       [-1.49316788e-02, -6.06073327e-02,  2.32502371e-02,
         1.53633535e-01,  1.13307789e-01, -1.18827701e-01,
        -4.82188575e-02, -5.57667464e-02, -5.64857125e-02,
        -9.46970657e-02],
       [-1.61244735e-01, -2.02564284e-01, -6.43717274e-02,
         1.11552835e-01,  2.06215024e-01,  2.68670470e-02,
        -4.10928950e-03, -1.21487029e-01, -1.08253866e-01,
        -4.21646833e-02],
       [-1.31115407e-01,  1.97452158e-02, -6.87445998e-02,
         9.73142609e-02,  9.02923569e-03,  5.40745258e-03,
        -9.39824432e-02, -3.04169133e-02, -1.52549520e-02,
        -7.94284791e-02],
       [-2.03944385e-01, -1.68591321e-01, -8.50997940e-02,
        -6.92034513e-02,  7.45109469e-02, -1.28197253e-01,
        -3.30277458e-02, -2.30982184e-01, -8.90743434e-02,
        -6.75272495e-02],
       [ 1.25199780e-02, -3.66478004e-02, -1.85598433e-02,
        -4.35349345e-03,  1.32734060e-01, -1.44086212e-01,
        -1.02089845e-01, -1.78967878e-01, -1.67951390e-01,
        -7.92149305e-02],
       [-2.12898344e-01,  7.23424032e-02, -6.64692521e-02,
         2.56968647e-01, -1.51842833e-02, -6.54250383e-04,
        -3.91354449e-02,  6.75468147e-03, -1.68333322e-01,
        -2.39662640e-02],
       [-2.07378954e-01, -1.43291295e-01, -6.11460358e-02,
         4.63999063e-02,  1.93370908e-01,  1.90266818e-02,
        -8.18494111e-02, -1.65581301e-01, -8.13359767e-02,
         4.84905131e-02],
       [-4.75545451e-02, -5.11822999e-02, -4.79907319e-02,
         5.94373755e-02,  6.13151528e-02, -2.43304446e-02,
         7.29072839e-02, -1.31035775e-01, -7.87011907e-03,
        -6.39809519e-02],
       [-1.53587535e-01, -1.51007861e-01, -4.03159559e-02,
         2.15081543e-01,  1.41464621e-01,  1.74854100e-02,
         7.37385750e-02,  2.05000788e-02, -1.26210690e-01,
         7.38826022e-03],
       [-1.83326572e-01,  4.01895493e-03, -1.80756673e-01,
        -1.07586905e-02,  3.95277739e-02, -3.02815065e-02,
        -2.12222129e-01, -1.42814025e-01, -1.96930677e-01,
         6.57971427e-02],
       [-2.24769160e-01, -6.76333159e-02, -1.26510993e-01,
         9.13576633e-02,  9.57667828e-04, -2.81486735e-02,
        -6.56390190e-03,  1.44888163e-02, -7.65987337e-02,
        -7.20929727e-03],
       [-4.21147048e-02, -1.02604777e-01, -3.79882865e-02,
         3.92164513e-02,  6.84537292e-02, -7.65606910e-02,
        -1.11784890e-01, -1.63284570e-01,  2.00181752e-02,
        -1.11831129e-01],
       [-5.09862229e-02, -1.21048838e-01, -5.04833199e-02,
         3.05969976e-02, -3.14625949e-02, -1.24416485e-01,
        -4.67367396e-02, -7.86877126e-02,  9.32597518e-02,
        -1.24367654e-01],
       [-6.24672994e-02, -3.11086699e-02, -1.03235692e-01,
         3.96189839e-02,  7.39871040e-02, -1.12175494e-02,
        -1.56951457e-01, -7.08221793e-02, -1.76286846e-01,
         2.88603529e-02],
       [-2.63252974e-01, -8.48379731e-02, -8.98276567e-02,
         1.19735137e-01,  5.74256182e-02,  1.38422698e-02,
         2.56380662e-02, -5.05944416e-02, -1.10762388e-01,
         3.55094820e-02],
       [-1.85789764e-01, -1.28740594e-01, -3.87063213e-02,
        -1.67864114e-02,  6.59885257e-02,  1.87616721e-02,
        -7.30160773e-02, -8.81641805e-02, -1.50436163e-02,
        -5.92254885e-02],
       [-2.00531811e-01, -6.78859204e-02, -1.05498314e-01,
        -7.03853369e-02,  4.23895717e-02, -6.24774843e-02,
         3.19530070e-03, -1.47509933e-01,  1.98302418e-03,
        -7.07108229e-02],
       [-1.43352106e-01, -4.31504399e-02, -9.60315093e-02,
        -9.56158563e-02, -3.98100391e-02, -1.42124772e-01,
        -1.39046520e-01, -1.59666866e-01, -6.47188053e-02,
        -7.55325705e-03],
       [-9.60299522e-02, -8.74441862e-02, -6.49379268e-02,
        -2.64349207e-02,  7.37841800e-02, -1.11844733e-01,
        -6.06652871e-02, -9.56512243e-02, -1.25965044e-01,
         3.10850516e-02],
       [-1.45892203e-01,  3.94490585e-02, -6.04322478e-02,
        -6.49462640e-03,  4.34122235e-03,  8.12214985e-02,
        -8.87766108e-02, -6.66431412e-02, -6.28292859e-02,
         3.22138220e-02],
       [-2.40669951e-01, -6.78462461e-02, -7.11867958e-03,
        -3.86667699e-02,  2.77847648e-02,  1.07206106e-02,
        -1.09458834e-01, -1.63020983e-01, -6.74830377e-02,
         1.07421257e-01],
       [-1.48102567e-01, -1.05975874e-01, -6.85550421e-02,
         1.06810734e-01,  1.39986396e-01,  5.61371595e-02,
        -1.05396390e-01, -1.20584115e-01, -2.72857726e-01,
         6.31251261e-02],
       [-9.49827731e-02, -8.87574926e-02,  1.41024925e-02,
         2.49370635e-02,  7.25466534e-02, -1.10524900e-01,
        -1.13479562e-01, -2.17339844e-01,  3.91067713e-02,
        -4.77101877e-02],
       [-1.59671620e-01,  1.84059665e-02, -3.42258625e-02,
        -1.06910616e-02, -1.56807914e-01, -2.78270319e-02,
        -9.17102024e-02, -9.66491401e-02,  1.01061091e-01,
        -7.13486895e-02],
       [-2.45026276e-01, -4.11175564e-02, -7.93516636e-02,
         4.07964662e-02,  9.09752026e-02,  3.22110802e-02,
        -6.14773892e-02, -1.02943733e-01, -9.67963040e-02,
         5.73014468e-02],
       [-1.92442864e-01, -9.37693715e-02, -1.09221675e-01,
        -4.74449545e-02,  3.21590006e-02, -7.22244978e-02,
        -3.82006764e-02, -1.42180204e-01, -6.05306700e-02,
         1.93654895e-02],
       [-1.54133484e-01, -1.19779378e-01, -2.52300799e-02,
         6.08154722e-02,  8.36429596e-02, -1.02861151e-02,
        -3.88667248e-02, -1.28530249e-01, -7.39606470e-02,
         1.77168231e-02],
       [-1.19787559e-01, -1.06743135e-01, -9.61441547e-02,
         4.51400131e-02,  5.66263124e-02, -1.11290671e-01,
        -2.97779702e-02, -1.75447673e-01, -1.00173011e-01,
         2.32300907e-02],
       [-1.84117928e-01, -1.78663209e-01, -5.68316728e-02,
         1.12353712e-02,  1.47012115e-01, -9.08793062e-02,
        -7.41108209e-02, -2.99196601e-01, -1.13636717e-01,
        -7.69316852e-02],
       [-2.33878016e-01, -1.37468144e-01, -4.61244062e-02,
         8.30687433e-02,  2.13391811e-01,  2.18477696e-02,
        -2.27422714e-02, -1.40307158e-01, -1.66960433e-01,
        -9.78016108e-03],
       [-2.18359545e-01, -6.20837659e-02, -6.29559085e-02,
         9.58924890e-02,  5.88997975e-02,  5.65329194e-02,
         4.75174226e-02, -7.25000128e-02, -6.57093227e-02,
         2.96130218e-03],
       [-2.42846087e-01, -3.47198844e-02, -4.42623347e-03,
         6.73993751e-02,  7.88057745e-02, -6.60947263e-02,
        -1.71849802e-01, -1.67886093e-01,  1.71777010e-02,
        -4.46490273e-02],
       [-3.20103541e-02, -5.53519689e-02, -6.31359592e-02,
         5.97493649e-02,  6.61840886e-02, -5.47183566e-02,
        -1.16811715e-01, -9.87290293e-02, -1.78369418e-01,
        -1.32702857e-01],
       [-3.39658380e-01, -2.49794573e-02, -1.60468996e-01,
         1.00754082e-01,  8.86355191e-02,  7.92684406e-02,
        -7.74804652e-02, -4.39344719e-02, -1.84107274e-01,
        -1.99005902e-02],
       [-2.33549118e-01, -1.27293974e-01, -1.05664469e-01,
         7.08326697e-02,  1.19887814e-01, -1.46917105e-02,
         3.87302302e-02, -1.59724206e-01, -1.48974791e-01,
        -5.33915311e-03],
       [-2.06045389e-01, -1.79777175e-01, -1.43966228e-01,
         6.48718625e-02,  1.61388725e-01, -8.17139521e-02,
        -3.77683826e-02, -2.91507781e-01, -1.75432652e-01,
        -2.68044993e-02],
       [ 5.40779531e-03, -1.42092913e-01,  4.49729338e-03,
         1.11285530e-01,  7.64341354e-02, -3.01092304e-02,
        -1.36518300e-01, -1.06433481e-02, -4.65661287e-04,
        -1.25339910e-01],
       [-1.54111922e-01, -8.27784911e-02,  1.83406174e-02,
         8.68886337e-02,  7.98676535e-02,  6.14105873e-02,
        -3.64392065e-02, -1.02400064e-01, -8.72582346e-02,
        -5.01237139e-02]], dtype=float32)>,
  1: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.56213963e-01,  3.29039544e-02, -1.73344061e-01,
         8.71157348e-02,  5.96746504e-02,  1.21168345e-02,
        -1.32805228e-01, -1.91299468e-01, -1.14446975e-01,
         9.01912898e-02],
       [-2.13857964e-01, -6.11672848e-02, -2.19858028e-02,
         8.07693601e-02,  1.37339950e-01,  5.97042888e-02,
        -2.69669332e-02, -6.79526776e-02, -1.83542415e-01,
        -4.49461862e-03],
       [-1.49993405e-01, -1.15949810e-01,  3.86238843e-03,
         4.38240319e-02,  1.83764964e-01, -1.09400302e-02,
        -5.51194549e-02, -2.40220666e-01, -1.84593365e-01,
        -1.17457360e-01],
       [-9.50783119e-03, -7.13789761e-02, -4.15810756e-03,
         7.96664655e-02,  2.31158361e-02, -4.82877679e-02,
        -4.59153578e-02, -4.64774668e-04, -1.10733718e-01,
        -4.70234007e-02],
       [-1.70799807e-01, -1.45563021e-01, -1.39521912e-01,
         9.88771692e-02,  1.04169130e-01,  4.49064374e-03,
        -2.30264738e-02, -1.26273930e-01, -1.33978397e-01,
        -5.78817725e-02],
       [-1.04722142e-01, -1.04440503e-01, -7.93045387e-03,
         1.35758385e-01,  1.84308916e-01, -1.11518323e-01,
         2.05845106e-02, -1.22875653e-01, -1.25620097e-01,
        -9.42208841e-02],
       [-1.25254035e-01, -7.66401589e-02, -8.79659504e-02,
         1.24705024e-01,  4.55645472e-02, -2.00709999e-02,
         1.06258243e-02, -7.18511790e-02, -1.22882910e-01,
        -3.78146358e-02],
       [-8.97380486e-02, -2.20139071e-01, -7.58394897e-02,
         5.24612442e-02,  1.26651555e-01, -1.32006675e-01,
        -2.17293799e-02, -1.77936643e-01,  5.10824323e-02,
        -1.45276934e-01],
       [-1.33673221e-01, -1.37605786e-01,  3.09484452e-03,
         1.47445887e-01,  1.00005984e-01, -2.62805745e-02,
        -7.61101097e-02, -1.15595311e-01, -8.00357014e-02,
         5.06501868e-02],
       [-2.63926446e-01, -1.48905277e-01, -9.02045369e-02,
         4.66917455e-02,  1.08707778e-01, -1.74734443e-02,
        -1.04283944e-01, -1.27100378e-01, -8.96556154e-02,
         7.90984035e-02],
       [-1.84497088e-02, -4.01905701e-02,  1.55509934e-02,
         6.59157336e-02, -2.78288163e-02, -2.63091028e-02,
        -1.21007353e-01, -8.39568079e-02, -1.99581236e-02,
        -3.77809294e-02],
       [-1.23488471e-01, -8.71539712e-02, -2.95720361e-02,
         1.13185883e-01,  9.56483036e-02, -1.05258018e-01,
        -5.02655618e-02, -1.84914261e-01, -1.75868183e-01,
        -5.89791089e-02],
       [-1.76476881e-01, -5.38873672e-03, -8.27259421e-02,
         1.04588225e-01, -1.03357583e-02, -4.34330106e-03,
        -1.63971484e-01, -7.60647804e-02, -1.03228837e-01,
         4.91180867e-02],
       [-7.99017251e-02, -1.40042394e-01,  1.60744339e-02,
         6.98403940e-02,  1.95315033e-02, -1.43020004e-02,
        -1.24998301e-01, -1.58754677e-01,  6.63474202e-05,
         2.22029760e-02],
       [-2.56872833e-01,  1.10013187e-02, -3.84579152e-02,
         8.77166837e-02, -3.55684012e-02,  9.75668803e-02,
        -4.70308363e-02, -1.35927871e-01, -1.42623708e-02,
         8.62094238e-02],
       [-1.01544991e-01, -1.33569360e-01,  4.65596542e-02,
         1.35969371e-03, -7.25878775e-03, -7.96420872e-03,
         3.16601247e-02, -1.66682392e-01,  5.30194342e-02,
        -2.53286883e-02],
       [-1.46417692e-01, -1.13038324e-01,  1.19495615e-02,
         1.32304728e-01,  2.25361347e-01,  3.52344662e-02,
        -8.93247575e-02, -7.56622478e-02, -2.08971471e-01,
         3.71385366e-02],
       [-1.53379172e-01, -1.29094884e-01, -6.51434287e-02,
         1.07400335e-01,  1.82568312e-01, -5.68994135e-03,
        -5.35240397e-02, -1.38104081e-01, -1.51515663e-01,
         1.25226974e-02],
       [-1.69412971e-01,  5.31662256e-02, -1.20803803e-01,
         2.40203738e-03, -1.24794245e-01,  9.29506496e-03,
        -5.10612428e-02, -1.86063826e-01, -9.17630643e-02,
         6.28065169e-02],
       [-1.64114594e-01, -1.57413796e-01, -1.45171769e-03,
         1.12220123e-02,  6.55660629e-02, -7.94592276e-02,
        -1.22122653e-01, -2.24073827e-02,  3.33268940e-03,
        -4.78817821e-02],
       [-5.85423037e-02, -1.16742648e-01, -1.43228576e-01,
         1.13036603e-01,  9.43292752e-02, -6.05906919e-02,
         2.13046446e-02, -2.97889411e-02, -1.55270189e-01,
        -4.31704819e-02],
       [-1.99865192e-01, -3.78460810e-02,  1.32810958e-02,
         1.02941088e-01,  9.22629982e-02,  6.72834516e-02,
         1.16682798e-02, -9.40978378e-02,  1.47406757e-03,
         2.56398693e-04],
       [-2.19724074e-01, -5.92395589e-02, -9.03382897e-02,
        -1.66359916e-02, -3.05038355e-02, -2.78870389e-03,
        -1.64798319e-01, -6.42254800e-02,  1.78713053e-02,
        -5.11906967e-02],
       [-2.81071305e-01, -4.90775555e-02, -1.02798715e-01,
        -2.35506743e-02,  9.00105760e-02, -6.38327301e-02,
        -8.92827138e-02, -1.38486207e-01,  2.39488482e-03,
         6.83499798e-02],
       [-1.21666402e-01, -1.42379344e-01, -7.16106743e-02,
         1.99491426e-01,  7.26406574e-02, -1.34077072e-02,
         1.04167759e-01, -1.06956407e-01, -2.01135978e-01,
        -2.37018019e-02],
       [-1.80502832e-01, -8.52232277e-02, -4.23105359e-02,
         7.29889721e-02,  1.51262432e-01,  3.40324640e-02,
        -7.30158389e-02, -9.62212309e-02, -1.92034543e-01,
         7.85146654e-02],
       [-6.60478398e-02, -1.76595509e-01,  1.18160695e-02,
        -3.80624980e-02,  1.12145424e-01, -5.63758798e-02,
        -5.12192063e-02, -1.45052642e-01, -1.30713552e-01,
         5.60368486e-02],
       [-1.53617993e-01, -4.56308275e-02, -1.31020725e-01,
         4.72098291e-02,  1.18008569e-01,  9.46269333e-02,
        -4.61464375e-02,  4.93782163e-02, -1.24801219e-01,
         3.89134698e-02],
       [-7.06513971e-02, -1.12574793e-01, -4.64573354e-02,
         8.46010298e-02,  4.20339145e-02, -1.57729983e-02,
         4.84506115e-02, -1.33923620e-01, -1.72406942e-01,
        -2.01730747e-02],
       [-3.31903249e-03, -1.20539382e-01, -1.25618801e-02,
         1.04497932e-01,  5.70757315e-02, -5.74234873e-02,
        -1.75712742e-02, -5.30101210e-02, -1.45998701e-01,
        -3.99483070e-02],
       [-1.34980455e-01, -1.05482563e-01, -1.01522252e-01,
         9.08782929e-02,  4.81974110e-02, -2.51885355e-02,
         3.21632028e-02, -7.40098730e-02, -1.59075215e-01,
        -8.20701122e-02],
       [-1.66419059e-01, -1.63766384e-01, -8.90725106e-02,
        -1.98126137e-02, -3.34109962e-02, -5.46650887e-02,
        -2.85158902e-02, -1.16954237e-01, -1.89556926e-02,
         3.37420031e-03],
       [ 1.34949945e-03, -7.61629939e-02,  2.75072642e-03,
         1.60345361e-02, -5.80317713e-03, -1.20513387e-01,
        -8.28137398e-02, -5.51925302e-02,  3.67344469e-02,
        -1.49966441e-02],
       [-3.22049618e-01, -5.89181781e-02, -8.51856992e-02,
         8.16482455e-02,  1.17035463e-01,  9.45892930e-03,
        -4.75103185e-02, -7.12895617e-02, -1.81054205e-01,
         5.23709767e-02],
       [-4.05580550e-02, -8.78048986e-02, -6.46664873e-02,
         6.78064525e-02, -2.54313499e-02, -5.69313243e-02,
        -6.80021942e-04, -1.98988244e-02, -5.25148511e-02,
         5.14682382e-04],
       [-2.16121987e-01, -1.35834232e-01, -6.86128661e-02,
         1.48657233e-01,  2.12983787e-01, -2.53096744e-02,
        -1.85489915e-02, -1.33787021e-01, -1.10178262e-01,
         2.38232687e-03],
       [-1.24680765e-01, -1.08659461e-01,  5.93773834e-02,
         6.29900247e-02,  7.60843158e-02, -5.80172166e-02,
        -9.53314602e-02, -1.08028397e-01,  2.94854492e-03,
         2.31515616e-03],
       [-2.70582497e-01, -1.02869466e-01, -1.28852755e-01,
         9.74292755e-02,  1.93260431e-01,  8.66692364e-02,
        -5.54818623e-02, -9.57939252e-02, -2.20372245e-01,
         6.93553537e-02],
       [-2.21064895e-01, -1.50706947e-01, -1.27117127e-01,
        -8.14852864e-03, -2.61776000e-02,  3.46084610e-02,
         9.35978219e-02, -1.75071985e-01, -1.17223009e-01,
        -5.43006100e-02],
       [-5.40738106e-02, -8.00209269e-02,  4.37876061e-02,
         5.25648147e-02,  4.93699387e-02, -3.74476537e-02,
        -1.13279745e-01, -2.06659973e-01, -6.22199103e-02,
        -7.21887797e-02],
       [-2.48945504e-01,  4.34431806e-02, -8.67137462e-02,
         1.67143285e-01,  4.17809822e-02,  2.78116763e-03,
         2.86157243e-02, -5.86117804e-02, -2.41669267e-02,
        -6.15257807e-02],
       [-1.71857834e-01, -5.59055507e-02, -1.41226023e-01,
        -7.00492337e-02, -1.28162801e-01, -1.78599685e-01,
        -1.30959943e-01, -9.36377421e-02,  9.47269797e-03,
         1.03660002e-02],
       [-1.81474701e-01, -2.77827829e-02, -5.69939315e-02,
         6.96943253e-02,  1.10883370e-01, -1.97249651e-02,
        -1.91436112e-01, -2.36060172e-01, -1.54423982e-01,
         5.92458248e-02],
       [-1.96146965e-01,  7.44698942e-03, -9.98937786e-02,
         7.24339113e-02,  6.20480701e-02,  6.59219921e-04,
         2.35383790e-02, -1.15859509e-01, -5.01298755e-02,
        -3.31900232e-02],
       [-3.05510350e-02, -1.28295906e-02, -7.02592870e-03,
         7.73006678e-02, -6.88145682e-03,  8.79874453e-03,
        -8.21803361e-02, -3.76408473e-02,  1.98783576e-02,
        -4.46537696e-02],
       [-1.17785141e-01, -2.06617370e-01, -5.35709038e-02,
         1.01064309e-01,  1.36472613e-01, -1.58235282e-02,
        -1.31892562e-01, -2.58945078e-01, -4.47198302e-02,
        -2.33962014e-02],
       [ 8.88829678e-03, -6.34488463e-02,  1.00380838e-01,
        -1.30051374e-03,  8.36084783e-02, -1.00404337e-01,
        -1.26781732e-01, -1.20743535e-01, -1.38598025e-01,
        -3.08082402e-02],
       [-1.01954192e-01, -8.35505351e-02, -1.19815610e-01,
         6.97904825e-03,  8.41830373e-02, -9.56156850e-02,
        -9.50505435e-02, -4.08826023e-02, -5.09108454e-02,
        -3.21487933e-02],
       [-1.77569717e-01, -5.44275716e-02, -5.90075590e-02,
         3.86309475e-02, -1.61728002e-02, -1.05978541e-01,
        -3.90377715e-02, -4.26430479e-02,  6.01563305e-02,
        -8.24524611e-02],
       [-7.67904744e-02,  4.43191268e-02,  1.48990620e-02,
         8.83519724e-02,  3.21051925e-02,  5.72694838e-03,
        -1.22055180e-01, -1.18737876e-01, -9.97096077e-02,
        -7.85067827e-02],
       [-1.72709376e-01, -1.29958093e-01, -8.99621099e-02,
         1.02254771e-01,  1.95903048e-01,  5.35909832e-03,
        -3.30776349e-02, -1.38253301e-01, -2.38460809e-01,
         1.47309750e-02],
       [-1.46660671e-01, -2.22492069e-02, -1.56837106e-01,
         2.52253920e-01, -8.82822275e-02,  3.32846045e-02,
         3.24276574e-02, -2.16675252e-02, -2.26410300e-01,
        -4.10300493e-02],
       [-2.10500330e-01, -8.92959982e-02, -1.22224376e-01,
        -1.25490874e-02,  5.19202054e-02,  7.35470653e-03,
        -3.53565551e-02, -9.05182287e-02, -5.70299551e-02,
        -1.20432116e-02],
       [-1.01016678e-01, -9.34747979e-02, -6.93750530e-02,
         2.03890651e-02,  1.21070728e-01, -7.42723122e-02,
        -9.43783000e-02, -1.06761046e-01, -1.91974908e-01,
         1.33374259e-02],
       [-1.69483840e-01, -1.23641536e-01, -9.86215919e-02,
         1.42501354e-01,  1.80596262e-01,  3.20677906e-02,
        -9.76325199e-03, -9.69750583e-02, -1.79029703e-01,
         4.45185080e-02],
       [-2.30026960e-01, -1.55068934e-01, -1.55420765e-01,
         1.54814124e-03,  8.05147886e-02, -1.29992127e-01,
         4.67123613e-02, -2.17713714e-01, -1.29212692e-01,
         1.68785155e-02],
       [-1.90404192e-01, -2.25298673e-01, -5.81250116e-02,
        -2.38417387e-02,  1.36959791e-01, -1.46796688e-01,
        -2.06375122e-03, -2.38482222e-01, -3.27058807e-02,
        -5.97536638e-02],
       [-1.51395544e-01, -2.15365887e-01, -6.56528398e-02,
         2.51160562e-01,  3.24645728e-01,  2.83460468e-02,
        -7.92784840e-02, -7.58692473e-02, -2.47377872e-01,
        -1.85067803e-02],
       [-2.01602086e-01, -1.78586751e-01, -3.14563289e-02,
        -6.12829477e-02,  1.20253459e-01, -8.14617053e-02,
        -4.47988510e-02, -6.68551624e-02,  7.92008638e-03,
        -5.97691536e-02],
       [-1.62953034e-01, -8.43949765e-02, -1.66129827e-01,
         1.87980026e-01,  1.20851621e-02, -3.30129340e-02,
        -2.16185823e-02, -1.03120923e-01, -1.19601294e-01,
        -1.40753575e-02],
       [-1.80888206e-01,  5.00197262e-02, -2.87422277e-02,
         8.39663818e-02,  4.73293290e-02, -1.73420832e-02,
        -1.85923949e-01, -9.96253490e-02, -7.63225406e-02,
        -2.93213874e-02],
       [ 2.82085463e-02, -1.27728015e-01,  6.06552064e-02,
         8.80483389e-02,  1.20364837e-01, -1.60190180e-01,
        -5.67752719e-02, -1.93156391e-01, -5.47255315e-02,
        -9.92913246e-02],
       [ 7.88181648e-03, -9.06187072e-02,  5.42960688e-03,
         1.72855183e-02, -1.57072693e-02, -1.34019971e-01,
        -9.60277766e-02, -1.40550882e-01, -9.20628831e-02,
        -2.40895897e-03],
       [-1.75741985e-01,  2.60501951e-02, -2.65618004e-02,
         1.12590082e-01,  7.32031539e-02,  1.13229901e-02,
        -1.07761726e-01, -6.90757334e-02, -1.55004039e-01,
         2.88570225e-02]], dtype=float32)>,
  2: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.18622065e-01, -4.80383635e-02, -7.80763626e-02,
        -5.65892532e-02,  5.65803796e-02, -4.60475534e-02,
        -1.47938013e-01, -1.39093891e-01,  1.05602145e-02,
        -3.03755403e-02],
       [-2.22777084e-01,  2.82550156e-02, -8.23304877e-02,
         7.27979094e-02, -4.05894965e-03, -1.39985010e-02,
        -1.21795088e-01, -2.84186080e-02, -7.95347244e-02,
         3.17907333e-02],
       [-2.31982499e-01, -1.13635600e-01, -6.29137456e-02,
         2.50927508e-02,  1.60225600e-01, -3.47373560e-02,
        -1.91207439e-01, -1.18490927e-01, -2.58446038e-01,
         4.98191789e-02],
       [-1.99452117e-01, -1.26542449e-01, -5.31540513e-02,
         1.15276724e-01,  1.50020584e-01, -4.68152165e-02,
        -9.11516100e-02, -9.15623307e-02, -1.36065409e-01,
         3.56047526e-02],
       [-1.27488852e-01, -5.79035208e-02, -1.09562017e-01,
         1.74391672e-01,  8.17514956e-02, -8.03902298e-02,
        -1.00652844e-01, -7.43787512e-02, -7.21562281e-02,
        -9.31031108e-02],
       [-9.11276639e-02, -1.30131692e-02, -3.48834991e-02,
         3.91736031e-02,  1.04049250e-01,  8.94533098e-03,
         1.42995007e-02, -8.27608109e-02, -1.68300703e-01,
         5.86905926e-02],
       [-1.37468532e-01, -9.41935927e-03, -3.64061221e-02,
         9.07413736e-02,  2.93911435e-02,  2.06236541e-02,
         1.06558166e-01, -5.67114055e-02, -3.10667977e-02,
        -4.51412275e-02],
       [-5.12214527e-02, -6.10571690e-02, -5.43052256e-02,
         5.93981780e-02,  6.77483231e-02, -6.92736208e-02,
        -2.51794532e-02, -1.16102234e-01, -2.40157098e-02,
        -8.25837627e-02],
       [-5.13551161e-02, -7.28777982e-03,  4.10806015e-02,
         9.43285823e-02,  6.30164593e-02, -2.36448608e-02,
        -1.03774831e-01, -7.82925636e-02, -8.73049349e-02,
        -7.23443329e-02],
       [-1.15348957e-01, -1.03502363e-01, -1.64012402e-01,
         1.11891828e-01,  1.94986910e-03, -6.92155510e-02,
         8.53531063e-04, -1.03657715e-01, -1.35936543e-01,
        -5.49774095e-02],
       [-2.00437084e-01, -8.43933225e-02, -1.49807930e-01,
         1.22252554e-01,  1.77868322e-01,  1.06553733e-02,
        -5.68887256e-02, -1.43495083e-01, -1.56670794e-01,
        -3.02786920e-02],
       [-1.72011822e-01, -1.24456048e-01, -9.95005891e-02,
        -5.10253049e-02, -1.08759589e-01, -1.82087868e-02,
         6.44969493e-02, -1.01956531e-01, -4.01061922e-02,
         2.99463458e-02],
       [-1.82381138e-01, -8.84093195e-02, -2.41792910e-02,
         8.70520920e-02,  1.56285048e-01,  4.79183346e-02,
        -1.04085192e-01, -1.21092916e-01, -1.02471031e-01,
         6.23085313e-02],
       [-1.55630529e-01, -1.09569505e-01, -1.01884559e-01,
        -3.38301063e-03,  5.25906458e-02, -6.25553578e-02,
        -5.27070239e-02, -1.57067090e-01, -2.87416503e-02,
         7.68488348e-02],
       [-1.88608259e-01, -8.06754231e-02,  4.69120778e-03,
         6.98600560e-02,  1.00554742e-01,  9.89464819e-02,
        -3.84185091e-02,  4.64840829e-02, -1.15921900e-01,
         4.29807827e-02],
       [-2.29003042e-01, -3.68302464e-02, -2.37280518e-01,
         2.07991526e-02,  3.17515135e-02, -4.42533270e-02,
        -9.63043123e-02, -1.83197528e-01, -2.15664297e-01,
         3.69398966e-02],
       [-2.71373749e-01, -2.29283422e-02, -6.76829144e-02,
         1.62800610e-01,  1.77176848e-01,  7.14185834e-03,
        -5.13582751e-02, -1.24970540e-01, -4.90280688e-02,
        -2.32797526e-02],
       [-1.32671371e-01, -1.05925947e-01,  1.37970224e-02,
        -3.25781107e-03, -1.54116303e-02, -1.49612546e-01,
        -9.20461565e-02, -1.28995076e-01,  6.23100474e-02,
        -4.84038368e-02],
       [-1.77421004e-01,  1.20637566e-03, -7.40719438e-02,
         7.47624785e-03, -1.79810658e-01, -3.50079797e-02,
        -1.17765449e-01, -3.19697484e-02,  9.32934433e-02,
        -5.85267693e-02],
       [-3.12250495e-01, -1.77390873e-03, -1.27836898e-01,
         1.01998165e-01, -1.10829771e-02,  8.53704885e-02,
        -1.12587877e-01, -2.00294912e-01, -1.69150800e-01,
         1.23497970e-01],
       [-3.02309334e-01, -9.59729403e-02, -1.66549951e-01,
         1.69012368e-01,  1.55977756e-01,  9.56503749e-02,
        -6.70976192e-02, -6.32242411e-02, -2.28162110e-01,
         8.53204578e-02],
       [-1.15855612e-01, -1.72991693e-01, -3.85832824e-02,
         1.24845549e-01,  2.33356401e-01, -1.41636744e-01,
        -1.27913952e-02, -1.11601941e-01, -1.02567822e-01,
        -9.03108865e-02],
       [ 1.82036385e-02, -1.93940863e-01,  6.95734471e-02,
         3.61038297e-02,  1.02290824e-01, -7.31098130e-02,
        -8.13976154e-02, -9.43982601e-02, -9.04485583e-03,
        -9.97118652e-02],
       [-2.84229934e-01, -1.14863910e-01, -1.12033233e-01,
         2.11337656e-02,  3.13820317e-02,  1.62747875e-02,
        -2.48021781e-02, -1.96348578e-01, -5.32811284e-02,
        -2.50945091e-02],
       [-1.75024554e-01,  7.07824379e-02, -1.21888414e-01,
         1.93332285e-02, -1.03795715e-01, -2.43023336e-02,
        -1.12834647e-01, -9.55849513e-02, -4.19348627e-02,
        -1.27387196e-02],
       [-2.08064884e-01,  8.48269463e-03, -1.03213869e-01,
         1.55607879e-01,  4.58230674e-02,  6.14446402e-02,
        -9.38809663e-02,  1.08700588e-01, -2.51485646e-01,
        -2.56617889e-02],
       [-1.00302458e-01, -1.61795676e-01, -8.71188715e-02,
        -2.14308351e-02,  5.29092625e-02, -1.30962819e-01,
        -1.38738379e-03, -1.45790309e-01, -2.87716761e-02,
        -1.15036502e-01],
       [ 1.70535184e-02, -9.02777687e-02, -4.85073291e-02,
         4.25172150e-02,  1.11016929e-01, -6.14262186e-02,
        -8.69335532e-02, -2.02684641e-01, -1.18313126e-01,
        -8.45843554e-03],
       [-2.11484566e-01, -1.84206426e-01, -6.17977679e-02,
         1.93131089e-01,  2.22305477e-01,  7.55271316e-02,
         2.70785391e-03, -5.73645830e-02, -1.90291941e-01,
         1.37506947e-02],
       [-1.54488787e-01, -8.49604756e-02,  4.00311872e-03,
         2.15890780e-02,  1.30104303e-01, -5.66638112e-02,
        -1.74196199e-01, -1.20592259e-01, -1.81575328e-01,
        -2.15945989e-02],
       [-1.13820657e-01, -1.43980190e-01, -4.07709181e-03,
         7.57762119e-02,  2.12188214e-01, -2.82883644e-02,
        -3.65009271e-02, -9.06620547e-02, -1.15396664e-01,
        -3.93195450e-03],
       [-1.89802587e-01, -3.11169103e-02, -6.96812272e-02,
         8.55413303e-02,  1.31849259e-01,  3.56973112e-02,
        -6.38814121e-02, -5.01507819e-02, -8.23300108e-02,
        -4.75421473e-02],
       [-1.42080382e-01, -1.26603737e-01, -1.37690693e-01,
         3.48311365e-02,  1.34378403e-01, -5.53520471e-02,
        -5.54120988e-02, -1.17055811e-01, -1.12406179e-01,
        -3.78189981e-03],
       [-2.14066803e-02, -1.67680278e-01,  6.09776825e-02,
         5.87410182e-02,  2.53670871e-01, -1.04848817e-01,
        -1.12973109e-01, -2.27827251e-01, -1.72088683e-01,
        -8.71021152e-02],
       [-2.08294600e-01, -1.22842968e-01, -3.96946445e-03,
         2.50235498e-02,  4.90248203e-06, -7.23325312e-02,
         3.18043828e-02, -1.22997969e-01, -9.40374881e-02,
        -3.58204171e-02],
       [-1.08865075e-01, -1.83587633e-02, -3.86281759e-02,
         4.71176878e-02,  1.32990122e-01,  2.90070102e-02,
        -1.74620092e-01, -1.55035168e-01, -1.27626136e-01,
        -1.08029723e-01],
       [-1.92430407e-01, -1.43428802e-01, -1.61632031e-01,
        -2.64168978e-02,  6.47923052e-02, -1.05036378e-01,
         1.31582320e-02, -1.51326209e-01,  1.62273869e-02,
        -7.95314014e-02],
       [-1.46146014e-01, -1.43946409e-01, -2.02701569e-01,
        -2.02349648e-02,  7.07512572e-02, -5.21270335e-02,
         2.80647911e-02, -1.78616136e-01, -1.65089011e-01,
        -4.18910980e-02],
       [-2.18131930e-01,  2.99300849e-02, -1.15146317e-01,
         2.07010388e-01,  3.67674604e-02, -1.75569654e-02,
        -6.71223179e-03, -1.40600979e-01, -1.08277187e-01,
        -6.51790202e-02],
       [-1.13692321e-01, -1.26904696e-01, -8.00543875e-02,
         9.49823037e-02,  1.10704221e-01, -4.09431309e-02,
        -2.81966254e-02, -9.69650894e-02, -1.16475344e-01,
         4.18098867e-02],
       [-2.12048978e-01, -8.93960595e-02, -4.96708192e-02,
         1.10165894e-01,  8.15800726e-02,  1.13921992e-01,
         3.49859148e-02, -1.03857458e-01, -4.97306138e-02,
        -4.60503958e-02],
       [-1.76234990e-01, -7.70872682e-02, -6.18268102e-02,
         1.52689919e-01,  1.94699034e-01,  5.39069399e-02,
        -4.96895090e-02, -3.31558585e-02, -1.85139030e-01,
         7.87408799e-02],
       [-2.38845885e-01, -5.12747169e-02, -1.36963889e-01,
         9.96940657e-02,  6.54694661e-02,  3.57060879e-02,
        -7.02928603e-02, -8.71425197e-02, -1.58412352e-01,
         5.93720078e-02],
       [-1.58926964e-01,  6.26134425e-02, -1.29999816e-01,
         1.28382802e-01,  1.32834196e-01, -3.59639674e-02,
        -2.32275501e-01, -4.31256592e-02, -3.14703882e-01,
        -1.46043569e-01],
       [-1.96496427e-01, -4.94864807e-02, -1.22942358e-01,
         1.20347358e-01,  1.33029252e-01,  4.28225845e-02,
        -1.60483509e-01,  4.79465723e-03, -1.61326826e-01,
         2.86684260e-02],
       [-5.49462512e-02, -5.38495220e-02, -8.56895149e-02,
         6.73860610e-02, -2.37850286e-02, -3.77055667e-02,
        -1.47368640e-01, -3.10028791e-02, -6.09867871e-02,
         2.25309990e-02],
       [-2.29001462e-01, -1.15906537e-01, -9.69577432e-02,
        -8.48331898e-02, -6.59356862e-02,  8.02898258e-02,
         1.58300266e-01, -1.41157225e-01, -6.84659481e-02,
        -6.24797419e-02],
       [-2.27232039e-01, -1.41192287e-01, -1.66422382e-01,
         1.41542733e-01,  1.79099292e-01,  7.30920732e-02,
         5.12477048e-02, -2.19374895e-03, -1.74755171e-01,
         5.67732751e-03],
       [-1.11173958e-01, -2.32646465e-02, -1.73542351e-02,
        -1.11154914e-02, -9.74261165e-02, -1.28799416e-02,
        -1.24198586e-01, -4.74506393e-02,  5.55145591e-02,
        -9.86203253e-02],
       [-5.34499809e-02, -1.07035644e-01, -5.87719604e-02,
         7.92650878e-03,  8.19761157e-02, -8.98833647e-02,
        -6.50129616e-02, -2.32616365e-02, -6.52160794e-02,
        -2.63372734e-02],
       [-1.51043698e-01, -8.00651461e-02, -1.01332128e-01,
         7.20168129e-02,  3.09408456e-02, -3.33223343e-02,
         6.06775284e-03, -2.29082108e-01,  2.69115716e-03,
        -6.04840592e-02],
       [-6.18841089e-02, -8.92749727e-02, -4.41954359e-02,
        -6.95271045e-03, -6.51970804e-02, -1.45737797e-01,
        -3.45894396e-02, -1.11928955e-01,  1.18997246e-02,
        -1.12951092e-01],
       [-1.26971349e-01, -2.19437853e-02, -5.61639816e-02,
        -3.81761640e-02, -1.27885267e-01, -3.91721912e-02,
        -8.56075808e-02, -5.15636317e-02,  7.06897527e-02,
        -9.55859199e-02],
       [-2.24867851e-01, -3.28013003e-02, -6.02710769e-02,
         8.64380300e-02,  1.47626460e-01,  1.58059463e-01,
         1.27491765e-02,  1.54137611e-03, -1.21143982e-01,
        -6.59556389e-02],
       [-1.34460121e-01, -4.31514680e-02, -1.11357309e-01,
         2.30784371e-01,  2.78172866e-02, -7.04825521e-02,
        -5.16957939e-02,  1.68063641e-02, -1.00909561e-01,
        -5.26818037e-02],
       [-7.23162889e-02, -1.87913701e-02, -1.86011717e-02,
         7.93044269e-03, -1.40181303e-01, -4.87775207e-02,
        -8.96310508e-02, -9.82604846e-02,  1.10361427e-01,
        -1.19583145e-01],
       [-2.08006114e-01, -1.35998905e-01, -7.12261200e-02,
         9.55476984e-02,  1.10977314e-01,  3.86161208e-02,
        -3.89417224e-02, -2.01288596e-01, -1.55291021e-01,
         1.35656431e-01],
       [-1.65449306e-01, -2.72232294e-03, -2.19225399e-02,
         6.01486862e-02,  2.65012942e-02, -1.18527263e-02,
        -2.01358795e-01, -8.38685185e-02,  5.60516864e-03,
         2.89868824e-02],
       [-2.25441739e-01, -9.65888649e-02, -1.12556458e-01,
         1.13298908e-01,  9.96981561e-02,  2.21842527e-03,
        -2.75023729e-02, -1.21232875e-01, -1.84147239e-01,
         7.08261728e-02],
       [-2.48636827e-02, -1.95411325e-01,  7.19704255e-02,
        -1.54576451e-02,  8.29630941e-02, -1.01169080e-01,
        -1.05838925e-02, -5.05419150e-02, -1.87411904e-03,
        -1.69678256e-02],
       [-1.07642174e-01, -7.42681921e-02, -1.11681670e-02,
         4.31303531e-02,  1.15110792e-01, -1.10990740e-01,
        -8.58963057e-02, -1.15901396e-01, -7.08180517e-02,
        -4.99593839e-02],
       [-2.51329958e-01, -3.11789215e-02, -7.49738142e-02,
         7.77494758e-02,  1.30425021e-01, -4.77944538e-02,
        -1.56344920e-02, -9.52227265e-02, -2.12230906e-02,
        -3.75425778e-02],
       [-2.69328281e-02, -6.26106635e-02, -5.45325316e-02,
         1.06040694e-01,  9.10804272e-02, -1.29021369e-02,
        -1.13914937e-01,  4.24201041e-02, -5.70038594e-02,
        -7.49924183e-02],
       [-1.55278623e-01, -1.84793204e-01, -1.28665328e-01,
        -1.59393400e-02,  7.25794435e-02, -3.72538641e-02,
         1.77359134e-02, -1.76151663e-01, -5.50818369e-02,
        -1.32631063e-02]], dtype=float32)>,
  3: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-0.1233949 ,  0.0306532 ,  0.01506481,  0.14869131, -0.00242151,
        -0.01294589, -0.02127065, -0.05192772, -0.09956914, -0.10299885],
       [-0.1798876 , -0.1960856 , -0.05181348,  0.02666742,  0.17055541,
        -0.03525529, -0.02422158, -0.18629645, -0.05319609, -0.03938816],
       [-0.11730058, -0.10043657, -0.11140861, -0.05790957, -0.00410666,
        -0.12318806, -0.04031731, -0.10785541, -0.1313473 ,  0.04424933],
       [-0.17092255, -0.11166693, -0.04049611,  0.07196078,  0.01569334,
         0.03302138, -0.02389625, -0.06720973, -0.0026828 , -0.0447191 ],
       [-0.13086951, -0.13084534, -0.02817666, -0.00141026,  0.0780483 ,
        -0.06985046, -0.10983647, -0.12160238, -0.18786824,  0.0870638 ],
       [-0.05607726, -0.05799097, -0.00473941,  0.18969995,  0.03421883,
        -0.0412782 , -0.06431054, -0.06569731, -0.14644182, -0.02758818],
       [-0.05323289,  0.00337919,  0.02187643,  0.08584414,  0.02806523,
        -0.00749452, -0.09806205, -0.06961535, -0.06048461, -0.08008826],
       [-0.2853573 ,  0.08507048, -0.1819371 ,  0.08309032, -0.04780732,
        -0.0339517 , -0.12391321, -0.13554102, -0.24620405,  0.12064482],
       [-0.1081121 ,  0.01817158,  0.02254607,  0.0007066 ,  0.06363986,
         0.00354922, -0.05468165, -0.1655403 , -0.09062124, -0.00828907],
       [-0.17373204, -0.06299572, -0.06730817,  0.05633632,  0.01303217,
         0.08219384, -0.12604564, -0.01601628, -0.04890175, -0.01032439],
       [-0.0730572 , -0.22056088, -0.07132262,  0.08563757,  0.21419477,
        -0.01172113, -0.0083829 , -0.10937741, -0.07716778, -0.12192303],
       [-0.11296036, -0.116445  , -0.12978333,  0.12685291,  0.01170371,
        -0.09306131, -0.07949975, -0.11163814,  0.02412134, -0.13977948],
       [-0.13646948, -0.17518115, -0.04456852,  0.10183477,  0.24787205,
        -0.11475898, -0.01483224, -0.22470903, -0.05406256, -0.07095394],
       [-0.18059935, -0.12776956, -0.04012675,  0.05005774,  0.00391461,
         0.00774977,  0.05236959, -0.19111437, -0.12248493,  0.02410165],
       [-0.07076136, -0.08640838, -0.01024238,  0.0329275 ,  0.01842605,
        -0.10478152, -0.05420242, -0.1214268 , -0.05450567, -0.09131949],
       [-0.1070205 , -0.1787869 , -0.06078527,  0.05454218,  0.00383853,
        -0.00506437,  0.04100275, -0.11361884, -0.05096018, -0.03284677],
       [-0.27122506, -0.03182219, -0.11228865,  0.0386443 ,  0.00176525,
        -0.02525187, -0.06824939, -0.20644006, -0.2259509 ,  0.02092525],
       [-0.09234443, -0.16262658,  0.02147478,  0.01367255, -0.01726622,
        -0.05533799,  0.01683374, -0.0624358 ,  0.06884648, -0.04681031],
       [-0.16584386, -0.05551966, -0.08536382,  0.05694874,  0.2051315 ,
         0.08645384, -0.04793188, -0.07859614, -0.11775545,  0.02689572],
       [-0.15686771, -0.13796893, -0.08758966,  0.16308506,  0.12949377,
         0.01149958, -0.0668726 , -0.1009377 , -0.20018153,  0.09330714],
       [-0.0707716 , -0.04073169, -0.02052977,  0.01056305, -0.0614361 ,
        -0.05007634, -0.10553206, -0.09245164,  0.03710708, -0.12483476],
       [-0.10809626, -0.12443198,  0.02280758,  0.05773169,  0.10902789,
         0.02916887, -0.11796592, -0.0984447 , -0.17568226, -0.02483597],
       [-0.09507307, -0.13847943, -0.02537843, -0.03727494,  0.05088034,
        -0.128656  ,  0.01344421, -0.25490728, -0.05057207, -0.10618179],
       [-0.10588515, -0.1297527 ,  0.00787944,  0.1104769 ,  0.03638612,
         0.03561652,  0.02687478, -0.08724042, -0.03091092, -0.07410491],
       [-0.12104618, -0.16634186, -0.048428  ,  0.01070444,  0.14705442,
        -0.11314987, -0.12083325, -0.2758017 , -0.16267428, -0.03306893],
       [-0.16881996, -0.00158609, -0.04261097, -0.02474274, -0.21331058,
        -0.05604962, -0.10995254, -0.06082068,  0.11800019, -0.09666707],
       [-0.34807435, -0.0721263 , -0.18738541,  0.00754841,  0.04921692,
        -0.02376209, -0.05598007, -0.17670149, -0.06020834, -0.00246602],
       [-0.12451831, -0.12685084, -0.09811206,  0.1507258 ,  0.17246085,
        -0.04552718,  0.0084536 ,  0.01783425, -0.19823895, -0.03365051],
       [-0.10554817, -0.03067023, -0.00182529,  0.0952156 ,  0.00701311,
        -0.03539346, -0.02381842, -0.11934233, -0.13552146, -0.0322886 ],
       [-0.21983153,  0.04257854, -0.17796822,  0.09056215, -0.00212698,
        -0.02422053, -0.08598864, -0.10543736, -0.16683328, -0.01587498],
       [-0.03512586, -0.1359188 , -0.00072425,  0.01157489, -0.07950523,
        -0.06773412, -0.03372143, -0.12709671,  0.07464442, -0.08793914],
       [-0.22482513, -0.01410258, -0.23216373,  0.07726371, -0.02403733,
        -0.09269173, -0.01141784, -0.17402723, -0.19002248,  0.01758099],
       [-0.19026834, -0.08013971, -0.09418195,  0.12539543,  0.1896056 ,
         0.0159035 , -0.07923675, -0.09576906, -0.18525887,  0.01947794],
       [-0.1362203 , -0.13044748, -0.16077709,  0.1844739 ,  0.16317676,
         0.01158957, -0.00744279, -0.00952011, -0.1780558 , -0.04787514],
       [-0.07263175, -0.02919874,  0.0268159 ,  0.15711974,  0.23556638,
        -0.03426027, -0.16793877, -0.02719302, -0.28967535, -0.11363058],
       [-0.15040864, -0.16377147, -0.00270315, -0.01041726,  0.07433004,
        -0.07639292, -0.04018666, -0.09735382,  0.01693606, -0.07350131],
       [-0.04787125, -0.21343072, -0.05340545,  0.16214728,  0.05696563,
        -0.10293255,  0.01295586, -0.05438276,  0.01408707, -0.10137461],
       [-0.21702345, -0.0162362 , -0.12763332, -0.07799495,  0.00308742,
        -0.06118096, -0.15705487, -0.16007946, -0.14968961,  0.07305033],
       [-0.24111268, -0.07762159, -0.09968221, -0.00736153,  0.01975157,
        -0.01182853, -0.08551721, -0.15788263, -0.17283657,  0.05355252],
       [-0.21137154, -0.21808289, -0.07482084,  0.12689087,  0.20352782,
        -0.02761652, -0.02241259, -0.16977169, -0.08073498,  0.08364114],
       [-0.15689062, -0.11457565, -0.16742393,  0.11367387,  0.00543004,
        -0.06531213,  0.00430411, -0.14402992, -0.10051601, -0.04988737],
       [-0.17600064, -0.17445025, -0.0466633 , -0.05958818,  0.05207215,
        -0.07019341,  0.05897637, -0.18470253, -0.08656207,  0.03193514],
       [-0.1650149 , -0.12465855, -0.09345284, -0.06179336,  0.07355883,
        -0.09992584,  0.00799271, -0.15036762, -0.05302164,  0.01202243],
       [-0.25983816, -0.10974929, -0.10608997,  0.18112555,  0.12824729,
         0.07694422, -0.02370424, -0.08973577, -0.17237347,  0.08074867],
       [-0.05988602,  0.02283674,  0.08786452, -0.07609625,  0.06072552,
         0.03496878, -0.05544003, -0.04938322, -0.14522704, -0.02790395],
       [ 0.05720642, -0.17634542,  0.02385233,  0.11742905,  0.08179362,
        -0.10522201, -0.08602263, -0.10910565,  0.04325084, -0.05947156],
       [-0.12381347, -0.18064402, -0.03211292, -0.07073596,  0.00421129,
        -0.06442264, -0.06178114, -0.20215663,  0.08510509, -0.0711394 ],
       [-0.26446068, -0.15075937, -0.12760888,  0.13857606,  0.1704602 ,
         0.0429188 ,  0.03389472, -0.03190601, -0.2272865 ,  0.06731693],
       [-0.17169651, -0.02543844,  0.05840578,  0.01680346,  0.02311387,
         0.0448919 , -0.02873517, -0.08897491, -0.09153538, -0.07077759],
       [-0.22213744, -0.05944808, -0.05999316,  0.00400291,  0.05323791,
        -0.00247826, -0.04231683, -0.16135469,  0.00881679,  0.01791727],
       [-0.25142348, -0.10350126, -0.15173164,  0.10354704,  0.14224836,
         0.04261138,  0.00930562, -0.04991762, -0.21571505,  0.01046845],
       [ 0.02619136, -0.07929154, -0.0173293 ,  0.0826678 ,  0.05538072,
        -0.01912966, -0.04582095, -0.04591579, -0.12245479, -0.0429607 ],
       [-0.16802856, -0.11705022, -0.13955832,  0.08280446,  0.14567949,
        -0.03441104, -0.0434529 , -0.04674514, -0.11791277, -0.0014071 ],
       [-0.04319119,  0.02680471,  0.02172658,  0.09181707,  0.02849531,
         0.01973759, -0.12218198, -0.06670761, -0.07195242, -0.07322703],
       [-0.18286295, -0.1568456 , -0.16993785,  0.0440698 , -0.0308438 ,
        -0.01920044, -0.03012326, -0.1515145 , -0.03417185, -0.05762017],
       [-0.11224627, -0.10305995, -0.10776886, -0.08955202,  0.0487871 ,
        -0.11440354, -0.06885469, -0.18813613, -0.02833985, -0.04925932],
       [-0.26331362,  0.01517712, -0.03767164, -0.00374702, -0.05034718,
         0.04829896, -0.09870179, -0.19696185, -0.18453416,  0.14341515],
       [-0.11017214, -0.0355324 , -0.05794729,  0.12473413,  0.01790386,
         0.02730063, -0.00530912,  0.01963069, -0.01265548, -0.10504578],
       [-0.18826915, -0.07540239, -0.08486896,  0.06244456,  0.11082812,
         0.02332492, -0.06626737, -0.05723035, -0.15291232,  0.06211346],
       [-0.12491722, -0.01731017, -0.03308757,  0.13322315,  0.04470283,
         0.01368196,  0.03849104,  0.01989594, -0.0831395 , -0.07313108],
       [-0.04529165, -0.1118453 , -0.04297772,  0.01698605,  0.07945478,
        -0.06994987, -0.08753476, -0.11262193, -0.11539245,  0.05767304],
       [-0.24347621, -0.09062086, -0.08512177,  0.05760796,  0.1256511 ,
        -0.00278029,  0.01031994, -0.0498026 , -0.02609731, -0.0913346 ],
       [-0.12913392, -0.0360139 , -0.04862705,  0.06109999,  0.00598583,
        -0.02087297, -0.00709179,  0.04664   , -0.02082852, -0.01882938],
       [-0.12633127, -0.14244772, -0.08439697,  0.16796847,  0.19576558,
        -0.08141258,  0.00696393, -0.18966636, -0.21226291, -0.08730254]],
      dtype=float32)>
} }
2022-12-15 02:03:41.411701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

복원된 함수를 호출하는 것은 저장된 모델(tf.keras.Model.predict)에 대한 순방향 전달일 뿐입니다. 로드된 함수를 계속 훈련하려면 어떻게 해야 할까요? 또는 로드된 함수를 더 큰 모델에 포함해야 한다면 어떻게 해야 할까요? 일반적으로 이를 해결하는 방법은 이 로드된 객체를 Keras 레이어로 래핑하는 것입니다. 다행스럽게도 TF Hub에는 다음과 같이 이 목적을 위한 hub.KerasLayer가 있습니다.

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
2022-12-15 02:03:42.057240: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
235/235 [==============================] - 7s 7ms/step - loss: 0.3255 - sparse_categorical_accuracy: 0.9090
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.0960 - sparse_categorical_accuracy: 0.9729

위의 예에서 Tensorflow Hub의 hub.KerasLayertf.saved_model.load에서 다시 로드된 결과를 다른 모델을 빌드하는 데 사용되는 Keras 레이어로 래핑합니다. 이것은 전이 학습에 매우 유용합니다.

어떤 API를 사용해야 할까요?

저장의 경우, Keras 모델로 작업한다면 하위 수준 API에서 제공하는 추가적 제어가 필요한 경우가 아니면 Model.save API를 사용합니다. 저장하는 대상이 Keras 모델이 아닌 경우 하위 수준 API인 tf.saved_model.save가 유일한 선택입니다.

로드의 경우, API 선택은 모델 로드 API에서 얻고자 하는 결과에 따라 다릅니다. Keras 모델을 가져올 수 없거나 원하지 않으면 tf.saved_model.load를 사용합니다. 그렇지 않으면 tf.keras.models.load_model을 사용합니다. Keras 모델을 저장한 경우에만 Keras 모델을 다시 가져올 수 있습니다.

API를 혼합하여 구성할 수 있습니다. Model.save를 사용하여 Keras 모델을 저장하고 하위 수준 API인 tf.saved_model.load를 사용하여 비 Keras 모델을 로드할 수 있습니다.

model = get_model()

# Saving the model using Keras `Model.save`
model.save(keras_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

로컬 장치에서 저장/로드하기

원격 장치에서 훈련하는 동안 로컬 I/O 장치에서 저장 및 로드할 때(예: Cloud TPU 사용 시) tf.saved_model.SaveOptionstf.saved_model.LoadOptions에서 옵션 experimental_io_device를 사용하여 I/O 장치를 localhost로 설정해야 합니다. 예를 들면 다음과 같습니다.

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

주의 사항

한 가지 특별한 경우는 특정한 방식으로 Keras 모델을 생성한 다음 훈련 전에 저장할 때입니다. 예를 들면 다음과 같습니다.

class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(keras_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f255805b340>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f255805b340>, because it is not built.
ValueError:  Model <__main__.SubclassedModel object at 0x7f255805b340> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.

SavedModel은 tf.function을 추적할 때 생성된 tf.types.experimental.ConcreteFunction 객체를 저장합니다(자세한 내용은 그래프 및 tf.function 소개 가이드에서 함수 추적은 언제입니까? 참조). 이와 같은 ValueError가 발생하면 Model.save가 추적된 ConcreteFunction을 찾거나 생성할 수 없기 때문입니다.

주의: 적어도 하나의 ConcreteFunction 없이 모델을 저장하면 안 됩니다. 그렇지 않으면 하위 수준 API가 ConcreteFunction 서명 없이 SavedModel을 생성하기 때문입니다(SavedModel 형식에 대해 자세히 알아보기). 예를 들면 다음과 같습니다.

tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f255805b340>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f255805b340>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f2570072d60>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f2570072d60>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
_SignatureMap({})

일반적으로 모델의 순방향 전달(call 메서드)은 모델이 처음으로 호출될 때 종종 Model.fit 메서드를 통해 자동으로 추적됩니다. 예를 들어 첫 번째 레이어를 tf.keras.layers.InputLayer 또는 다른 레이어 유형으로 만들고 이를 input_shape 키워드 인수를 전달하여 입력 형상을 설정하면 Keras 순차형함수형 API에서 ConcreteFunction을 생성할 수도 있습니다.

모델에 추적된 ConcreteFunction이 있는지 확인하려면 Model.save_specNone인지 확인하세요.

print(my_model.save_spec() is None)
True

tf.keras.Model.fit을 사용하여 모델을 훈련하고 save_spec이 정의되고 모델 저장이 작동하는지 확인합니다.

BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(keras_model_path)
Epoch 1/2
7/7 [==============================] - 1s 3ms/step - loss: 11.9682
Epoch 2/2
7/7 [==============================] - 0s 2ms/step - loss: 11.5175
False
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25341793a0>, 139801521600336), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25341793a0>, 139801521600336), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25701bfac0>, 139798770240176), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25701bfac0>, 139798770240176), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25341793a0>, 139801521600336), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25341793a0>, 139801521600336), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25701bfac0>, 139798770240176), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f25701bfac0>, 139798770240176), {}).
WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets