在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
设置
import tensorflow as tf
from tensorflow import keras
2022-12-14 21:33:02.655873: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:33:02.655988: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:33:02.655999: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Layer
类:状态(权重)和部分计算的组合
Keras 的一个中心抽象是 Layer
类。层封装了状态(层的“权重”)和从输入到输出的转换(“调用”,即层的前向传递)。
下面是一个密集连接的层。它具有一个状态:变量 w
和 b
。
class Linear(keras.layers.Layer):
def __init__(self, units=32, input_dim=32):
super(Linear, self).__init__()
w_init = tf.random_normal_initializer()
self.w = tf.Variable(
initial_value=w_init(shape=(input_dim, units), dtype="float32"),
trainable=True,
)
b_init = tf.zeros_initializer()
self.b = tf.Variable(
initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
您可以在某些张量输入上通过调用来使用层,这一点很像 Python 函数。
x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)
tf.Tensor( [[-0.00204108 -0.04129961 -0.01273937 -0.1288575 ] [-0.00204108 -0.04129961 -0.01273937 -0.1288575 ]], shape=(2, 4), dtype=float32)
请注意,权重 w
和 b
在被设置为层特性后会由层自动跟踪:
assert linear_layer.weights == [linear_layer.w, linear_layer.b]
请注意,您还可以使用一种更加快捷的方式为层添加权重:add_weight()
方法:
class Linear(keras.layers.Layer):
def __init__(self, units=32, input_dim=32):
super(Linear, self).__init__()
self.w = self.add_weight(
shape=(input_dim, units), initializer="random_normal", trainable=True
)
self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)
tf.Tensor( [[-0.11710902 0.10924062 0.02172409 0.02696215] [-0.11710902 0.10924062 0.02172409 0.02696215]], shape=(2, 4), dtype=float32)
层可以具有不可训练权重
除了可训练权重外,您还可以向层添加不可训练权重。训练层时,不必在反向传播期间考虑此类权重。
以下是添加和使用不可训练权重的方式:
class ComputeSum(keras.layers.Layer):
def __init__(self, input_dim):
super(ComputeSum, self).__init__()
self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)
def call(self, inputs):
self.total.assign_add(tf.reduce_sum(inputs, axis=0))
return self.total
x = tf.ones((2, 2))
my_sum = ComputeSum(2)
y = my_sum(x)
print(y.numpy())
y = my_sum(x)
print(y.numpy())
[2. 2.] [4. 4.]
它是 layer.weights
的一部分,但被归类为不可训练权重:
print("weights:", len(my_sum.weights))
print("non-trainable weights:", len(my_sum.non_trainable_weights))
# It's not included in the trainable weights:
print("trainable_weights:", my_sum.trainable_weights)
weights: 1 non-trainable weights: 1 trainable_weights: []
最佳做法:将权重创建推迟到得知输入的形状之后
上面的 Linear
层接受了一个 input_dim
参数,用于计算 __init__()
中权重 w
和 b
的形状:
class Linear(keras.layers.Layer):
def __init__(self, units=32, input_dim=32):
super(Linear, self).__init__()
self.w = self.add_weight(
shape=(input_dim, units), initializer="random_normal", trainable=True
)
self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
在许多情况下,您可能事先不知道输入的大小,并希望在得知该值时(对层进行实例化后的某个时间)再延迟创建权重。
在 Keras API 中,我们建议您在层的 build(self, inputs_shape)
方法中创建层权重。如下所示:
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super(Linear, self).__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
层的 __call__()
方法将在首次调用时自动运行构建。现在,您有了一个延迟并因此更易使用的层:
# At instantiation, we don't know on what inputs this is going to get called
linear_layer = Linear(32)
# The layer's weights are created dynamically the first time the layer is called
y = linear_layer(x)
如上所示单独实现 build()
很好地将只创建一次权重与在每次调用时使用权重分开。但是,对于一些高级自定义层,将状态创建和计算分开可能变得不切实际。层实现器可以将权重创建推迟到第一个 __call__()
,但需要注意,后面的调用会使用相同的权重。此外,由于 __call__()
很可能是第一次在 tf.function
中执行,在 __call__()
中发生的任何变量创建都应当封装在 tf.init_scope
中。
层可递归组合
如果将层实例分配为另一个层的特性,则外部层将开始跟踪内部层创建的权重。
我们建议在 __init__()
方法中创建此类子层,并将其留给第一个 __call__()
以触发构建它们的权重。
class MLPBlock(keras.layers.Layer):
def __init__(self):
super(MLPBlock, self).__init__()
self.linear_1 = Linear(32)
self.linear_2 = Linear(32)
self.linear_3 = Linear(1)
def call(self, inputs):
x = self.linear_1(inputs)
x = tf.nn.relu(x)
x = self.linear_2(x)
x = tf.nn.relu(x)
return self.linear_3(x)
mlp = MLPBlock()
y = mlp(tf.ones(shape=(3, 64))) # The first call to the `mlp` will create the weights
print("weights:", len(mlp.weights))
print("trainable weights:", len(mlp.trainable_weights))
weights: 6 trainable weights: 6
add_loss()
方法
在编写层的 call()
方法时,您可以在编写训练循环时创建想要稍后使用的损失张量。这可以通过调用 self.add_loss(value)
来实现:
# A layer that creates an activity regularization loss
class ActivityRegularizationLayer(keras.layers.Layer):
def __init__(self, rate=1e-2):
super(ActivityRegularizationLayer, self).__init__()
self.rate = rate
def call(self, inputs):
self.add_loss(self.rate * tf.reduce_sum(inputs))
return inputs
这些损失(包括由任何内部层创建的损失)可通过 layer.losses
取回。此属性会在每个 __call__()
开始时重置到顶层,因此 layer.losses
始终包含在上一次前向传递过程中创建的损失值。
class OuterLayer(keras.layers.Layer):
def __init__(self):
super(OuterLayer, self).__init__()
self.activity_reg = ActivityRegularizationLayer(1e-2)
def call(self, inputs):
return self.activity_reg(inputs)
layer = OuterLayer()
assert len(layer.losses) == 0 # No losses yet since the layer has never been called
_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1 # We created one loss value
# `layer.losses` gets reset at the start of each __call__
_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1 # This is the loss created during the call above
此外,loss
属性还包含为任何内部层的权重创建的正则化损失:
class OuterLayerWithKernelRegularizer(keras.layers.Layer):
def __init__(self):
super(OuterLayerWithKernelRegularizer, self).__init__()
self.dense = keras.layers.Dense(
32, kernel_regularizer=tf.keras.regularizers.l2(1e-3)
)
def call(self, inputs):
return self.dense(inputs)
layer = OuterLayerWithKernelRegularizer()
_ = layer(tf.zeros((1, 1)))
# This is `1e-3 * sum(layer.dense.kernel ** 2)`,
# created by the `kernel_regularizer` above.
print(layer.losses)
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0021994286>]
在编写训练循环时应考虑这些损失,如下所示:
# Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Iterate over the batches of a dataset.
for x_batch_train, y_batch_train in train_dataset:
with tf.GradientTape() as tape:
logits = layer(x_batch_train) # Logits for this minibatch
# Loss value for this minibatch
loss_value = loss_fn(y_batch_train, logits)
# Add extra losses created during this forward pass:
loss_value += sum(model.losses)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
有关编写训练循环的详细指南,请参阅从头开始编写训练循环指南。
这些损失还可以无缝使用 fit()
(它们会自动求和并添加到主损失中,如果有):
import numpy as np
inputs = keras.Input(shape=(3,))
outputs = ActivityRegularizationLayer()(inputs)
model = keras.Model(inputs, outputs)
# If there is a loss passed in `compile`, the regularization
# losses get added to it
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# It's also possible not to pass any loss in `compile`,
# since the model already has a loss to minimize, via the `add_loss`
# call during the forward pass!
model.compile(optimizer="adam")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
1/1 [==============================] - 0s 178ms/step - loss: 0.2993 1/1 [==============================] - 0s 33ms/step - loss: 0.0257 <keras.callbacks.History at 0x7fd7c594fe80>
add_metric()
方法
与 add_loss()
类似,层还具有 add_metric()
方法,用于在训练过程中跟踪数量的移动平均值。
请思考下面的 "logistic endpoint" 层。它将预测和目标作为输入,计算通过 add_loss()
跟踪的损失,并计算通过 add_metric()
跟踪的准确率标量。
class LogisticEndpoint(keras.layers.Layer):
def __init__(self, name=None):
super(LogisticEndpoint, self).__init__(name=name)
self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
self.accuracy_fn = keras.metrics.BinaryAccuracy()
def call(self, targets, logits, sample_weights=None):
# Compute the training-time loss value and add it
# to the layer using `self.add_loss()`.
loss = self.loss_fn(targets, logits, sample_weights)
self.add_loss(loss)
# Log accuracy as a metric and add it
# to the layer using `self.add_metric()`.
acc = self.accuracy_fn(targets, logits, sample_weights)
self.add_metric(acc, name="accuracy")
# Return the inference-time prediction tensor (for `.predict()`).
return tf.nn.softmax(logits)
可通过 layer.metrics
访问以这种方式跟踪的指标:
layer = LogisticEndpoint()
targets = tf.ones((2, 2))
logits = tf.ones((2, 2))
y = layer(targets, logits)
print("layer.metrics:", layer.metrics)
print("current accuracy value:", float(layer.metrics[0].result()))
layer.metrics: [<keras.metrics.metrics.BinaryAccuracy object at 0x7fd630047430>] current accuracy value: 1.0
和 add_loss()
一样,这些指标也是通过 fit()
跟踪的:
inputs = keras.Input(shape=(3,), name="inputs")
targets = keras.Input(shape=(10,), name="targets")
logits = keras.layers.Dense(10)(inputs)
predictions = LogisticEndpoint(name="predictions")(logits, targets)
model = keras.Model(inputs=[inputs, targets], outputs=predictions)
model.compile(optimizer="adam")
data = {
"inputs": np.random.random((3, 3)),
"targets": np.random.random((3, 10)),
}
model.fit(data)
1/1 [==============================] - 1s 594ms/step - loss: 0.9536 - binary_accuracy: 0.0000e+00 <keras.callbacks.History at 0x7fd7c59a8940>
可选择在层上启用序列化
如果需要将自定义层作为函数式模型的一部分进行序列化,您可以选择实现 get_config()
方法:
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super(Linear, self).__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
def get_config(self):
return {"units": self.units}
# Now you can recreate the layer from its config:
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
{'units': 64}
请注意,基础 Layer
类的 __init__()
方法会接受一些关键字参数,尤其是 name
和 dtype
。最好将这些参数传递给 __init__()
中的父类,并将其包含在层配置中:
class Linear(keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super(Linear, self).__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
def get_config(self):
config = super(Linear, self).get_config()
config.update({"units": self.units})
return config
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
{'name': 'linear_8', 'trainable': True, 'dtype': 'float32', 'units': 64}
如果根据层的配置对层进行反序列化时需要更大的灵活性,还可以重写 from_config()
类方法。下面是 from_config()
的基础实现:
def from_config(cls, config):
return cls(**config)
要详细了解序列化和保存,请参阅完整的保存和序列化模型指南。
call()
方法中的特权 training
参数
某些层,尤其是 BatchNormalization
层和 Dropout
层,在训练和推断期间具有不同的行为。对于此类层,标准做法是在 call()
方法中公开 training
(布尔)参数。
通过在 call()
中公开此参数,可以启用内置的训练和评估循环(例如 fit()
)以在训练和推断中正确使用层。
class CustomDropout(keras.layers.Layer):
def __init__(self, rate, **kwargs):
super(CustomDropout, self).__init__(**kwargs)
self.rate = rate
def call(self, inputs, training=None):
if training:
return tf.nn.dropout(inputs, rate=self.rate)
return inputs
call()
方法中的特权 mask
参数
call()
支持的另一个特权参数是 mask
参数。
它会出现在所有 Keras RNN 层中。掩码是布尔张量(在输入中每个时间步骤对应一个布尔值),用于在处理时间序列数据时跳过某些输入时间步骤。
当先前的层生成掩码时,Keras 会自动将正确的 mask
参数传递给 __call__()
(针对支持它的层)。掩码生成层是配置了 mask_zero=True
的 Embedding
层和 Masking
层。
要详细了解遮盖以及如何编写启用遮盖的层,请查看了解填充和遮盖指南。
Model
类
通常,您会使用 Layer
类来定义内部计算块,并使用 Model
类来定义外部模型,即您将训练的对象。
例如,在 ResNet50 模型中,您会有几个子类化 Layer
的 ResNet 块,以及一个包含整个 ResNet50 网络的 Model
。
Model
类具有与 Layer
相同的 API,但有如下区别:
- 它会公开内置训练、评估和预测循环(
model.fit()
、model.evaluate()
、model.predict()
)。 - 它会通过
model.layers
属性公开其内部层的列表。 - 它会公开保存和序列化 API(
save()
、save_weights()
…)
实际上,Layer
类对应于我们在文献中所称的“层”(如“卷积层”或“循环层”)或“块”(如“ResNet 块”或“Inception 块”)。
同时,Model
类对应于文献中所称的“模型”(如“深度学习模型”)或“网络”(如“深度神经网络”)。
因此,如果您想知道“我应该用 Layer
类还是 Model
类?”,请问自己:我是否需要在它上面调用 fit()
?我是否需要在它上面调用 save()
?如果是,则使用 Model
。如果不是(要么因为您的类只是更大系统中的一个块,要么因为您正在自己编写训练和保存代码),则使用 Layer
。
例如,我们可以使用上面的 mini-resnet 示例,用它来构建一个 Model
,该模型可以通过 fit()
进行训练,并通过 save_weights()
进行保存:
class ResNet(tf.keras.Model):
def __init__(self, num_classes=1000):
super(ResNet, self).__init__()
self.block_1 = ResNetBlock()
self.block_2 = ResNetBlock()
self.global_pool = layers.GlobalAveragePooling2D()
self.classifier = Dense(num_classes)
def call(self, inputs):
x = self.block_1(inputs)
x = self.block_2(x)
x = self.global_pool(x)
return self.classifier(x)
resnet = ResNet()
dataset = ...
resnet.fit(dataset, epochs=10)
resnet.save(filepath)
汇总:端到端示例
到目前为止,您已学习以下内容:
Layer
封装了状态(在__init__()
或build()
中创建)和一些计算(在call()
中定义)。- 层可以递归嵌套以创建新的更大的计算块。
- 层可以通过
add_loss()
和add_metric()
创建并跟踪损失(通常是正则化损失)以及指标。 - 您要训练的外部容器是
Model
。Model
就像Layer
,但是添加了训练和序列化实用工具。
让我们将这些内容全部汇总到一个端到端示例:我们将实现一个变分自动编码器 (VAE),并用 MNIST 数字对其进行训练。
我们的 VAE 将是 Model
的一个子类,它是作为子类化 Layer
的嵌套组合层进行构建的。它将具有正则化损失(KL 散度)。
from tensorflow.keras import layers
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
class Encoder(layers.Layer):
"""Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""
def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
super(Encoder, self).__init__(name=name, **kwargs)
self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
self.dense_mean = layers.Dense(latent_dim)
self.dense_log_var = layers.Dense(latent_dim)
self.sampling = Sampling()
def call(self, inputs):
x = self.dense_proj(inputs)
z_mean = self.dense_mean(x)
z_log_var = self.dense_log_var(x)
z = self.sampling((z_mean, z_log_var))
return z_mean, z_log_var, z
class Decoder(layers.Layer):
"""Converts z, the encoded digit vector, back into a readable digit."""
def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
super(Decoder, self).__init__(name=name, **kwargs)
self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
self.dense_output = layers.Dense(original_dim, activation="sigmoid")
def call(self, inputs):
x = self.dense_proj(inputs)
return self.dense_output(x)
class VariationalAutoEncoder(keras.Model):
"""Combines the encoder and decoder into an end-to-end model for training."""
def __init__(
self,
original_dim,
intermediate_dim=64,
latent_dim=32,
name="autoencoder",
**kwargs
):
super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
self.original_dim = original_dim
self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
# Add KL divergence regularization loss.
kl_loss = -0.5 * tf.reduce_mean(
z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
)
self.add_loss(kl_loss)
return reconstructed
让我们在 MNIST 上编写一个简单的训练循环:
original_dim = 784
vae = VariationalAutoEncoder(original_dim, 64, 32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()
loss_metric = tf.keras.metrics.Mean()
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
epochs = 2
# Iterate over epochs.
for epoch in range(epochs):
print("Start of epoch %d" % (epoch,))
# Iterate over the batches of the dataset.
for step, x_batch_train in enumerate(train_dataset):
with tf.GradientTape() as tape:
reconstructed = vae(x_batch_train)
# Compute reconstruction loss
loss = mse_loss_fn(x_batch_train, reconstructed)
loss += sum(vae.losses) # Add KLD regularization loss
grads = tape.gradient(loss, vae.trainable_weights)
optimizer.apply_gradients(zip(grads, vae.trainable_weights))
loss_metric(loss)
if step % 100 == 0:
print("step %d: mean loss = %.4f" % (step, loss_metric.result()))
Start of epoch 0 WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7fd61c2ab8b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7fd61c2ab8b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. step 0: mean loss = 0.3351 step 100: mean loss = 0.1247 step 200: mean loss = 0.0988 step 300: mean loss = 0.0890 step 400: mean loss = 0.0841 step 500: mean loss = 0.0808 step 600: mean loss = 0.0786 step 700: mean loss = 0.0771 step 800: mean loss = 0.0759 step 900: mean loss = 0.0749 Start of epoch 1 step 0: mean loss = 0.0746 step 100: mean loss = 0.0739 step 200: mean loss = 0.0734 step 300: mean loss = 0.0730 step 400: mean loss = 0.0727 step 500: mean loss = 0.0723 step 600: mean loss = 0.0720 step 700: mean loss = 0.0717 step 800: mean loss = 0.0714 step 900: mean loss = 0.0712
请注意,由于 VAE 是 Model
的子类,它具有内置的训练循环。因此,您也可以用以下方式训练它:
vae = VariationalAutoEncoder(784, 64, 32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=2, batch_size=64)
Epoch 1/2 938/938 [==============================] - 4s 3ms/step - loss: 0.0747 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0676 <keras.callbacks.History at 0x7fd61c1b3160>
超越面向对象的开发:函数式 API
这个示例对您来说是否包含了太多面向对象的开发?您也可以使用函数式 API 来构建模型。重要的是,选择其中一种样式并不妨碍您利用以另一种样式编写的组件:您随时可以搭配使用。
例如,下面的函数式 API 示例重用了我们在上面的示例中定义的同一个 Sampling
层:
original_dim = 784
intermediate_dim = 64
latent_dim = 32
# Define encoder model.
original_inputs = tf.keras.Input(shape=(original_dim,), name="encoder_input")
x = layers.Dense(intermediate_dim, activation="relu")(original_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()((z_mean, z_log_var))
encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name="encoder")
# Define decoder model.
latent_inputs = tf.keras.Input(shape=(latent_dim,), name="z_sampling")
x = layers.Dense(intermediate_dim, activation="relu")(latent_inputs)
outputs = layers.Dense(original_dim, activation="sigmoid")(x)
decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name="decoder")
# Define VAE model.
outputs = decoder(z)
vae = tf.keras.Model(inputs=original_inputs, outputs=outputs, name="vae")
# Add KL divergence regularization loss.
kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)
vae.add_loss(kl_loss)
# Train.
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=3, batch_size=64)
Epoch 1/3 938/938 [==============================] - 4s 3ms/step - loss: 0.0748 Epoch 2/3 938/938 [==============================] - 3s 3ms/step - loss: 0.0676 Epoch 3/3 938/938 [==============================] - 3s 3ms/step - loss: 0.0676 <keras.callbacks.History at 0x7fd6146fe6a0>
有关详情,请务必阅读函数式 API 指南。