Keras 中的遮盖和填充

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

设置

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
2022-12-14 22:02:38.381529: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:02:38.381624: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:02:38.381633: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

简介

遮盖的作用是告知序列处理层输入中有某些时间步骤丢失,因此在处理数据时应将其跳过。

填充是遮盖的一种特殊形式,其中被遮盖的步骤位于序列的起点或开头。填充是出于将序列数据编码成连续批次的需要:为了使批次中的所有序列适合给定的标准长度,有必要填充或截断某些序列。

让我们仔细看看。

填充序列数据

在处理序列数据时,各个样本常常具有不同长度。请考虑以下示例(文本被切分为单词):

[
  ["Hello", "world", "!"],
  ["How", "are", "you", "doing", "today"],
  ["The", "weather", "will", "be", "nice", "tomorrow"],
]

进行词汇查询后,数据可能会被向量化为整数,例如:

[
  [71, 1331, 4231]
  [73, 8, 3215, 55, 927],
  [83, 91, 1, 645, 1253, 927],
]

此数据是一个嵌套列表,其中各个样本的长度分别为 3、5 和 6。由于深度学习模型的输入数据必须为单一张量(例如在此例中形状为 (batch_size, 6, vocab_size)),短于最长条目的样本需要用占位符值进行填充(或者,也可以在填充短样本前截断长样本)。

Keras 提供了一个效用函数来截断和填充 Python 列表,使其具有相同长度:tf.keras.preprocessing.sequence.pad_sequences

raw_inputs = [
    [711, 632, 71],
    [73, 8, 3215, 55, 927],
    [83, 91, 1, 645, 1253, 927],
]

# By default, this will pad using 0s; it is configurable via the
# "value" parameter.
# Note that you could "pre" padding (at the beginning) or
# "post" padding (at the end).
# We recommend using "post" padding when working with RNN layers
# (in order to be able to use the
# CuDNN implementation of the layers).
padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(
    raw_inputs, padding="post"
)
print(padded_inputs)
[[ 711  632   71    0    0    0]
 [  73    8 3215   55  927    0]
 [  83   91    1  645 1253  927]]

遮盖

既然所有样本现在都具有了统一长度,那就必须告知模型,数据的某些部分实际上是填充,应该忽略。这种机制就是遮盖

在 Keras 模型中引入输入掩码有三种方式:

掩码生成层:EmbeddingMasking

这些层将在后台创建一个掩码张量(形状为 (batch, sequence_length) 的二维张量),并将其附加到由 MaskingEmbedding 层返回的张量输出上。

embedding = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)
masked_output = embedding(padded_inputs)

print(masked_output._keras_mask)

masking_layer = layers.Masking()
# Simulate the embedding lookup by expanding the 2D input to 3D,
# with embedding dimension of 10.
unmasked_embedding = tf.cast(
    tf.tile(tf.expand_dims(padded_inputs, axis=-1), [1, 1, 10]), tf.float32
)

masked_embedding = masking_layer(unmasked_embedding)
print(masked_embedding._keras_mask)
tf.Tensor(
[[ True  True  True False False False]
 [ True  True  True  True  True False]
 [ True  True  True  True  True  True]], shape=(3, 6), dtype=bool)
tf.Tensor(
[[ True  True  True False False False]
 [ True  True  True  True  True False]
 [ True  True  True  True  True  True]], shape=(3, 6), dtype=bool)

您可以在输出结果中看到,该掩码是一个形状为 (batch_size, sequence_length) 的二维布尔张量,其中每个 False 条目表示对应的时间步骤应在处理时忽略。

函数式 API 和序列式 API 中的掩码传播

在使用函数式 API 或序列式 API 时,由 EmbeddingMasking 层生成的掩码将通过网络传播给任何能够使用它们的层(如 RNN 层)。Keras 将自动提取与输入相对应的掩码,并将其传递给任何知道该掩码使用方法的层。

例如,在下面的序贯模型中,LSTM 层将自动接收掩码,这意味着它将忽略填充的值:

