使用分布策略保存和加载模型

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

概述

本教程演示了如何在训练期间或训练之后使用 tf.distribute.Strategy 以 SavedModel 格式保存和加载模型。有两种用于保存和加载 Keras 模型的 API:高级(tf.keras.Model.savetf.keras.models.load_model)和低级(tf.saved_model.savetf.saved_model.load)。

要全面了解 SavedModel 和序列化,请阅读已保存模型指南Keras 模型序列化指南。我们从一个简单的示例开始。

小心:TensorFlow 模型是代码,对于不受信任的代码,一定要小心。请参阅安全使用 TensorFlow 以了解详情。

导入依赖项:

import tensorflow_datasets as tfds

import tensorflow as tf
2023-11-07 23:17:51.828961: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 23:17:51.829033: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 23:17:51.830761: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

使用 TensorFlow Datasets 和 tf.data 加载和准备数据,并使用 tf.distribute.MirroredStrategy 创建模型:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

使用 tf.keras.Model.fit 训练模型:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
2023-11-07 23:17:58.325808: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699399085.009935  522218 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
235/235 [==============================] - ETA: 0s - loss: 0.3166 - sparse_categorical_accuracy: 0.9101INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
235/235 [==============================] - 9s 9ms/step - loss: 0.3166 - sparse_categorical_accuracy: 0.9101
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.0970 - sparse_categorical_accuracy: 0.9722
<keras.src.callbacks.History at 0x7f7a0c273700>

保存和加载模型

现在,您已经有一个简单的模型可供使用,让我们探索保存/加载 API。有两种可用的 API:

Keras API

以下为使用 Keras API 保存和加载模型的示例:

keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)

恢复无 tf.distribute.Strategy 的模型:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
235/235 [==============================] - 2s 5ms/step - loss: 0.0644 - sparse_categorical_accuracy: 0.9817
Epoch 2/2
235/235 [==============================] - 1s 4ms/step - loss: 0.0506 - sparse_categorical_accuracy: 0.9851
<keras.src.callbacks.History at 0x7f7a6b052490>

恢复模型后,您可以继在它上面续训练,甚至不需要再次调用 Model.compile,因为它在保存之前已经编译。模型以 Keras zip 存档格式保存,由 .keras 扩展名标记。有关详情,请参阅 Keras 保存指南

现在,恢复模型并使用 tf.distribute.Strategy 对其进行训练:

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2023-11-07 23:18:13.264179: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2023-11-07 23:18:13.328327: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
15/235 [>.............................] - ETA: 2s - loss: 0.0740 - sparse_categorical_accuracy: 0.9812
2023-11-07 23:18:13.954467: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2023-11-07 23:18:13.988420: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2023-11-07 23:18:14.002096: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
235/235 [==============================] - 3s 12ms/step - loss: 0.0652 - sparse_categorical_accuracy: 0.9811
Epoch 2/2
235/235 [==============================] - 3s 11ms/step - loss: 0.0510 - sparse_categorical_accuracy: 0.9851

正如 Model.fit 输出所示,tf.distribute.Strategy 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。

tf.saved_model API

使用较低级别的 API 保存模型类似于 Keras API:

model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

可以使用 tf.saved_model.load 进行加载。但是,由于它是一个较低级别的 API(因此用例范围更广泛),不会返回 Keras 模型。相反,它会返回一个对象,其中包含可用于进行推断的函数。例如:

DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

加载的对象可能包含多个函数,每个函数与一个键关联。"serving_default" 键是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[ 0.1147381 , -0.40301684,  0.04525661, ..., -0.00361927,
        -0.16612433, -0.0153662 ],
       [ 0.23993656, -0.2951342 ,  0.11567082, ..., -0.018293  ,
        -0.25642908, -0.12756145],
       [ 0.12482558, -0.1498732 , -0.02880868, ...,  0.10482813,
        -0.04954109,  0.03868026],
       ...,
       [ 0.24344432, -0.15270862,  0.07865744, ...,  0.07004395,
        -0.13115454, -0.06379548],
       [ 0.18582651, -0.06589739,  0.0035552 , ...,  0.05252723,
        -0.15672804, -0.01999891],
       [ 0.10272254, -0.19479191,  0.0195781 , ..., -0.00804898,
        -0.19664931, -0.12797697]], dtype=float32)>}
2023-11-07 23:18:20.522652: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

