ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf_agents.keras_layers.InnerReshape

Returns a Keras layer that reshapes the inner dimensions of tensors.

Each tensor passed to an instance of InnerReshape, will be reshaped to:

shape(tensor)[:-len(current_shape)] + new_shape

(after its inner shape is validated against current_shape). Note: The current_shape may contain None (unknown) dimension values.

This can be helpful when switching between Dense, ConvXd, and RNN layers in TF-Agents networks, in ways that are agnostic to whether the input has either [batch_size] or [batch_size, time] outer dimensions.

For example, to switch between Dense, Conv2D, and GRU layers:

net = tf_agents.networks.Sequential([
  tf.keras.layers.Dense(32),
  # Convert inner dim from [32] to [4, 4, 2] for Conv2D.
  tf_agents.keras_layers.InnerReshape([None], new_shape=[4, 4, 2]),
  tf.keras.layers.Conv2D(2, 3),
  # Convert inner HWC dims [?, ?, 2] to [8] for Dense/RNN.
  tf_agents.keras_layers.InnerReshape([None, None, 2], new_shape=[-1]),
  tf.keras.layers.GRU(2, return_state=True, return_sequences=True)
])

current_shape The current (partial) shape for the inner dims. This should be a list, tuple, or tf.TensorShape with known rank. The given current_shape must be compatible with the inner shape of the input. Examples - [], [None], [None] * 3, [3, 3, 4], [3, None, 4].
new_shape The new shape for the inner dims. The length of new_shape need not match the length of current_shape, but if both shapes are fully defined then the total number of elements must match. It may have up to one flexible (-1) dimension. Examples - [3], [], [-1], [-1, 3].
**kwargs Additionnal args to the Keras core layer constructor, e.g. name.

A new Keras Layer that performs the requested reshape on incoming tensors.

ValueError If current_shape has unknown rank.
ValueError If both shapes are fully defined and the number of elements doesn't match.