model = keras.Sequential(
    [layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True), layers.LSTM(32),]
)

对以下函数式 API 的情况也是如此:

inputs = keras.Input(shape=(None,), dtype="int32")
x = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)(inputs)
outputs = layers.LSTM(32)(x)

model = keras.Model(inputs, outputs)

将掩码张量直接传递给层

能够处理掩码的层(如 LSTM 层)在其 __call__ 方法中有一个 mask 参数。

同时,生成掩码的层(如 Embedding)会公开一个 compute_mask(input, previous_mask) 方法,供您调用。

因此,您可以将掩码生成层的 compute_mask() 方法的输出传递给掩码使用层的 __call__ 方法,如下所示:

class MyLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(MyLayer, self).__init__(**kwargs)
        self.embedding = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)
        self.lstm = layers.LSTM(32)

    def call(self, inputs):
        x = self.embedding(inputs)
        # Note that you could also prepare a `mask` tensor manually.
        # It only needs to be a boolean tensor
        # with the right shape, i.e. (batch_size, timesteps).
        mask = self.embedding.compute_mask(inputs)
        output = self.lstm(x, mask=mask)  # The layer will ignore the masked values
        return output


layer = MyLayer()
x = np.random.random((32, 10)) * 100
x = x.astype("int32")
layer(x)
<tf.Tensor: shape=(32, 32), dtype=float32, numpy=
array([[ 0.0068582 ,  0.00841713, -0.0079247 , ...,  0.00052032,
        -0.00937904,  0.0065177 ],
       [-0.00286922, -0.00298016, -0.00316756, ..., -0.00044272,
        -0.01275316,  0.00332524],
       [-0.00615364,  0.00443598, -0.00481584, ...,  0.00112707,
        -0.00541872, -0.00483814],
       ...,
       [ 0.00323759, -0.0020246 , -0.00310902, ..., -0.00167182,
        -0.00273533, -0.00789348],
       [ 0.00154623, -0.00299895, -0.00288218, ...,  0.00081357,
         0.00665191, -0.00701018],
       [-0.00209733, -0.00161297,  0.00302829, ..., -0.0003766 ,
        -0.00618952, -0.00943652]], dtype=float32)>

在自定义层中支持遮盖

有时,您可能需要编写生成掩码的层(如 Embedding),或者需要修改当前掩码的层。

例如,任何生成与其输入具有不同时间维度的张量的层(如在时间维度上进行连接的 Concatenate 层)都需要修改当前掩码,这样下游层才能正确顾及被遮盖的时间步骤。

为此,您的层应实现 layer.compute_mask() 方法,该方法会根据输入和当前掩码生成新的掩码。

以下是需要修改当前掩码的 TemporalSplit 层的示例。

class TemporalSplit(keras.layers.Layer):
    """Split the input tensor into 2 tensors along the time dimension."""

    def call(self, inputs):
        # Expect the input to be 3D and mask to be 2D, split the input tensor into 2
        # subtensors along the time axis (axis 1).
        return tf.split(inputs, 2, axis=1)

    def compute_mask(self, inputs, mask=None):
        # Also split the mask into 2 if it presents.
        if mask is None:
            return None
        return tf.split(mask, 2, axis=1)


first_half, second_half = TemporalSplit()(masked_embedding)
print(first_half._keras_mask)
print(second_half._keras_mask)
tf.Tensor(
[[ True  True  True]
 [ True  True  True]
 [ True  True  True]], shape=(3, 3), dtype=bool)
tf.Tensor(
[[False False False]
 [ True  True False]
 [ True  True  True]], shape=(3, 3), dtype=bool)

下面是关于 CustomEmbedding 层的另一个示例,该层能够根据输入值生成掩码:

class CustomEmbedding(keras.layers.Layer):
    def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):
        super(CustomEmbedding, self).__init__(**kwargs)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.mask_zero = mask_zero

    def build(self, input_shape):
        self.embeddings = self.add_weight(
            shape=(self.input_dim, self.output_dim),
            initializer="random_normal",
            dtype="float32",
        )

    def call(self, inputs):
        return tf.nn.embedding_lookup(self.embeddings, inputs)

    def compute_mask(self, inputs, mask=None):
        if not self.mask_zero:
            return None
        return tf.not_equal(inputs, 0)