您还可以采用分布式方式加载和进行推断:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2023-11-07 23:18:20.756679: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
{'dense_3': PerReplica:{
  0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 1.14738092e-01, -4.03016806e-01,  4.52566147e-02,
         4.16727774e-02, -1.74182177e-01, -1.42247193e-02,
         1.34800702e-01, -3.61925364e-03, -1.66124314e-01,
        -1.53662562e-02],
       [ 2.39936545e-01, -2.95134246e-01,  1.15670875e-01,
         6.21562228e-02, -5.74412942e-03,  9.67399850e-02,
         6.38412088e-02, -1.82929561e-02, -2.56429106e-01,
        -1.27561510e-01],
       [ 1.24825627e-01, -1.49873167e-01, -2.88086776e-02,
         5.14284968e-02, -1.49795488e-01,  5.63685335e-02,
        -1.25008225e-02,  1.04828164e-01, -4.95411083e-02,
         3.86802554e-02],
       [-3.14452238e-02, -1.13862000e-01,  3.59135307e-02,
         3.37418914e-03, -9.56010595e-02,  8.71438980e-02,
         3.11656222e-02, -8.25512223e-03, -2.66045481e-02,
         4.20068875e-02],
       [ 3.10148541e-02, -1.27727151e-01,  2.81632431e-02,
         1.84415001e-02, -5.98692782e-02,  3.23428214e-02,
         1.68298408e-02,  1.26852781e-01, -2.95518860e-02,
         5.77915385e-02],
       [ 2.76307046e-01, -3.25123250e-01,  4.44254279e-02,
        -3.92483473e-02, -1.58822447e-01, -1.81119330e-02,
         1.93657145e-01, -1.59449875e-03, -1.98606268e-01,
        -1.35647625e-01],
       [ 2.55708933e-01, -2.18255267e-01,  1.28493115e-01,
        -4.39336300e-02, -1.41243652e-01,  8.09284598e-02,
         1.08400889e-01,  1.75348297e-01, -7.91441277e-02,
        -8.59785229e-02],
       [ 5.21239787e-02, -1.76998034e-01, -2.13304833e-02,
        -7.83970486e-03, -5.99634312e-02, -3.24056335e-02,
         1.27930865e-01,  1.46872252e-02, -9.80568081e-02,
        -7.00330585e-02],
       [ 2.82409132e-01, -1.93280309e-01, -8.35070014e-03,
         4.52031568e-02, -7.35079646e-02,  7.34596774e-02,
         1.23129845e-01,  8.56838673e-02, -1.56609535e-01,
        -1.27697110e-01],
       [ 1.95753723e-02, -6.63016886e-02,  1.30141824e-02,
         2.76767928e-02, -3.44424620e-02,  3.92978042e-02,
         5.05526550e-02, -5.25370911e-02, -3.88029367e-02,
        -6.39644451e-03],
       [ 2.11090848e-01, -2.83745229e-01,  3.65779810e-02,
         2.85095610e-02, -2.34931260e-02,  4.11867388e-02,
         1.30671784e-02, -3.93030792e-03, -2.42590263e-01,
        -4.02898788e-02],
       [ 1.81141794e-01, -3.37841779e-01,  1.12994514e-01,
        -2.88403705e-02, -1.92396373e-01, -1.70339942e-02,
         2.04947889e-01,  5.92036843e-02, -1.37502417e-01,
        -1.25579387e-01],
       [ 1.37473986e-01, -6.17838576e-02,  2.42818594e-02,
        -4.88483533e-02, -2.27512754e-02, -4.38422635e-02,
         1.40101425e-02,  1.89208180e-01, -2.73643304e-02,
         4.38259095e-02],
       [ 1.37770995e-01, -1.95421547e-01,  8.27945471e-02,
        -5.11158891e-02, -1.46871895e-01, -4.40512337e-02,
         1.36922881e-01,  3.38321365e-02, -6.93419352e-02,
         7.48251379e-02],
       [ 1.42696083e-01, -2.57114649e-01,  7.77321011e-02,
        -3.41028869e-02, -1.32499710e-01,  2.23793611e-02,
         5.32371849e-02,  8.72942060e-02, -8.10581073e-02,
        -6.02141619e-02],
       [ 1.80777013e-01, -4.03917909e-01,  6.90800399e-02,
        -2.51645818e-02, -1.23742625e-01, -4.05675992e-02,
         1.58395991e-01, -4.84118313e-02, -1.65474027e-01,
         2.75065750e-03],
       [ 8.19395036e-02, -2.05518156e-01,  7.86826313e-02,
         8.60604942e-02, -1.14381023e-01,  1.78768933e-02,
         6.73884302e-02, -4.87268753e-02, -5.19034900e-02,
        -1.28923394e-02],
       [ 8.34569931e-02, -2.35908389e-01, -1.88448504e-02,
        -2.51799859e-02, -9.28956643e-02, -3.55080329e-02,
         1.06324434e-01, -2.64462829e-02, -8.51299688e-02,
        -1.65816441e-01],
       [ 1.38177112e-01, -8.78564119e-02,  3.55074555e-02,
         1.48693472e-02, -7.11714923e-02, -1.00443840e-01,
         1.47152349e-01,  9.27866548e-02, -1.08396456e-01,
        -1.49118602e-02],
       [ 8.80307406e-02, -3.10434043e-01,  6.94117248e-02,
         4.18802537e-03, -1.08778998e-01, -4.30073291e-02,
         8.84500146e-02,  4.38911989e-02, -1.53112322e-01,
        -5.85531518e-02],
       [ 2.34211773e-01, -3.66408825e-01, -1.79111585e-02,
         6.03636056e-02, -1.43945143e-01,  4.89909239e-02,
        -8.40552300e-02, -1.03743076e-02, -9.86053273e-02,
         4.31424566e-02],
       [ 1.02573961e-01, -1.71199620e-01, -2.37351842e-02,
         4.09821328e-03, -2.38268785e-02, -1.18338242e-02,
         1.08704492e-02,  2.90585421e-02, -6.05013967e-02,
         3.61278653e-03],
       [ 8.79913270e-02, -1.94176421e-01, -2.00331323e-02,
        -2.16388889e-02, -7.72256553e-02,  1.26879945e-01,
         7.20053166e-03,  4.87994403e-05, -2.52155736e-02,
        -7.43481815e-02],
       [ 1.55093133e-01, -1.88648060e-01,  1.71378314e-01,
         4.07023504e-02, -2.79899277e-02,  7.33412281e-02,
         1.95677489e-01, -5.18236943e-02, -2.44920149e-01,
        -1.47026524e-01],
       [ 8.88790637e-02, -7.06918687e-02,  8.23669136e-03,
         3.28989029e-02, -4.74687740e-02, -3.91436554e-02,
         2.02782646e-01, -5.95949031e-02, -1.41937882e-01,
        -1.76825047e-01],
       [ 8.08424950e-02, -2.05268264e-01, -2.22340673e-02,
         5.69474809e-02, -1.63213357e-01, -7.66383857e-03,
         1.01891026e-01,  1.39792293e-01,  7.43784755e-03,
         1.46049857e-02],
       [ 4.65198010e-02, -1.18421592e-01,  5.71081787e-03,
        -1.99123714e-02, -2.86436118e-02, -1.18452031e-02,
         7.82068148e-02, -7.94757456e-02, -1.25772312e-01,
        -2.60711517e-02],
       [ 1.99550956e-01, -3.14251661e-01,  1.25800818e-02,
        -3.17056216e-02, -1.84959501e-01, -1.89967565e-02,
         1.46765903e-01, -7.49615878e-02, -9.53090489e-02,
         1.35244187e-02],
       [ 3.55243273e-02, -2.14189321e-01, -7.25452676e-02,
        -6.59273565e-03, -1.28400885e-02, -1.16234273e-03,
        -1.04160756e-02,  1.14769824e-01,  4.23721373e-02,
         8.35013017e-02],
       [ 1.50340393e-01, -1.05201721e-01,  6.84968680e-02,
         3.05669941e-02, -1.56205803e-01,  5.06962836e-03,
         1.52017191e-01,  4.91600484e-02, -8.51629376e-02,
        -1.06258251e-01],
       [ 1.86911039e-03, -1.51550382e-01,  1.17743239e-02,
        -8.72416645e-02, -8.68112147e-02,  3.39231119e-02,
         4.37024832e-02, -2.23021675e-02, -9.70169753e-02,
         2.56811678e-02],
       [ 1.16186142e-01, -1.02433823e-01, -6.49472773e-02,
         3.55430171e-02, -5.40682226e-02,  7.80661404e-02,
         5.08949235e-02,  1.28508121e-01, -7.03311861e-02,
        -5.55767156e-02],
       [ 1.71947837e-01, -2.27554902e-01,  1.12310968e-01,
         6.67985454e-02, -2.50421464e-03,  4.48588282e-04,
         1.31396815e-01,  4.60651964e-02, -2.85277605e-01,
        -1.54619426e-01],
       [ 9.04485509e-02, -1.53403699e-01,  4.04663347e-02,
         9.91863385e-03, -2.72045955e-02, -1.21520497e-02,
         3.10486741e-02, -2.25084275e-02, -1.05380222e-01,
         9.88293067e-03],
       [ 2.85415828e-01, -2.26620391e-01,  1.64607465e-02,
         6.41832799e-02, -1.54512525e-02,  5.29760048e-02,
         8.19145665e-02,  1.95806518e-01, -1.91747159e-01,
        -4.67150137e-02],
       [ 1.44423053e-01, -3.01142633e-01,  5.19012064e-02,
        -3.03523019e-02, -1.07664317e-01,  4.88068052e-02,
         2.54983976e-02, -1.21741220e-02, -1.48139864e-01,
         3.94036807e-02],
       [ 1.47695377e-01, -2.07773745e-01,  4.62337583e-02,
         1.31908700e-01, -4.77908961e-02,  2.77673639e-02,
         9.66634601e-03,  1.09004922e-01, -1.16734572e-01,
         6.88606501e-02],
       [ 1.03634521e-01, -2.72952735e-01,  9.33039337e-02,
         8.69573802e-02, -1.33137301e-01, -5.01692779e-02,
         6.97830170e-02, -7.55320936e-02, -1.41739622e-01,
        -6.78976849e-02],
       [ 1.06849074e-01, -1.84772417e-01,  8.95364657e-02,
         7.14784488e-02, -8.10323730e-02,  7.74582475e-02,
         6.44342899e-02, -5.20523898e-02, -1.76087245e-01,
        -4.41263206e-02],
       [ 8.01230669e-02, -1.45829022e-01,  1.27984345e-01,
        -2.54378784e-02, -1.72845364e-01, -3.67196649e-03,
         1.65524147e-02,  2.18645334e-02, -4.70805466e-02,
        -6.07568920e-02],
       [ 2.71336198e-01, -2.65550196e-01,  2.60874480e-02,
        -3.91726196e-03, -3.06420289e-02, -1.12165965e-01,
         8.77269879e-02,  1.38617188e-01, -1.80567384e-01,
        -1.23042606e-01],
       [ 1.12579219e-01, -2.32007965e-01,  8.63358527e-02,
        -2.00997815e-02, -2.17423558e-01,  1.46582089e-02,
         1.13442279e-01,  8.59252661e-02, -6.22012019e-02,
         4.07065637e-02],
       [ 1.13179669e-01, -1.43530339e-01,  6.30862936e-02,
         2.57700384e-02, -8.49790201e-02, -7.14672282e-02,
         1.56494290e-01, -4.96934503e-02, -1.42485321e-01,
        -8.69153067e-02],
       [ 1.16933331e-01, -2.19553173e-01,  5.76893166e-02,
         7.28537440e-02, -1.59825042e-01,  4.11896482e-02,
         1.74791291e-01, -6.24563508e-02, -1.73232391e-01,
        -1.36193216e-01],
       [ 1.50829941e-01, -2.25219905e-01,  1.01536565e-01,
         6.07888252e-02, -6.25048950e-02,  7.20896050e-02,
         1.06562912e-01,  4.47168946e-04, -1.20307669e-01,
        -5.42154051e-02],
       [ 1.99102953e-01, -2.48793215e-01,  7.31401667e-02,
         6.93217814e-02, -1.61027700e-01,  5.28271124e-03,
         1.13453656e-01, -3.14607993e-02, -1.07732221e-01,
        -1.08008727e-01],
       [ 1.38328165e-01, -4.22267616e-03,  1.74680427e-02,
        -3.36153060e-02, -6.78008199e-02, -3.17606032e-02,
         7.04128593e-02,  9.42409635e-02, -3.00518237e-02,
         4.27855998e-02],
       [ 2.12617606e-01, -2.99683511e-01,  1.62409842e-01,
         1.66243166e-02, -1.53720871e-01,  1.01241879e-02,
         1.48515463e-01,  1.23912930e-01, -9.20653939e-02,
        -1.20602928e-01],
       [ 1.87547475e-01, -1.44004717e-01,  4.97233085e-02,
        -2.46492457e-02, -1.54617831e-01, -3.74954231e-02,
         1.76990163e-02,  7.12629259e-02, -1.44083351e-01,
        -2.99077481e-03],
       [ 8.89945105e-02, -1.05510533e-01, -2.93245353e-02,
         1.64474137e-02, -9.23269093e-02, -5.59005886e-04,
         9.08045918e-02,  1.68739669e-02, -6.56834543e-02,
         4.58198115e-02],
       [ 1.68633655e-01, -3.15440297e-01,  9.87913385e-02,
         5.69747314e-02, -1.22252390e-01,  6.93495199e-02,
         1.78802013e-01,  1.19574219e-01, -2.01008558e-01,
        -1.30682662e-01],
       [ 1.12698339e-01, -8.39850754e-02,  1.28045306e-03,
        -4.32496741e-02, -9.08113793e-02, -2.36964636e-02,
        -3.23375352e-02,  1.59308851e-01,  6.71649799e-02,
         1.06286608e-01],
       [ 8.84962082e-02, -8.77688080e-02,  2.91030705e-02,
         1.65279955e-03, -2.81104688e-02, -1.09382160e-01,
         8.25178400e-02,  3.51370834e-02, -6.47382215e-02,
        -6.30832762e-02],
       [ 1.71750352e-01, -2.38583833e-01,  4.11464274e-02,
        -3.72004025e-02, -1.43957645e-01,  4.91107404e-02,
        -9.48989391e-03,  4.14304100e-02, -1.11518845e-01,
         2.66138837e-02],
       [ 1.30051881e-01, -1.84456408e-01,  8.72039050e-03,
        -5.94134852e-02, -1.03170224e-01,  3.01133618e-02,
         8.10889602e-02, -1.40480176e-02, -1.76957116e-01,
        -2.34120674e-02],
       [ 1.97964907e-01, -2.65603721e-01, -1.02177635e-02,
        -2.02472582e-02, -1.44539356e-01, -3.10033225e-02,
         1.25215396e-01,  3.86336520e-02, -1.76335335e-01,
         2.26367004e-02],
       [ 1.52395889e-01, -1.72641754e-01,  2.91320533e-02,
         1.69216059e-02, -4.63872030e-02, -5.21881431e-02,
         7.36403912e-02,  1.16167501e-01, -1.48858696e-01,
        -6.89277798e-03],
       [ 1.38520181e-01, -2.21183941e-01, -5.37112355e-03,
         1.28165735e-02, -1.21752948e-01, -5.14230281e-02,
         7.42741823e-02, -8.05972517e-03, -1.83900625e-01,
        -1.61556512e-01],
       [ 7.76374117e-02, -2.69544601e-01, -5.02528064e-02,
        -5.47325611e-02, -5.57238273e-02,  8.90205503e-02,
        -2.71422751e-02, -9.67182219e-03, -4.58855331e-02,
         2.22984366e-02],
       [ 1.73109919e-01, -1.86981618e-01, -1.07379034e-01,
         3.51453945e-02, -7.06118494e-02,  3.06907482e-02,
         6.51141554e-02,  3.62586752e-02, -8.01420659e-02,
        -8.44963267e-03],
       [ 1.01186842e-01, -1.03504345e-01, -1.06271729e-02,
        -1.39176100e-02, -9.08325911e-02, -6.89803511e-02,
         1.43685207e-01,  6.79614693e-02, -8.00314844e-02,
         3.09409201e-03],
       [ 1.83373183e-01, -2.88777769e-01,  1.73106775e-01,
         3.24123055e-02, -1.02930009e-01,  3.42867188e-02,
         1.20540455e-01,  3.31402868e-02, -1.85105324e-01,
        -9.63809416e-02],
       [ 6.88238293e-02, -1.40648097e-01,  2.70645432e-02,
         6.98859990e-02, -7.55100846e-02, -1.90960746e-02,
         1.25096366e-03,  1.79661065e-03, -1.02877662e-01,
        -1.51238833e-02],
       [ 1.78445086e-01, -1.35213643e-01,  2.82063894e-02,
         6.88501969e-02, -1.33073516e-02, -7.33393729e-02,
         5.36525771e-02,  4.37953472e-02, -1.48098364e-01,
         2.72138119e-02]], dtype=float32)>,
  1: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 1.27885044e-01, -2.46666372e-01,  6.10609949e-02,
         5.92685156e-02, -1.16793141e-01,  3.48639004e-02,
         1.55070916e-01,  5.65104187e-02, -1.22155793e-01,
        -4.20930646e-02],
       [ 1.15872577e-01, -1.84752017e-01, -5.71194962e-02,
        -2.12614611e-02, -9.98550504e-02,  1.10701919e-02,
         1.60223812e-01,  5.52013516e-03, -1.10808961e-01,
         3.24809998e-02],
       [ 1.78615242e-01, -2.25994438e-01,  1.19417734e-01,
        -1.42031908e-03, -6.84399307e-02,  1.91914812e-02,
         1.65670529e-01,  4.86450642e-02, -1.44869775e-01,
        -1.78900450e-01],
       [ 5.76383621e-03, -8.49595964e-02, -2.43893713e-02,
         6.89533502e-02, -5.31295221e-03, -8.38937238e-04,
         2.82113031e-02,  3.61816697e-02, -4.08098660e-02,
         2.59497818e-02],
       [ 1.76921219e-01, -3.15236866e-01,  1.10127069e-01,
        -8.51884484e-03, -1.87771827e-01,  6.48858398e-03,
         3.52341011e-02,  6.73333406e-02, -4.66764718e-02,
        -2.54447237e-02],
       [ 9.79562625e-02, -2.42872879e-01, -1.49208345e-02,
        -1.20885670e-06, -4.69806530e-02, -7.70475417e-02,
         5.06212115e-02, -6.41435087e-02, -1.86235204e-01,
        -3.10107730e-02],
       [ 7.08691403e-02, -1.16873421e-01, -9.42983851e-03,
         2.27671117e-03,  1.45658301e-02, -7.89136142e-02,
         6.70566186e-02,  1.31644204e-01, -1.49833467e-02,
         7.93657303e-02],
       [ 1.86089694e-01, -2.50324190e-01,  1.13102704e-01,
         3.89533602e-02, -1.43481359e-01, -6.45611808e-03,
         7.22613558e-02,  4.79452685e-03, -1.19379878e-01,
        -8.94017220e-02],
       [ 1.49628833e-01, -2.16814831e-01, -6.96402788e-03,
         1.93740204e-02, -3.62833068e-02, -4.06554714e-02,
         1.50012195e-01,  3.59057933e-02, -1.78931803e-01,
        -1.15502670e-01],
       [ 1.63735747e-01, -1.43615663e-01,  4.69809175e-02,
        -5.90279698e-04, -7.05487207e-02, -7.57674724e-02,
         7.80205950e-02,  6.56433553e-02, -1.64821565e-01,
        -3.91464680e-04],
       [ 1.26897953e-02, -5.01747802e-02, -5.16617298e-03,
         6.06552139e-03, -4.10623625e-02,  4.56354581e-02,
         1.76724494e-02,  2.98884623e-02,  3.63355130e-03,
         5.01044393e-02],
       [ 6.56843111e-02, -2.60651082e-01, -1.38901398e-02,
        -1.73517112e-02, -1.21665627e-01, -1.04290918e-02,
         1.27930671e-01, -9.25720483e-02, -1.58776075e-01,
        -7.63121694e-02],
       [ 8.36120993e-02, -2.00590909e-01,  4.21527401e-02,
         6.50077611e-02, -8.55429247e-02,  6.01949394e-02,
         1.33191377e-01,  4.96718735e-02, -1.61278158e-01,
        -5.53168356e-04],
       [ 6.77295700e-02, -2.03086168e-01,  8.10615718e-03,
        -2.12259516e-02, -1.39885053e-01,  1.17493905e-02,
         8.93364996e-02,  3.86537649e-02, -5.11981249e-02,
         2.95043886e-02],
       [ 9.55446064e-02, -1.63773730e-01, -7.08613992e-02,
         2.62736268e-02, -7.58625865e-02, -2.06121802e-03,
         1.04235694e-01,  7.80706555e-02, -7.82580078e-02,
         3.35763246e-02],
       [-2.68600509e-03, -7.47267455e-02,  3.05374712e-03,
         2.12906729e-02, -2.69198082e-02, -8.48305821e-02,
         9.39091668e-02,  4.27491181e-02,  5.97379357e-03,
         1.54054165e-03],
       [ 3.16542149e-01, -3.66938829e-01,  9.62305665e-02,
        -3.13103572e-03, -7.22732991e-02,  7.06394017e-03,
         1.17471784e-01,  6.95461631e-02, -2.31973529e-01,
        -1.68918461e-01],
       [ 1.89187884e-01, -3.91187459e-01,  4.67585213e-02,
         8.95505399e-03, -1.43740714e-01, -5.67009896e-02,
         7.82376975e-02, -3.87473516e-02, -1.66953027e-01,
         5.54143749e-02],
       [ 3.83565985e-02, -1.45506427e-01, -4.69468199e-02,
        -3.17817926e-03, -6.44920021e-02,  2.83019096e-02,
        -3.06543037e-02,  1.64629489e-01,  1.73046105e-02,
         3.22597697e-02],
       [ 5.32256514e-02, -1.92764401e-01,  1.59556121e-02,
         1.05027430e-01, -6.49360567e-02, -3.22996527e-02,
         8.86313990e-02,  3.03415172e-02, -1.77735567e-01,
        -1.91540644e-02],
       [ 1.89158231e-01, -1.87466145e-01, -1.91907883e-02,
         2.67856568e-03, -5.02583720e-02, -8.62729922e-02,
         1.12799689e-01,  1.70333266e-01, -5.92950433e-02,
         1.18697166e-01],
       [ 1.99711189e-01, -1.12967432e-01,  7.96890184e-02,
         6.03392050e-02,  3.79058048e-02,  9.56646353e-03,
         1.51658617e-02,  5.44389337e-03, -1.81115925e-01,
        -6.87612668e-02],
       [ 7.68725425e-02, -2.67590612e-01,  4.63806167e-02,
         1.10034347e-01, -1.72406077e-01,  6.25213236e-02,
        -3.97406295e-02,  1.06628135e-01, -2.36582905e-02,
         3.07434909e-02],
       [ 6.65293038e-02, -1.82442263e-01,  5.69735989e-02,
         2.18663830e-02, -6.48378283e-02, -6.76428527e-02,
         1.39767185e-01, -7.41331056e-02, -1.65435642e-01,
        -9.70883518e-02],
       [ 2.01058596e-01, -2.61317432e-01, -1.84001699e-02,
         6.33071959e-02, -2.67550349e-03, -1.38270892e-02,
         1.18393995e-01,  9.25516039e-02, -1.43898085e-01,
        -1.81777164e-01],
       [ 2.66967833e-01, -2.88913816e-01,  9.67910215e-02,
        -3.84363271e-02, -1.85001448e-01,  3.13463435e-02,
         1.20091312e-01,  4.58185002e-02, -1.53159901e-01,
        -2.07820572e-02],
       [ 9.00611654e-02, -1.31268799e-01, -3.55406478e-03,
        -7.53392875e-02, -6.13245592e-02, -7.31819123e-03,
         7.10753649e-02,  7.09620118e-02, -8.13134611e-02,
         9.47821811e-02],
       [ 1.93869412e-01, -2.27815628e-01,  9.92562696e-02,
         3.57618257e-02, -5.27647361e-02, -8.50399882e-02,
         7.01466426e-02,  6.12566397e-02, -1.26461953e-01,
        -1.15291402e-01],
       [ 3.62229533e-02, -9.56141502e-02,  6.54015392e-02,
         5.92482425e-02,  6.29962981e-03, -1.02086797e-01,
         9.97735411e-02, -1.57521851e-03, -5.39299026e-02,
        -5.78723028e-02],
       [-3.74699943e-03, -9.68191177e-02,  2.20198110e-02,
         5.78998737e-02, -2.86683831e-02, -2.87471954e-02,
         8.07546973e-02,  1.06133707e-02, -2.48439610e-03,
        -3.79150473e-02],
       [ 6.77859634e-02, -1.86870977e-01,  3.07906345e-02,
        -1.07684545e-02, -6.36681765e-02, -8.39179978e-02,
         9.85609591e-02,  1.28829136e-01, -5.84729463e-02,
         5.23564368e-02],
       [ 7.90519267e-02, -5.69423065e-02, -3.77275795e-03,
         6.29761219e-02, -4.18563001e-02, -7.02205151e-02,
         6.42244965e-02,  3.84026617e-02, -6.34890199e-02,
        -1.45601146e-02],
       [ 3.99283431e-02, -1.67437077e-01,  5.26406020e-02,
         9.17010084e-02, -1.04780704e-01, -2.69896705e-02,
         4.59531210e-02, -1.88221671e-02, -5.26715294e-02,
         1.33575667e-02],
       [ 1.64025575e-01, -8.50998834e-02,  1.93232782e-02,
        -6.94485530e-02, -1.01411864e-01, -2.13257670e-02,
         9.18813199e-02,  1.11052282e-01, -6.62977323e-02,
         7.83713162e-03],
       [ 4.78624664e-02, -6.54663891e-02, -1.04190670e-02,
         2.60893814e-02, -2.12309733e-02, -1.55553408e-02,
         5.87825105e-02,  3.11466493e-02, -7.78730214e-03,
        -7.10255280e-03],
       [ 2.66455829e-01, -3.26354086e-01, -3.67345028e-02,
         6.28823489e-02, -1.37245595e-01, -2.08172575e-03,
         9.50635821e-02, -3.36158499e-02, -2.23693013e-01,
        -7.32927471e-02],
       [ 9.71049517e-02, -1.49090379e-01, -1.67229921e-02,
        -1.99327189e-02, -3.14421989e-02, -6.04505464e-02,
         1.28427938e-01,  5.42176515e-03, -1.67008087e-01,
        -1.57183595e-02],
       [ 3.02089453e-01, -3.65109324e-01,  1.32298410e-01,
         5.61413728e-02, -8.09673443e-02,  1.22207098e-01,
         3.69112641e-02,  7.63800740e-02, -1.84713438e-01,
        -9.88921598e-02],
       [ 1.06831625e-01, -1.52416468e-01, -2.29629055e-02,
         3.14386450e-02, -4.54321057e-02, -3.33097875e-02,
        -5.76151162e-03,  2.20365137e-01,  1.56667735e-02,
         2.68171690e-02],
       [ 1.36580437e-01, -2.46003568e-01,  1.02501661e-01,
         1.32048670e-02, -9.05340537e-02, -6.83012828e-02,
         1.27900183e-01,  5.57735525e-02, -1.51475564e-01,
         6.83141127e-03],
       [ 1.39515564e-01, -1.62621498e-01, -3.87455821e-02,
         6.51934445e-02, -3.19160707e-02,  1.45661831e-03,
         7.68807530e-02,  7.29631111e-02, -7.86662400e-02,
        -1.49219409e-02],
       [ 6.13788962e-02, -2.69548178e-01,  1.18900515e-01,
         1.12737730e-01, -8.42165053e-02, -4.28909063e-03,
         6.08275607e-02, -5.35333529e-03, -1.32040590e-01,
        -2.14897394e-02],
       [ 1.29430592e-01, -1.69879854e-01,  5.08677363e-02,
        -6.63570538e-02, -1.27193436e-01, -3.16529535e-02,
         1.52301386e-01,  1.29514225e-02, -1.70265734e-01,
        -5.10287769e-02],
       [ 1.67749554e-01, -1.68682814e-01, -1.97422504e-02,
         4.16205376e-02, -6.93296045e-02, -2.42651030e-02,
         1.10124059e-01,  4.05656844e-02, -1.39969587e-01,
        -7.01349080e-02],
       [-1.04028881e-02, -5.56345582e-02, -1.66911036e-02,
         2.40713861e-02,  4.15784121e-03,  4.71809283e-02,
         1.90047584e-02, -1.10308323e-02, -2.00496763e-02,
         4.83327433e-02],
       [ 2.53342360e-01, -3.61683875e-01,  8.94631967e-02,
         1.34667251e-02, -7.44774342e-02,  5.18039986e-02,
         1.33576021e-01, -2.20271200e-02, -3.30656826e-01,
        -1.24956459e-01],
       [ 1.24488942e-01, -1.94048434e-01,  1.12591207e-01,
        -4.95741591e-02, -1.30783647e-01, -8.21867585e-02,
         1.67381003e-01,  6.83536306e-02, -1.19620122e-01,
        -5.15734553e-02],
       [ 5.60863428e-02, -2.13165015e-01,  9.02202725e-03,
         1.54780403e-01, -3.49860117e-02,  1.05738103e-01,
         8.18760991e-02, -5.99314719e-02, -1.36598170e-01,
        -1.28377333e-01],
       [ 6.91584870e-02, -1.32375687e-01,  7.13182688e-02,
        -2.55982019e-03, -1.32133961e-01, -6.03656173e-02,
         1.57119706e-01,  3.29462886e-02, -3.85527164e-02,
         2.84372568e-02],
       [ 1.27729736e-02, -1.07306823e-01,  9.01227444e-03,
        -1.20804552e-02, -8.12134147e-02, -3.42096668e-03,
         7.89515451e-02, -3.02556269e-02, -3.10066342e-02,
         2.42812596e-02],
       [ 2.20786184e-01, -2.25777328e-01,  8.15835893e-02,
        -1.41532868e-02, -9.14084837e-02,  6.05455711e-02,
         1.93735033e-01,  5.09279817e-02, -1.43708259e-01,
        -1.08355373e-01],
       [ 3.54668722e-02, -1.22292228e-01, -6.29182011e-02,
        -1.59568451e-02, -4.45853919e-03, -4.20350581e-03,
         9.75034386e-02,  1.22882031e-01,  2.62191668e-02,
         9.45737362e-02],
       [ 1.29813582e-01, -5.81891872e-02,  4.99568135e-03,
        -1.68870073e-02, -5.77754118e-02, -1.09626323e-01,
         1.72958165e-01,  4.13768515e-02, -1.25800118e-01,
        -3.27996835e-02],
       [ 6.20602444e-03, -1.06919788e-01,  1.09983012e-02,
        -4.35637608e-02, -4.96682003e-02,  1.58199333e-02,
        -4.13202494e-03,  1.34374648e-01, -1.75401121e-02,
         7.53175467e-02],
       [ 2.79235691e-01, -3.46116304e-01,  1.17301166e-01,
         3.98727134e-03, -6.18337058e-02,  7.98800588e-02,
         6.40892535e-02,  5.22308759e-02, -2.23644450e-01,
        -6.86144382e-02],
       [ 1.69425577e-01, -2.07581669e-01,  1.12285420e-01,
         8.75772983e-02, -1.01843439e-01,  6.67120516e-02,
         7.06640258e-02,  5.69561459e-02, -1.47610158e-01,
        -1.20118290e-01],
       [ 1.14333220e-01, -6.84894919e-02,  8.16385597e-02,
         6.77842647e-02, -1.03860795e-01, -3.16225737e-02,
         1.53396308e-01,  3.70394625e-02, -1.35988802e-01,
        -1.01624839e-01],
       [ 2.92404115e-01, -3.93845558e-01,  5.71674928e-02,
        -2.51131058e-02, -1.60410091e-01,  3.09475474e-02,
         7.53643289e-02,  2.51288787e-02, -1.54129967e-01,
        -6.94034472e-02],
       [ 8.85679126e-02, -1.08807750e-01,  5.58503717e-02,
        -9.01477598e-03, -1.06419198e-01, -4.48017754e-02,
         1.28318310e-01,  8.32917094e-02, -6.37922436e-02,
        -2.29686238e-02],
       [ 4.63552587e-03, -1.37431130e-01, -3.59577984e-02,
         2.88817808e-02,  1.93121880e-02, -4.39537615e-02,
         3.42130065e-02,  7.91419670e-02, -8.01757276e-02,
         4.19560373e-02],
       [ 8.51974785e-02, -1.26128271e-01,  4.91543673e-02,
         2.47586817e-02, -1.44599870e-01,  3.98865119e-02,
         1.51983023e-01,  3.72571349e-02, -1.00875258e-01,
        -5.67118041e-02],
       [ 4.93021831e-02, -5.30348122e-02,  2.32948065e-02,
         1.18169617e-02,  1.75193902e-02,  1.99133325e-02,
         3.15895304e-03, -3.82575206e-02, -9.85571146e-02,
        -3.87357771e-02],
       [ 4.37819548e-02, -1.29934117e-01, -3.73580828e-02,
         1.53276622e-02, -2.25356761e-02, -2.21195221e-02,
         1.17225915e-01, -3.14791426e-02, -1.28002509e-01,
         7.71386921e-03],
       [ 1.82949141e-01, -1.19327851e-01,  4.90520149e-03,
        -8.00369978e-02, -5.71436062e-02, -5.69989383e-02,
         7.82369524e-02,  4.79414277e-02, -6.31392896e-02,
        -4.17428203e-02]], dtype=float32)>,
  2: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 1.99086890e-02, -1.18299082e-01, -5.09649515e-04,
         6.69284016e-02, -8.31833929e-02,  2.47789174e-02,
         1.09678909e-01, -6.57006949e-02, -6.55488595e-02,
        -1.36098862e-01],
       [ 1.49800882e-01, -2.26246357e-01,  1.61927938e-03,
         1.80277340e-02, -7.24180490e-02,  4.12681997e-02,
         4.12922613e-02,  8.69020075e-02, -1.77083746e-01,
         5.42148575e-03],
       [ 1.25505716e-01, -1.46150410e-01, -7.57971033e-03,
        -1.04201481e-01, -1.05630115e-01, -2.59995013e-02,
         1.24512076e-01,  5.85957877e-02, -7.12437481e-02,
         5.04249185e-02],
       [ 2.30737314e-01, -3.29748571e-01,  1.10920988e-01,
         5.40469140e-02, -9.02851447e-02,  5.63502088e-02,
         5.74141629e-02,  5.98804355e-02, -2.30635077e-01,
        -1.15260467e-01],
       [ 1.08695909e-01, -1.05851710e-01, -3.79128158e-02,
         4.51071262e-02, -2.87510157e-02,  4.73653153e-02,
        -1.29062757e-02,  5.98394498e-03, -1.00353345e-01,
        -1.17310762e-01],
       [ 1.61498696e-01, -3.26610357e-01,  8.05005133e-02,
         1.71714544e-01, -9.03970897e-02,  8.17581192e-02,
         9.75183249e-02, -7.28197768e-02, -1.07868716e-01,
        -1.71384990e-01],
       [ 1.21108338e-01, -1.31294638e-01, -3.59654129e-02,
         7.17959553e-02,  1.44487135e-02, -3.97857912e-02,
         6.32279068e-02,  1.38185635e-01, -1.74389094e-01,
         1.69183258e-02],
       [ 1.00826427e-01, -1.62813425e-01,  6.60581887e-02,
         6.85117096e-02, -1.16659135e-01, -2.51427125e-02,
         3.96192186e-02,  4.79074419e-02, -1.85138136e-02,
        -1.91030093e-02],
       [ 2.86910180e-02, -9.00120884e-02,  3.93116102e-02,
         2.58334447e-02, -5.80003299e-02, -2.57440098e-03,
         6.70223534e-02, -2.94831172e-02, -2.25386396e-02,
        -1.40298223e-02],
       [ 7.66111538e-02, -6.82412758e-02,  3.55873182e-02,
         2.25770194e-02, -2.21704319e-02, -5.50767295e-02,
         1.16521202e-01, -3.95722687e-04, -4.66949865e-02,
        -6.52008131e-02],
       [ 1.90626800e-01, -1.78208202e-01, -1.17898956e-02,
         4.52186540e-03, -6.41462654e-02, -2.84481235e-02,
         1.91934071e-02,  1.31951049e-02, -8.37352425e-02,
        -1.46341110e-02],
       [ 1.08117163e-01, -1.47211328e-01,  6.64647594e-02,
         1.30034126e-02, -5.92802688e-02, -5.52572459e-02,
         6.90619126e-02,  1.21193103e-01, -9.70066115e-02,
         2.86763571e-02],
       [ 1.58437029e-01, -2.21591622e-01,  9.35730338e-03,
         2.01131757e-02, -1.37784064e-01, -5.62573820e-02,
         7.02252537e-02,  4.55708057e-02, -1.86293170e-01,
        -8.43652040e-02],
       [ 1.67793900e-01, -1.52516231e-01,  7.15897903e-02,
         7.57010132e-02, -4.93664257e-02, -4.49585579e-02,
         9.89307389e-02, -5.60925677e-02, -1.47269905e-01,
        -1.29956588e-01],
       [ 1.52405053e-01, -2.64702678e-01,  2.32688636e-02,
         1.58944249e-01, -7.74601549e-02,  1.07911929e-01,
        -5.37938625e-03,  7.31428117e-02, -4.76112515e-02,
        -1.64536983e-02],
       [ 2.76000202e-01, -2.64147282e-01, -5.51072061e-02,
         6.10502101e-02, -1.12358376e-01,  1.03120834e-01,
         1.09085605e-01,  6.32375106e-02, -8.90948921e-02,
        -1.08196996e-02],
       [ 1.60287946e-01, -1.70183644e-01,  3.88289988e-03,
         4.69395965e-02, -3.31258662e-02, -1.21399462e-02,
         1.24135226e-01, -4.24484164e-02, -2.25743502e-01,
        -1.64331049e-01],
       [ 1.23702049e-01, -1.06801137e-01,  3.85086276e-02,
        -5.94215766e-02, -9.24396738e-02, -4.75805849e-02,
         1.53745770e-01,  1.31310895e-03, -1.50655419e-01,
         4.11138125e-02],
       [ 5.75363711e-02, -1.52460605e-01, -1.93905830e-02,
         2.93618478e-02, -1.06651574e-01, -3.39352489e-02,
         6.63267449e-02,  1.82254463e-02, -5.50914742e-02,
         1.51493698e-02],
       [ 1.51542366e-01, -4.98135909e-02, -5.07678464e-03,
        -7.93997049e-02, -1.22898512e-01, -5.74200675e-02,
         1.80150792e-01,  4.02402654e-02, -7.00754970e-02,
        -2.99639255e-03],
       [ 2.93729872e-01, -2.91988134e-01,  9.83853713e-02,
         1.80692673e-02, -4.35825661e-02,  7.16883615e-02,
         1.02651522e-01,  6.21555373e-02, -1.80277765e-01,
        -1.56922981e-01],
       [ 1.27452701e-01, -2.61504441e-01, -5.84305823e-03,
        -1.30307376e-02, -6.11604080e-02, -3.52911763e-02,
         5.25013320e-02, -3.22515741e-02, -1.68141842e-01,
        -5.05033806e-02],
       [ 1.27830327e-01, -1.80839926e-01,  6.05587140e-02,
         1.19829495e-02, -1.18314162e-01, -8.33228603e-03,
         6.06215522e-02,  4.40728031e-02, -2.48889476e-02,
        -4.71421778e-02],
       [ 8.64849053e-03, -9.48932394e-02, -5.82862273e-02,
        -4.62751463e-02, -8.91413763e-02,  1.20226936e-02,
         1.01917088e-01,  1.11679949e-01,  4.31743488e-02,
         1.29187196e-01],
       [ 1.54470444e-01, -2.23193020e-01,  3.34921740e-02,
         8.71215910e-02, -1.23471543e-01,  2.43844390e-02,
         1.35603428e-01, -1.69654042e-02, -1.51649848e-01,
        -4.61355001e-02],
       [ 1.37748018e-01, -2.43970215e-01,  2.16891691e-02,
         8.76856968e-02, -1.83750421e-01,  3.35105993e-02,
         1.58057511e-02,  1.21766001e-01,  2.04236656e-02,
        -2.83812769e-02],
       [ 6.75756857e-02, -1.15049757e-01, -3.97734344e-03,
        -2.23360956e-02, -1.18047900e-01, -1.68771297e-03,
         1.28711045e-01,  9.10824910e-03, -1.23794109e-01,
        -4.14184704e-02],
       [ 1.13439873e-01, -2.63516515e-01,  7.00936541e-02,
         7.42653199e-03, -8.36165398e-02, -2.55119838e-02,
         4.82222177e-02, -9.47414637e-02, -1.97409734e-01,
        -3.53264213e-02],
       [ 1.66785255e-01, -2.42981404e-01, -5.23246825e-03,
        -1.90869048e-02, -2.03238517e-01, -1.80280320e-02,
         1.55341566e-01,  6.38855249e-02, -7.29285181e-02,
        -4.53866944e-02],
       [ 1.35256872e-01, -2.15336680e-01,  3.52459066e-02,
        -3.14351581e-02, -1.27391279e-01,  2.34926231e-02,
         6.31493554e-02, -9.79364663e-03, -5.98913580e-02,
         1.72683019e-02],
       [ 2.36514777e-01, -3.06144297e-01,  4.72223908e-02,
        -2.51390524e-02, -1.67270333e-01,  2.09238864e-02,
         1.78542882e-01,  5.76151311e-02, -1.69851169e-01,
        -7.04906434e-02],
       [ 1.47906214e-01, -1.27977043e-01,  1.17616639e-01,
        -5.50510548e-03, -6.39014840e-02, -3.38953957e-02,
         8.13580453e-02, -1.98115110e-02, -9.37554762e-02,
        -1.23155683e-01],
       [ 1.58563465e-01, -3.11379015e-01,  6.05554879e-02,
         7.10094646e-02, -1.52761281e-01,  3.97607125e-02,
         3.71359177e-02, -4.77026179e-02, -5.41603491e-02,
        -6.91294968e-02],
       [ 2.26677850e-01, -3.45731795e-01,  9.72678512e-02,
        -5.66386431e-02, -4.58341986e-02, -4.90778051e-02,
         1.54168665e-01,  1.15580484e-02, -2.15630978e-01,
        -1.53614283e-01],
       [ 1.21113941e-01, -7.23675042e-02,  4.52256203e-03,
        -7.06283897e-02, -2.52325498e-02, -2.70105675e-02,
         2.09800780e-01,  9.54700634e-02, -6.90466315e-02,
         1.14207581e-01],
       [ 2.37637579e-01, -2.34120190e-01,  1.18098080e-01,
        -1.10259205e-01, -2.41631091e-01, -3.07580531e-02,
         1.32344872e-01,  8.59988928e-02, -1.18466549e-01,
        -5.68656027e-02],
       [ 1.89389080e-01, -2.59211004e-01,  9.40684304e-02,
         6.50911033e-02, -1.64005876e-01,  7.09105134e-02,
         1.31057054e-01, -5.50926812e-02, -2.69368976e-01,
        -2.92787217e-02],
       [ 9.15603042e-02, -7.48736411e-02, -2.05843896e-03,
        -1.37888446e-01, -1.17194057e-01,  3.80388945e-02,
        -1.01942569e-02,  2.49259651e-01,  3.29918936e-02,
         1.48468792e-01],
       [-5.40000200e-03, -1.58397853e-01,  3.28571945e-02,
        -6.00184798e-02, -5.84385246e-02, -3.49122137e-02,
         1.65357620e-01,  1.22921243e-01, -7.53659904e-02,
         1.14809312e-02],
       [ 2.06552312e-01, -3.35975409e-01,  2.50989646e-02,
         5.21945432e-02, -1.07146963e-01,  3.47564742e-02,
         8.42617452e-02, -1.37068033e-02, -1.90948009e-01,
        -6.02461249e-02],
       [ 1.73392743e-01, -1.18123904e-01,  4.63536307e-02,
        -1.86695531e-02, -1.87901650e-02, -8.34226310e-02,
         8.43422711e-02,  7.80861080e-02, -1.22561961e-01,
         2.46241130e-02],
       [ 3.17839086e-01, -3.50495458e-01,  9.96849313e-02,
        -4.37931791e-02, -8.02193508e-02,  4.93139438e-02,
         1.01354539e-01,  6.51845038e-02, -2.24032074e-01,
        -1.70172825e-01],
       [ 2.13741973e-01, -2.05324426e-01, -2.04442665e-02,
         3.19967382e-02, -1.22797407e-01,  8.48451853e-02,
         3.68535519e-02,  7.88279027e-02, -1.06538005e-01,
        -4.56362218e-02],
       [ 1.03479452e-01, -1.43576637e-01,  3.80221903e-02,
        -6.23773746e-02, -1.78798050e-01,  3.95287722e-02,
         6.69063479e-02,  1.92946434e-01,  2.98695341e-02,
         2.34640911e-02],
       [ 1.03736326e-01, -1.55367523e-01, -2.82113999e-03,
         3.94254923e-02, -7.02259839e-02,  3.85735221e-02,
         7.37018883e-02, -6.25223666e-03, -1.25918329e-01,
        -7.64901042e-02],
       [ 6.05488196e-02, -1.35546595e-01,  7.41246343e-02,
         2.82055791e-02, -2.44041197e-02,  6.75846934e-02,
         2.11438071e-02,  6.26247972e-02, -9.91896391e-02,
        -5.11676259e-02],
       [ 6.40958101e-02, -1.62868008e-01, -2.61689574e-02,
        -7.91574270e-03, -1.63719013e-01,  1.48498733e-02,
        -2.39525214e-02,  2.54111052e-01,  6.73066974e-02,
         1.24313466e-01],
       [ 2.63601363e-01, -2.67279714e-01,  6.18229955e-02,
        -1.42678246e-02, -1.33891374e-01,  2.78969556e-02,
         3.30321267e-02,  8.87694508e-02, -1.14312038e-01,
        -3.81142832e-02],
       [ 7.76631907e-02, -8.23275149e-02,  1.65616199e-02,
         7.23844692e-02, -7.70044103e-02, -1.76761039e-02,
         7.57018626e-02, -1.91767737e-02, -2.11364478e-02,
        -1.93126928e-02],
       [ 1.03861272e-01, -7.95281753e-02,  1.11505449e-01,
        -3.01373769e-02, -3.14256474e-02,  1.21443691e-02,
         2.26334594e-02,  5.83403297e-02, -5.78931980e-02,
        -2.71143392e-03],
       [ 1.31511033e-01, -3.17864388e-01, -5.23106605e-02,
         5.28486371e-02, -1.11975744e-01,  2.37966832e-02,
         1.08769163e-01,  2.01898962e-02, -1.77653655e-01,
        -6.46678656e-02],
       [ 9.42641050e-02, -1.87936381e-01,  9.70203727e-02,
         5.08598238e-02, -1.11506939e-01,  7.63423145e-02,
        -1.41387582e-02,  8.06879252e-04, -2.64787935e-02,
         6.68242872e-02],
       [ 5.84154874e-02, -7.41206557e-02, -4.30397689e-03,
         2.63823792e-02, -8.86534154e-02,  2.24945322e-03,
         6.24449924e-02,  4.91791219e-03,  1.29501894e-02,
         4.14403453e-02],
       [ 1.77251548e-01, -3.28018725e-01,  4.21758294e-02,
         1.18176624e-01, -1.66284680e-01,  2.46274024e-02,
         9.90192741e-02,  5.68202361e-02, -1.04375720e-01,
        -8.44068378e-02],
       [ 7.16061592e-02, -5.91352955e-02,  1.26830488e-03,
        -1.58589184e-02, -5.89829274e-02,  1.73955578e-02,
         4.31577414e-02,  9.27123055e-03, -8.12586397e-03,
        -5.05773239e-02],
       [ 5.34252450e-02, -1.15689509e-01, -3.85283269e-02,
         3.44470218e-02, -5.64463772e-02, -2.66508162e-02,
         1.03828378e-01, -5.78810386e-02, -3.45733836e-02,
         1.77112222e-02],
       [ 6.96563944e-02, -7.28312433e-02,  3.85822356e-02,
        -1.12106353e-01, -1.84130706e-02, -4.71508503e-02,
         9.40193087e-02,  1.32382125e-01, -4.03798744e-02,
         6.63202256e-02],
       [ 1.20616466e-01, -1.76006392e-01,  6.06995970e-02,
        -1.19535923e-02, -1.25546023e-01, -6.73555061e-02,
         5.25750034e-02, -5.85914403e-03, -1.42149031e-01,
        -1.35314688e-01],
       [ 4.45430093e-02, -3.03276572e-02,  4.94690612e-02,
         6.06487319e-03,  2.94805840e-02, -2.37799305e-02,
         7.84643888e-02,  6.73769116e-02, -5.73738292e-03,
        -3.27878967e-02],
       [ 5.99569604e-02, -1.74194366e-01,  1.31788552e-02,
        -2.48489417e-02, -9.50946584e-02, -7.07104653e-02,
         7.19238669e-02,  3.04890126e-02, -6.06462061e-02,
         6.85882568e-02],
       [ 9.53177810e-02, -2.12965280e-01,  3.85214463e-02,
        -9.70859732e-03, -1.46338463e-01,  4.84607816e-02,
         7.36515149e-02,  7.81962872e-02, -1.03356630e-01,
        -4.61516678e-02],
       [ 1.48165345e-01, -2.38766402e-01,  4.59170938e-02,
         1.19222172e-01, -1.02554969e-02, -8.06974247e-02,
         6.40274435e-02, -4.92060259e-02, -2.28516504e-01,
        -1.18880026e-01],
       [ 1.42433986e-01, -2.25945935e-01,  9.52820629e-02,
         2.56884135e-02, -2.11920142e-01, -3.51343043e-02,
        -4.46302965e-02,  1.21341519e-01, -3.68189067e-02,
         3.29084694e-04],
       [ 1.91503465e-01, -1.10098489e-01,  8.38430524e-02,
         5.19769341e-02, -8.84517953e-02, -7.78743327e-02,
         1.43743396e-01,  1.27910435e-01, -1.10260636e-01,
        -2.67593414e-02]], dtype=float32)>,
  3: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 1.55930191e-01, -5.95854372e-02, -9.51537341e-02,
         1.97939035e-02, -9.13705975e-02,  1.78810954e-02,
         9.00360718e-02,  1.57191977e-01, -5.36973551e-02,
        -2.19843406e-02],
       [ 1.54073685e-01, -1.11792572e-01,  7.94308037e-02,
         6.68642893e-02, -9.13412273e-02, -4.13990654e-02,
         1.49329126e-01,  1.40528288e-02, -1.55709594e-01,
        -1.34351447e-01],
       [ 8.02544057e-02, -4.78572883e-02, -1.13128200e-02,
        -7.43905902e-02, -9.03891623e-02, -7.05615431e-03,
         1.81483999e-02,  1.69111013e-01,  8.17810372e-03,
         1.08657256e-01],
       [ 1.50779009e-01, -1.11277133e-01,  2.08596773e-02,
         5.19825071e-02, -2.70946864e-02, -3.66631709e-02,
         3.98920104e-02,  1.13958403e-01, -1.25575081e-01,
        -7.60315545e-03],
       [ 1.00667626e-01, -1.49804890e-01,  1.70425549e-02,
        -1.07546836e-01, -5.46016768e-02, -6.32741451e-02,
         1.40364215e-01,  2.67757848e-02, -1.12272836e-01,
         7.66827688e-02],
       [ 1.36238098e-01, -1.16187342e-01,  1.49570182e-02,
        -6.35659248e-02, -1.55948456e-02, -3.29913758e-02,
         4.38953564e-02,  3.00220884e-02, -7.26266503e-02,
        -6.00622706e-02],
       [ 2.70235986e-02, -8.04229230e-02,  2.65898779e-02,
         1.96443871e-02, -4.58556116e-02,  5.73398732e-03,
         5.42063825e-02, -3.15506905e-02, -2.60907337e-02,
         8.53541493e-03],
       [ 9.91104096e-02, -2.05552399e-01, -1.54858157e-02,
        -6.51729703e-02, -1.83459237e-01,  5.88360205e-02,
         1.10120043e-01,  4.84671742e-02, -6.05430305e-02,
        -3.97218671e-03],
       [ 1.07383639e-01, -1.24965347e-01,  9.19724330e-02,
         4.51150239e-02, -9.35338885e-02,  9.76894796e-03,
         2.09916860e-01, -4.65418324e-02, -1.29753783e-01,
        -8.11312348e-02],
       [ 1.28446102e-01, -1.72700167e-01, -2.50982046e-02,
         7.19068721e-02, -8.59898254e-02,  5.80850355e-02,
         1.36235237e-01,  2.79696658e-03, -3.30999047e-02,
        -8.43722969e-02],
       [ 2.68472165e-01, -3.69969964e-01,  1.41037613e-01,
         2.03329325e-03, -1.68444961e-01,  5.90009019e-02,
        -4.02991474e-02,  4.23120111e-02, -1.59692869e-01,
        -8.96255970e-02],
       [ 1.02896094e-01, -2.58697033e-01,  1.06979609e-01,
         1.39155686e-01, -1.94118023e-01, -2.57306900e-02,
         3.11022103e-02,  3.91111933e-02, -6.82359040e-02,
        -3.41641009e-02],
       [ 1.31178096e-01, -1.59546852e-01,  7.70823807e-02,
         1.81411933e-02, -6.87068254e-02, -6.23967685e-02,
         1.47092253e-01, -4.01412472e-02, -2.02124089e-01,
        -1.37845933e-01],
       [ 1.38894588e-01, -2.13477105e-01,  4.71255518e-02,
        -5.43162525e-02, -1.53828293e-01, -4.50345762e-02,
         4.50511239e-02,  1.82030529e-01, -1.33984908e-03,
         6.68019354e-02],
       [ 1.04664735e-01, -2.13843167e-01,  6.29136413e-02,
         3.90464254e-02, -1.84314921e-01,  4.68939729e-02,
         8.27086270e-02,  5.09214476e-02, -3.93434167e-02,
         2.72851288e-02],
       [ 1.39909580e-01, -1.31682649e-01,  8.12229738e-02,
        -4.39220853e-03, -8.16902518e-02, -8.53725150e-02,
         8.10373276e-02,  1.37043729e-01, -6.64164424e-02,
         1.60408653e-02],
       [ 1.32347941e-01, -8.47300142e-02, -4.52096872e-02,
        -2.36382578e-02, -5.90682402e-02,  2.02781111e-02,
         1.01454005e-01,  4.70062606e-02, -3.02681550e-02,
         6.19654879e-02],
       [-5.90272434e-03, -1.08653575e-01, -3.47647294e-02,
         1.92391574e-02, -8.56941752e-03, -3.29914428e-02,
        -4.60869893e-02,  7.30581433e-02,  5.85135333e-02,
         3.59948725e-03],
       [ 1.08222775e-01, -1.42150968e-01,  3.03277150e-02,
         8.75964984e-02,  5.18569350e-03,  8.27972814e-02,
         5.45808449e-02, -3.50898132e-02, -1.44389808e-01,
        -1.01724833e-01],
       [ 3.21439356e-01, -3.49961400e-01,  8.97648707e-02,
         1.03238970e-02, -8.05300921e-02,  1.19192451e-02,
         9.43127871e-02,  9.36765894e-02, -1.49869040e-01,
        -8.41590315e-02],
       [ 9.65263993e-02, -1.40845105e-01,  1.35765299e-02,
         4.90715206e-02, -9.92955267e-02, -3.73904109e-02,
         1.34406179e-01,  6.81181252e-03, -2.35903710e-02,
        -2.95076612e-02],
       [ 2.29147524e-01, -3.44276696e-01,  1.10413320e-01,
        -5.52750826e-02, -2.09450603e-01,  3.07433940e-02,
         8.77569765e-02,  9.72212553e-02, -9.36580747e-02,
        -2.00432166e-02],
       [ 8.46759602e-02, -2.18302935e-01,  6.88344389e-02,
         7.67964348e-02, -1.60326988e-01,  1.33434683e-03,
         1.61436498e-01,  1.17654540e-02, -1.37885869e-01,
        -3.99693362e-02],
       [ 1.75176054e-01, -1.44516990e-01,  6.13314919e-02,
         1.35606229e-02, -6.50275499e-03, -1.07311144e-01,
         5.89473248e-02,  4.66354340e-02, -1.43664867e-01,
        -3.24210301e-02],
       [ 1.78989083e-01, -2.66841561e-01,  1.04531519e-01,
        -3.27354930e-02, -1.39905661e-01,  4.29677032e-02,
         1.29277363e-01, -4.66909334e-02, -1.98185280e-01,
        -1.69813186e-02],
       [ 7.97410458e-02, -1.56347916e-01, -1.69430524e-02,
         4.66718711e-02, -1.14518210e-01, -4.85013984e-03,
         1.20978072e-01, -2.11730599e-04, -1.02752246e-01,
        -2.04447843e-02],
       [ 1.71688974e-01, -1.60856083e-01,  8.17834884e-02,
         2.27832720e-02, -1.59954980e-01, -1.38849355e-02,
         6.19734973e-02,  2.55999975e-02, -1.17938437e-01,
        -3.21948230e-02],
       [ 2.36310750e-01, -3.33596766e-01,  3.12455595e-02,
         1.01563901e-01, -5.83848730e-02,  8.75023454e-02,
        -3.87299061e-02,  1.02197260e-01, -1.32033169e-01,
        -8.75383914e-02],
       [ 2.54810359e-02, -3.70554924e-02,  4.45615798e-02,
         1.75250489e-02,  8.44433904e-02,  5.10861538e-02,
         9.42382142e-02, -7.45707229e-02, -1.40686348e-01,
        -4.69042547e-02],
       [ 1.01537056e-01, -2.10425407e-01,  1.45389378e-01,
         1.85685791e-02, -1.21420637e-01,  6.05402067e-02,
        -1.05622701e-01,  1.00927658e-01, -9.95251536e-03,
        -4.70631719e-02],
       [ 7.08850399e-02, -2.22108185e-01,  9.55490842e-02,
         8.27069581e-02, -8.26469213e-02,  6.11465797e-02,
        -1.16094500e-02, -5.27828485e-02, -1.15885019e-01,
         5.18243313e-02],
       [ 3.29715788e-01, -3.13612908e-01, -6.12241775e-03,
         1.13853648e-01, -5.03032729e-02,  8.06967914e-03,
         1.62531644e-01,  6.59415871e-02, -2.12509841e-01,
         1.27977878e-03],
       [ 3.39619875e-01, -3.91972303e-01,  8.85933489e-02,
        -3.05471234e-02, -1.20262913e-01,  6.04787953e-02,
         6.49789125e-02,  3.68731022e-02, -2.51517773e-01,
        -8.98065120e-02],
       [ 1.95044711e-01, -3.09741080e-01,  1.22452214e-01,
        -4.36402112e-03, -9.90841240e-02,  5.20730354e-02,
         4.36080620e-03,  8.76479447e-02, -1.19349591e-01,
        -6.05513379e-02],
       [ 2.00214952e-01, -2.13249922e-01,  1.52568370e-01,
        -1.05035305e-03, -5.22367172e-02,  6.04827665e-02,
         8.36083293e-02,  1.06783055e-01, -1.32924229e-01,
        -1.29709750e-01],
       [ 1.30774364e-01, -1.79495052e-01,  6.19825684e-02,
        -5.08893281e-05, -5.47186211e-02, -2.32685748e-02,
         1.30711570e-02,  7.80923441e-02, -5.32548018e-02,
         6.83813691e-02],
       [ 1.37058780e-01, -2.97546387e-01,  8.60983878e-02,
         9.79922265e-02, -9.17050391e-02, -1.93512067e-04,
         2.21372060e-02,  5.17792515e-02, -1.45198271e-01,
         2.39564665e-02],
       [ 1.51854575e-01, -3.07811350e-01,  5.85288554e-02,
         2.71825790e-02, -1.01690598e-01,  3.11207920e-02,
         2.07860265e-02,  6.53296411e-02, -1.05385132e-01,
        -4.15206067e-02],
       [ 1.17928818e-01, -2.10365564e-01, -4.50771563e-02,
        -2.67105289e-02, -1.87750384e-01,  7.96425939e-02,
         1.16481632e-01,  1.23318985e-01, -1.06438920e-02,
         9.73494127e-02],
       [ 2.71671772e-01, -2.63200879e-01,  1.55655801e-01,
         7.29005858e-02, -3.64517160e-02,  4.19227108e-02,
         1.18004873e-01,  9.68050808e-02, -2.80712485e-01,
        -9.76329446e-02],
       [ 9.14449394e-02, -6.67936727e-02,  3.61688957e-02,
         2.94166133e-02, -2.28428245e-02, -7.15101287e-02,
         1.12430260e-01,  2.77679227e-02, -8.67272392e-02,
        -5.86342737e-02],
       [ 1.19631827e-01, -1.76199451e-01,  8.88054371e-02,
        -5.79333529e-02, -1.61447689e-01,  1.31010031e-02,
        -2.76737921e-02,  1.25910565e-01, -7.01321363e-02,
         3.53521109e-02],
       [ 1.19876057e-01, -1.13361612e-01, -2.57261433e-02,
        -4.72843572e-02, -4.17717062e-02, -4.30146120e-02,
        -5.17283268e-02,  1.30386308e-01,  4.24293503e-02,
         8.93710554e-02],
       [ 1.94169819e-01, -2.65223354e-01,  3.82882431e-02,
         1.03008702e-01, -6.13676086e-02,  1.07859813e-01,
         6.72024637e-02,  6.80919513e-02, -2.10222095e-01,
        -6.00702912e-02],
       [ 1.22785389e-01, -2.32885510e-01,  5.83024025e-02,
         4.63435128e-02, -9.07529220e-02,  2.89085135e-02,
         1.10610694e-01, -7.84037262e-02, -1.43721595e-01,
        -1.06894046e-01],
       [ 1.53462917e-01, -2.34329149e-01,  2.19478011e-02,
         5.20666614e-02, -6.18710145e-02,  2.15046108e-03,
         4.65009995e-02,  7.72917271e-03, -2.06512794e-01,
        -1.44317150e-02],
       [ 8.44687074e-02, -1.63276523e-01,  8.91670585e-03,
        -1.00204740e-02, -1.17402092e-01, -4.77838889e-02,
         3.61629464e-02, -2.33641900e-02, -6.58508167e-02,
         8.33485425e-02],
       [ 2.75007486e-01, -3.16218793e-01,  6.56070560e-03,
         1.27888322e-02, -1.52386218e-01,  1.14409402e-01,
         3.27666402e-02,  1.20601952e-01, -1.16330400e-01,
        -5.78423366e-02],
       [ 5.35397902e-02, -2.26179808e-01,  5.67623600e-02,
        -4.03686017e-02, -1.11849219e-01,  2.52255201e-02,
         2.42093116e-01, -2.27400400e-02, -7.09357709e-02,
         1.58765018e-02],
       [ 1.75934523e-01, -2.08054245e-01,  8.09602365e-02,
         5.75007647e-02, -1.16466999e-01, -4.53156643e-02,
         9.55019444e-02,  3.91122252e-02, -2.05660030e-01,
        -2.10969180e-01],
       [ 2.63072968e-01, -3.34022671e-01,  7.75949284e-02,
        -2.34512314e-02, -1.88964546e-01,  1.54963434e-02,
         1.01806656e-01,  1.00506425e-01, -7.98694119e-02,
        -6.06725216e-02],
       [ 1.06077157e-02, -9.93775576e-02,  2.10284032e-02,
         3.52887884e-02, -3.98235321e-02, -5.13123535e-03,
         6.64520115e-02,  2.50905119e-02, -2.88086385e-03,
         2.87667308e-02],
       [ 1.60640046e-01, -2.37863019e-01,  1.27255619e-02,
         3.80289704e-02, -7.33542293e-02, -9.28787142e-03,
         6.79433346e-02, -6.27164468e-02, -1.42385989e-01,
        -8.91289040e-02],
       [-1.52041204e-03, -7.26825595e-02, -9.42894071e-03,
         8.61139968e-03, -1.44720245e-02,  5.08274063e-02,
         4.08250876e-02, -1.67283006e-02, -4.39867675e-02,
         3.02529763e-02],
       [ 7.84117430e-02, -8.68821368e-02,  3.77271473e-02,
        -2.46688798e-02, -1.19762987e-01, -1.83716081e-02,
         6.53194189e-02,  1.59102783e-01,  2.84661017e-02,
         1.69514082e-02],
       [ 1.49873719e-01, -2.33182698e-01,  1.29174203e-01,
         3.85063589e-02, -9.77312028e-02, -2.16911286e-02,
        -1.02661110e-01,  1.20439932e-01, -5.27799651e-02,
        -1.43973231e-02],
       [ 1.66275203e-01, -6.17752448e-02, -5.89453727e-02,
         3.24374810e-02, -5.51749542e-02,  3.03901769e-02,
         1.16260141e-01,  8.42447653e-02, -6.38790280e-02,
        -1.03692207e-02],
       [ 1.36118680e-01, -1.16787389e-01,  3.80006954e-02,
         8.70505273e-02,  5.10789454e-04, -2.82304585e-02,
         2.45984010e-02,  7.17180595e-02, -1.32799834e-01,
        -8.55114982e-02],
       [ 1.91627622e-01, -2.63935387e-01,  4.45132554e-02,
         5.81513271e-02, -9.87120122e-02,  3.95036563e-02,
         1.74511611e-01,  1.01226643e-01, -1.32471725e-01,
        -9.00025815e-02],
       [ 1.25715703e-01, -1.65888295e-01, -2.93779373e-02,
         1.18103676e-01,  1.20232143e-02,  3.79601903e-02,
         1.09801933e-01,  9.84949842e-02, -2.29242846e-01,
        -1.04284897e-01],
       [ 8.91549885e-02, -2.70614028e-01,  6.29819557e-02,
        -7.89459795e-03, -1.02348357e-01,  5.10921739e-02,
        -5.17442077e-03, -7.32309371e-03, -1.17085509e-01,
         1.55961560e-02],
       [ 2.43444324e-01, -1.52708650e-01,  7.86574334e-02,
         4.44588959e-02, -2.38025673e-02, -3.11100893e-02,
        -1.83127522e-02,  7.00439587e-02, -1.31154567e-01,
        -6.37955219e-02],
       [ 1.85826480e-01, -6.58973902e-02,  3.55517119e-03,
         3.14786285e-03, -6.16544299e-02, -5.05571067e-03,
         9.06574503e-02,  5.25272116e-02, -1.56728029e-01,
        -1.99988559e-02],
       [ 1.02722555e-01, -1.94791898e-01,  1.95781067e-02,
         9.08551365e-03, -4.91546243e-02,  8.17024708e-03,
         1.35994062e-01, -8.04902613e-03, -1.96649328e-01,
        -1.27977014e-01]], dtype=float32)>
} }
2023-11-07 23:18:21.480277: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

