View source on GitHub |
Returns a Keras layer that reshapes the inner dimensions of tensors.
tf_agents.keras_layers.InnerReshape(
current_shape: tf_agents.typing.types.Shape
,
new_shape: tf_agents.typing.types.Shape
,
**kwargs
) -> tf.keras.layers.Layer
Each tensor passed to an instance of InnerReshape
, will be reshaped to:
shape(tensor)[:-len(current_shape)] + new_shape
(after its inner shape is validated against current_shape
). Note:
The current_shape
may contain None
(unknown) dimension values.
This can be helpful when switching between Dense
, ConvXd
, and RNN
layers
in TF-Agents networks, in ways that are agnostic to whether the input has
either [batch_size]
or [batch_size, time]
outer dimensions.
For example, to switch between Dense
, Conv2D
, and GRU
layers:
net = tf_agents.networks.Sequential([
tf.keras.layers.Dense(32),
# Convert inner dim from [32] to [4, 4, 2] for Conv2D.
tf_agents.keras_layers.InnerReshape([None], new_shape=[4, 4, 2]),
tf.keras.layers.Conv2D(2, 3),
# Convert inner HWC dims [?, ?, 2] to [8] for Dense/RNN.
tf_agents.keras_layers.InnerReshape([None, None, 2], new_shape=[-1]),
tf.keras.layers.GRU(2, return_state=True, return_sequences=True)
])
Args | |
---|---|
current_shape
|
The current (partial) shape for the inner dims. This should
be a list , tuple , or tf.TensorShape with known rank. The given
current_shape must be compatible with the inner shape of the input.
Examples - [] , [None] , [None] * 3 , [3, 3, 4] , [3, None, 4] .
|
new_shape
|
The new shape for the inner dims. The length of new_shape need
not match the length of current_shape , but if both shapes are fully
defined then the total number of elements must match. It may have up to
one flexible (-1 ) dimension. Examples - [3] , [] , [-1] , [-1, 3] .
|
**kwargs
|
Additionnal args to the Keras core layer constructor, e.g. name .
|
Returns | |
---|---|
A new Keras Layer that performs the requested reshape on incoming tensors.
|
Raises | |
---|---|
ValueError
|
If current_shape has unknown rank.
|
ValueError
|
If both shapes are fully defined and the number of elements doesn't match. |