View source on GitHub |
Reshapes tensors in sample
to have shape [rows, num_steps, ...]
.
tf_agents.replay_buffers.reverb_replay_buffer.truncate_reshape_rows_by_num_steps(
sample, num_steps
)
This function takes a structure sample
and for each tensor t
, it truncates
the tensor's outer dimension to be the highest possible multiple of
num_steps
.
This is done by first calculating rows = tf.shape(t[0]) // num_steps
, then
truncating the tensor
to shape t_trunc = t[: (rows * num_steps), ...]
.
For each tensor, it returns tf.reshape(t_trunc, [rows, num_steps, ...])
.
Args | |
---|---|
sample
|
Nest of tensors. |
num_steps
|
Python integer. |
Returns | |
---|---|
A next with tensors reshaped to [rows, num_steps, ...] .
|