在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
概述
本指南提供了使用 TensorFlow 2 (TF2) 编写代码的最佳做法列表,此列表专为最近从 TensorFlow 1 (TF1) 切换过来的用户编写。有关将 TF1 代码迁移到 TF2 的更多信息,请参阅指南的迁移部分。
设置
为本指南中的示例导入 TensorFlow 和其他依赖项。
import tensorflow as tf
import tensorflow_datasets as tfds
2022-12-14 20:16:24.587741: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:16:24.587848: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:16:24.587858: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
惯用 TensorFlow 2 的建议
将代码重构为更小的模块
一种良好做法是将代码重构为根据需要调用的更小函数。为了获得最佳性能,您应当尝试在 tf.function
中装饰最大的计算块(请注意,由 tf.function
调用的嵌套 Python 函数不需要自己单独的装饰,除非您想为 tf.function
使用不同的 jit_compile
设置)。根据您的用例,这可能是多个训练步骤,甚至是整个训练循环。对于推断用例,它可能是单个模型前向传递。
调整某些 tf.keras.optimizer
的默认学习率
在 TF2 中,某些 Keras 优化器具有不同的学习率。如果您发现模型的收敛行为发生变化,请检查默认学习率。
optimizers.SGD
、optimizers.Adam
或 optimizers.RMSprop
没有任何变更。
以下优化器的默认学习率已更改:
optimizers.Adagrad
从0.01
更改为0.001
optimizers.Adadelta
从1.0
更改为0.001
optimizers.Adamax
从0.002
更改为0.001
optimizers.Nadam
从0.002
更改为0.001
使用 tf.Module
和 Keras 层管理变量
tf.Module
和 tf.keras.layers.Layer
提供了方便的 variables
和 trainable_variables
属性,它们以递归方式收集所有因变量。这样便可轻松在使用变量的地方对它们进行本地管理。
Keras 层/模型继承自 tf.train.Checkpointable
并与 @tf.function
集成,这样便有可能从 Keras 对象直接导出 SavedModel 或为其添加检查点。您不必使用 Keras的 Model.fit
API 来利用这些集成。
阅读 Keras 指南中有关迁移学习和微调的部分,了解如何使用 Keras 收集相关变量的子集。
结合 tf.data.Dataset
和 tf.function
TensorFlow Datasets 软件包 (tfds) 包含用于将预定义数据集作为 tf.data.Dataset
对象加载的的实用工具。对于此示例,您可以使用 tfds
加载 MNIST 数据集:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
然后,准备用于训练的数据:
- 重新缩放每个图像;
- 重排样本顺序。
- 收集图像和标签批次。
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
为了使样本简短,将数据集修剪为仅返回 5 个批次:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2022-12-14 20:16:30.248760: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
使用常规 Python 迭代来迭代适合装入内存的训练数据。除此之外,tf.data.Dataset
是从磁盘流式传输训练数据的最佳方式。数据集是可迭代对象(但不是迭代器),就像其他 Eager Execution 中的 Python 可迭代对象一样。您可以通过将代码封装在 tf.function
中来充分利用数据集异步预提取/流式传输功能,此代码将 Python 迭代替换为使用 AutoGraph 的等效计算图运算。
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
如果您使用 Keras Model.fit
API,则不必担心数据集迭代。
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
使用 Keras 训练循环
如果您不需要对训练过程进行低级控制,建议使用 Keras 的内置 fit
、evaluate
和 predict
方法。无论实现方式(顺序、函数或子类化)如何,这些方法都能提供统一的接口来训练模型。
这些方法的优点包括:
- 接受 Numpy 数组、Python 生成器和
tf.data.Datasets
。 - 自动应用正则化和激活损失。
- 支持
tf.distribute
,无论硬件配置如何,训练代码都保持不变。 - 支持将任意可调用对象作为损失和指标。
- 支持
tf.keras.callbacks.TensorBoard
之类的回调以及自定义回调。 - 性能出色,可以自动使用 TensorFlow 计算图。
下面是使用 Dataset
训练模型的示例。要详细了解工作原理,请参阅教程。
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 7s 6ms/step - loss: 1.6298 - accuracy: 0.4750 Epoch 2/5 2022-12-14 20:16:37.438231: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.4652 - accuracy: 0.9031 Epoch 3/5 2022-12-14 20:16:37.745850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 4ms/step - loss: 0.2902 - accuracy: 0.9688 Epoch 4/5 2022-12-14 20:16:38.020249: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.2238 - accuracy: 0.9750 Epoch 5/5 2022-12-14 20:16:38.301951: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.1776 - accuracy: 0.9875 2022-12-14 20:16:38.608839: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 3ms/step - loss: 1.6134 - accuracy: 0.5906 Loss 1.6133979558944702, Accuracy 0.590624988079071 2022-12-14 20:16:38.952507: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
自定义训练并编写自己的循环
如果 Keras 模型适合您,但您需要更大的灵活性和对训练步骤或外层训练循环的控制,您可以实现自己的训练步骤甚至整个训练循环。如需了解详情,请参阅有关自定义 fit
的 Keras 指南。
此外 ,您还可以将许多内容作为 tf.keras.callbacks.Callback
实现。
这种方法具有前面提到的许多优点,但可以让您控制训练步骤甚至外层循环。
标准训练循环分为三个步骤:
- 迭代 Python 生成器或
tf.data.Dataset
来获得样本批次。 - 使用
tf.GradientTape
收集梯度。 - 使用
tf.keras.optimizers
之一将权重更新应用于模型的变量。
请记住:
- 始终在子类化层和模型的
call
方法上包含一个training
参数。 - 确保在
training
参数正确设置的情况下调用模型。 - 根据用法,在对一批数据运行模型之前,模型变量可能不存在。
- 您需要手动处理模型的正则化损失这类问题。
无需运行变量初始值设定项或添加手动控制依赖项。tf.function
会在创建时为您处理自动控制依赖项和变量初始化。
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2022-12-14 20:16:40.664234: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2022-12-14 20:16:40.916298: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2022-12-14 20:16:41.225330: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2022-12-14 20:16:41.484274: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2022-12-14 20:16:41.802437: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
通过 Python 控制流充分利用 tf.function
tf.function
提供了一种将依赖于数据的控制流转换为计算图模式等效项(如 tf.cond
和 tf.while_loop
)的方法。
数据依赖控制流出现的一个常见位置是序列模型。tf.keras.layers.RNN
封装一个 RNN 单元,允许您以静态或动态方式展开递归。例如,您可以按照下文所述重新实现动态展开。
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
阅读 tf.function
指南以了解更多信息。
新型指标和损失
指标和损失均为对象,两者都在 Eager 模式下工作,且都位于 tf.function
中。
损失对象是可调用对象,并使用 (y_true
, y_pred
) 作为参数:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
使用指标收集和显示数据
您可以使用 tf.metrics
聚合数据,使用 tf.summary
记录摘要并使用上下文管理器将其重定向到编写器。摘要会直接发送到编写器,这意味着您必须在调用点提供 step
值。
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
要在将数据记录为摘要之前对其进行聚合,请使用 tf.metrics
。指标是有状态的;它们积累值并在您调用 result
方法(例如 Mean.result
)时返回累积结果。可以使用 Model.reset_states
清除累积值。
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
通过将 TensorBoard 指向摘要日志目录来呈现生成的摘要:
tensorboard --logdir /tmp/summaries
使用 tf.summary
API 编写要在 TensorBoard 中呈现的摘要数据。有关更多信息,请阅读 tf.summary
指南。
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2022-12-14 20:16:42.710276: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.158 accuracy: 0.991 2022-12-14 20:16:42.986717: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.134 accuracy: 0.997 2022-12-14 20:16:43.273487: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.113 accuracy: 0.997 2022-12-14 20:16:43.559555: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.099 accuracy: 1.000 Epoch: 4 loss: 0.091 accuracy: 1.000 2022-12-14 20:16:43.827528: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Keras 指标名称
Keras 模型以一致方式处理指标名称。当您在指标列表中传递字符串时,该确切字符串会用作指标的 name
。这些名称在 model.fit
返回的历史对象中可见,而在传递给 keras.callbacks
的日志中,它们被设置为您在指标列表中传递的字符串。
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 2s 5ms/step - loss: 0.1055 - acc: 0.9937 - accuracy: 0.9937 - my_accuracy: 0.9937 2022-12-14 20:16:45.487647: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
调试
使用 Eager Execution 可以分步运行代码来检查形状、数据类型和值。某些 API(如 tf.function
、tf.keras
等)设计为使用计算图执行来提高性能和可移植性。调试时,使用 tf.config.run_functions_eagerly(True)
可以在此代码内使用 Eager Execution。
例如:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
这也可以在 Keras 模型和其他支持 Eager Execution 的 API 中使用:
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
注释:
tf.keras.Model
方法(例如fit
、evaluate
和predict
)作为计算图执行,并且tf.function
位于底层。使用
tf.keras.Model.compile
时,设置run_eagerly = True
以禁止Model
逻辑被封装在tf.function
中。使用
tf.data.experimental.enable_debug_mode
为tf.data
启用调试模式。阅读 API 文档,了解详细信息。
不要在您的对象中保留 tf.Tensors
这些张量对象可能会在 tf.function
或 Eager 上下文中创建,并且这些张量的行为有所不同。始终仅将 tf.Tensor
用于中间值。
要跟踪状态,请使用 tf.Variable
,因为它们始终可用于两种上下文。阅读 tf.Variable
指南以了解更多信息。