调用已恢复的函数只是基于已保存模型的前向传递 (tf.keras.Model.predict)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作?通常的做法是将此加载对象封装到 Keras 层以实现此目的。幸运的是,TF Hub 为此提供了 hub.KerasLayer,如下所示:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2023-11-07 23:18:22.447273: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
235/235 [==============================] - 5s 7ms/step - loss: 0.3242 - sparse_categorical_accuracy: 0.9105
Epoch 2/2
235/235 [==============================] - 2s 7ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9723

在上面的示例中,TensorFlow Hub 的 hub.KerasLayer 可将从 tf.saved_model.load 加载回的结果封装到可用于构建其他模型的 Keras 层。这对于迁移学习非常实用。

我应使用哪种 API?

对于保存,如果您使用的是 Keras 模型,请使用 Keras Model.save API,除非您需要低级 API 允许的额外控制。如果您保存的不是 Keras 模型,那么您只能选择使用较低级的 API tf.saved_model.save

对于加载,您的 API 选择取决于您要从加载 API 中获得什么。如果您无法(或不想)获取 Keras 模型,请使用 tf.saved_model.load。否则,请使用 tf.keras.models.load_model。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。

可以搭配使用 API。您可以使用 model.save 保存 Keras 模型,并使用低级 API tf.saved_model.load 加载非 Keras 模型。

