View source on GitHub |
Squash the outer dimensions of input tensors; unsquash outputs.
tf_agents.keras_layers.SquashedOuterWrapper(
wrapped: tf.keras.layers.Layer, inner_rank: int, **kwargs
)
This layer wraps a Keras layer wrapped
that cannot handle more than one
batch dimension. It squashes inputs' outer dimensions to a single larger
batch then unsquashes the outputs of wrapped
.
The outer dimensions are the leftmost rank(inputs) - inner_rank
dimensions.
Examples:
batch_norm = tf.keras.layers.BatchNormalization(axis=-1)
layer = SquashedOuterWrapper(wrapped=batch_norm, inner_rank=3)
inputs_0 = tf.random.normal((B, H, W, C))
# batch_norm sees tensor of shape [B, H, W, C]
# outputs_1 shape is [B, H, W, C]
outputs_0 = layer(inputs_0)
inputs_1 = tf.random.normal((B, T, H, W, C))
# batch_norm sees a tensor of shape [B * T, H, W, C]
# outputs_1 shape is [B, T, H, W, C]
outputs_1 = layer(inputs_1)
inputs_2 = tf.random.normal((B1, B2, T, H, W, C))
# batch_norm sees a tensor of shape [B1 * B2 * T, H, W, C]
# outputs_2 shape is [B1, B2, T, H, W, C]
outputs_2 = layer(inputs_2)
Attributes | |
---|---|
inner_rank
|
|
wrapped
|