layer = CustomEmbedding(10, 32, mask_zero=True)
x = np.random.random((3, 10)) * 9
x = x.astype("int32")

y = layer(x)
mask = layer.compute_mask(x)

print(mask)
tf.Tensor(
[[ True  True  True  True  True  True  True  True  True  True]
 [ True False  True  True  True False  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True]], shape=(3, 10), dtype=bool)

在兼容层上选择启用掩码传播

大多数层都不会修改时间维度,因此无需修改当前掩码。但是,这些层可能仍希望能够将当前掩码不加更改地传播到下一层。这是一种可以选择启用的行为。默认情况下,自定义层将破坏当前掩码(因为框架无法确定传播该掩码是否安全)。

如果您有一个不会修改时间维度的自定义层,且您希望它能够传播当前的输入掩码,您应该在层构造函数中设置 self.supports_masking = True。在这种情况下,compute_mask() 的默认行为是仅传递当前掩码。

下面是被列入掩码传播白名单的层的示例:

class MyActivation(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MyActivation, self).__init__(**kwargs)
        # Signal that the layer is safe for mask propagation
        self.supports_masking = True

    def call(self, inputs):
        return tf.nn.relu(inputs)

现在,您可以在掩码生成层(如 Embedding)和掩码使用层(如 LSTM)之间使用此自定义层,它会将掩码一路传递到掩码使用层。

inputs = keras.Input(shape=(None,), dtype="int32")
x = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)(inputs)
x = MyActivation()(x)  # Will pass the mask along
print("Mask found:", x._keras_mask)
outputs = layers.LSTM(32)(x)  # Will receive the mask

model = keras.Model(inputs, outputs)
Mask found: KerasTensor(type_spec=TensorSpec(shape=(None, None), dtype=tf.bool, name=None), name='Placeholder_1:0')

编写需要掩码信息的层

有些层是掩码使用者:他们会在 call 中接受 mask 参数,并使用该参数来决定是否跳过某些时间步骤。

要编写这样的层,您只需在 call 签名中添加一个 mask=None 参数。与输入关联的掩码只要可用就会被传递到您的层。

下面是一个简单示例:示例中的层在输入序列的时间维度(轴 1)上计算 Softmax,同时丢弃遮盖的时间步骤。

class TemporalSoftmax(keras.layers.Layer):
    def call(self, inputs, mask=None):
        broadcast_float_mask = tf.expand_dims(tf.cast(mask, "float32"), -1)
        inputs_exp = tf.exp(inputs) * broadcast_float_mask
        inputs_sum = tf.reduce_sum(
            inputs_exp * broadcast_float_mask, axis=-1, keepdims=True
        )
        return inputs_exp / inputs_sum


inputs = keras.Input(shape=(None,), dtype="int32")
x = layers.Embedding(input_dim=10, output_dim=32, mask_zero=True)(inputs)
x = layers.Dense(1)(x)
outputs = TemporalSoftmax()(x)

model = keras.Model(inputs, outputs)
y = model(np.random.randint(0, 10, size=(32, 100)), np.random.random((32, 100, 1)))

总结

以上是您需要了解的关于 Keras 中填充和遮盖的所有信息。回顾一下:

  • “遮盖”是层得知何时应该跳过/忽略序列输入中的某些时间步骤的方式。
  • 有些层是掩码生成者:Embedding 可以通过输入值来生成掩码(如果 mask_zero=True),Masking 层也可以。
  • 有些层是掩码使用者:它们会在其 __call__ 方法中公开 mask 参数。RNN 层就是如此。
  • 在函数式 API 和序列式 API 中,掩码信息会自动传播。
  • 单独使用层时,您可以将 mask 参数手动传递给层。
  • 您可以轻松编写会修改当前掩码的层、生成新掩码的层,或使用与输入关联的掩码的层。