tf_agents.trajectories.Transition

A tuple that represents a transition.

A Transition represents a S, A, S' sequence of operations. Tensors within a Transition are typically shaped [B, ...] where B is the batch size.

In some cases Transition objects are used to store time-shifted intermediate values for RNN computations, in which case the stored tensors are shaped [B, T, ...].

In other cases, Transition objects store n-step transitions S_t, A_t, S_{t+N} where the associated reward and discount in next_time_step are calculated as:

next_time_step.reward = r_t +
                        g^{1} * d_t * r_{t+1} +
                        g^{2} * d_t * d_{t+1} * r_{t+2} +
                        g^{3} * d_t * d_{t+1} * d_{t+2} * r_{t+3} +
                        ...
                        g^{N-1} * d_t * ... * d_{t+N-2} * r_{t+N-1}

next_time_step.discount = g^{N-1} * d_t * d_{t+1} * ... * d_{t+N-1}.

See to_n_step_transition for an example that converts Trajectory objects to this format.

time_step The initial state, reward, and discount.
action_step The action, policy info, and possibly policy state taken. (Note, action_step.state should not typically be stored in e.g. a replay buffer, except a copy inside policy_step.info as a special case for algorithms that choose to do this).
next_time_step The final state, reward, and discount.

Methods

replace

View source

Exposes as namedtuple._replace.

Usage:

new_transition = transition.replace(action_step=())

This returns a new transition with an empty action_step.

Args
**kwargs key/value pairs of fields in the transition.

Returns
A new Transition.