tf_agents.environments.TrajectoryReplay

A helper that replays a policy against given Trajectory observations.

policy A tf_policy.TFPolicy policy.
time_major If True, the tensors in trajectory passed to method run are assumed to have shape [time, batch, ...]. Otherwise (default) they are assumed to have shape [batch, time, ...].

ValueError If policy is not an instance of tf_policy.TFPolicy.

Methods

run

View source

Apply the policy to trajectory steps and store actions/info.

If self.time_major == True, the tensors in trajectory are assumed to have shape [time, batch, ...]. Otherwise they are assumed to have shape [batch, time, ...].

Args
trajectory The Trajectory to run against. If the replay class was created with time_major=True, then the tensors in trajectory must be shaped [time, batch, ...]. Otherwise they must be shaped [batch, time, ...].
policy_state (optional) A nest Tensor with initial step policy state.

Returns
output_actions A nest of the actions that the policy took. If the replay class was created with time_major=True, then the tensors here will be shaped [time, batch, ...]. Otherwise they'll be shaped [batch, time, ...].
output_policy_info A nest of the policy info that the policy emitted. If the replay class was created with time_major=True, then the tensors here will be shaped [time, batch, ...]. Otherwise they'll be shaped [batch, time, ...].
policy_state A nest Tensor with final step policy state.

Raises
TypeError If policy_state structure doesn't match self.policy.policy_state_spec, or trajectory structure doesn't match self.policy.trajectory_spec.
ValueError If policy_state doesn't match self.policy.policy_state_spec, or trajectory structure doesn't match self.policy.trajectory_spec.
ValueError If trajectory lacks two outer dims.