View source on GitHub
|
Facilitates flattening and unflattening batch dims of a tensor.
tf_agents.networks.utils.BatchSquash(
batch_dims
)
Used in the notebooks
| Used in the tutorials |
|---|
Exposes a pair of matched faltten and unflatten methods. After flattening only 1 batch dimension will be left. This facilitates evaluating networks that expect inputs to have only 1 batch dimension.
Args | |
|---|---|
batch_dims
|
Number of batch dimensions the flatten/unflatten ops should handle. |
Raises | |
|---|---|
ValueError
|
if batch dims is negative. |
Methods
flatten
flatten(
tensor
)
Flattens and caches the tensor's batch_dims.
unflatten
unflatten(
tensor
)
Unflattens the tensor's batch_dims using the cached shape.
View source on GitHub