tf_agents.keras_layers.DynamicUnroll

Process a history of sequences that are concatenated without padding.

Given batched, batch-major inputs, DynamicUnroll unrolls an RNN using cell; at each time step it feeds a frame of inputs as input to cell.call().

If at least one tensor in inputs has rank 3 or above (shaped [batch_size, n, ...] where n is the number of time steps), the RNN will run for exactly n steps.

If n == 1 is known statically, then only a single step is executed. This is done via a static unroll without using a tf.while_loop.

If all of the tensors in inputs have rank at most 2 (i.e., shaped [batch_size] or [batch_size, d], then it is assumed that a single step is being taken (i.e. n = 1) and the outputs will also not have a time dimension in their output.

cell A tf.nn.rnn_cell.RNNCell or Keras RNNCell (e.g. LSTMCell) whose call() method has the signature call(input, state, ...). Each tensor in the tuple is shaped [batch_size, ...].
parallel_iterations Parallel iterations to pass to tf.while_loop. The default value is a good trades off between memory use and performance. See documentation of tf.while_loop for more details.
swap_memory Python bool. Whether to swap memory from GPU to CPU when storing activations for backprop. This may sometimes have a negligible performance impact, but can improve memory usage. See documentation of tf.while_loop for more details.
**kwargs Additional layer arguments, such as dtype and name.

TypeError if cell lacks get_initial_state, output_size, or state_size property.

state_size

Methods

get_initial_state

View source