model = get_model()

# Saving the model using Keras `Model.save`
model.save(saved_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

从本地设备保存/加载

在远程设备上训练的过程中从本地 I/O 设备保存和加载时(例如,使用 Cloud TPU 时),必须使用 tf.saved_model.SaveOptionstf.saved_model.LoadOptions 中的选项 experimental_io_device 将 I/O 设备设置为 localhost。例如:

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

警告

一种特殊情况是当您以某种方式创建 Keras 模型,然后在训练之前保存它们。例如:

class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(saved_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f797645bf10>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f797645bf10>, because it is not built.
ValueError:  Model <__main__.SubclassedModel object at 0x7f797645bf10> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.

SavedModel 保存跟踪 tf.function 时生成的 tf.types.experimental.ConcreteFunction 对象(请查看计算图和 tf.function 简介指南中的函数何时执行跟踪?了解更多信息)。如果您收到像这样的 ValueError,那是因为 Model.save 无法找到或创建跟踪的 ConcreteFunction

小心:您不应在一个 ConcreteFunction 都没有的情况下保存模型,因为如果这样做,低级 API 将生成一个没有 ConcreteFunction 签名的 SavedModel(详细了解 SavedModel 格式)。例如:

tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f797645bf10>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f797645bf10>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7f79765fdfa0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7f79765fdfa0>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
_SignatureMap({})

一般而言,模型的前向传递(call 方法)会在第一次调用模型时被自动跟踪,通常是通过 Keras Model.fit 方法。如果您设置了输入形状,例如通过将第一层设为 tf.keras.layers.InputLayer 或其他层类型,并将 input_shape 关键字参数传递给它,Keras 序贯函数式 API 也可以生成 ConcreteFunction

要验证您的模型是否有任何跟踪的 ConcreteFunction,请检查 Model.save_spec 是否为 None

print(my_model.save_spec() is None)
True

我们使用 tf.keras.Model.fit 来训练模型,可以注意到,save_spec 被定义并且模型保存将生效:

BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(saved_model_path)
Epoch 1/2
7/7 [==============================] - 1s 3ms/step - loss: 12.9327
Epoch 2/2
7/7 [==============================] - 0s 2ms/step - loss: 12.4291
False
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets