Facilitates flattening and unflattening batch dims of a tensor.
tf_agents.networks.utils.BatchSquash(
batch_dims
)
Used in the notebooks
Used in the tutorials |
---|
Exposes a pair of matched faltten and unflatten methods. After flattening only 1 batch dimension will be left. This facilitates evaluating networks that expect inputs to have only 1 batch dimension.
Methods
flatten
flatten(
tensor
)
Flattens and caches the tensor's batch_dims.
unflatten
unflatten(
tensor
)
Unflattens the tensor's batch_dims using the cached shape.