tf_agents.drivers.dynamic_step_driver.DynamicStepDriver

A driver that takes N steps in an environment using a tf.while_loop.

Inherits From: Driver

Used in the notebooks

The while loop will run num_steps in the environment, only counting steps that result in an environment transition, i.e. (time_step, action, next_time_step). If a step results in environment resetting, i.e. time_step.is_last() and next_time_step.is_first() (traj.is_boundary()), this is not counted toward the num_steps.

As environments run batched time_steps, the counters for all batch elements are summed, and execution stops when the total exceeds num_steps. When batch_size > 1, there is no guarantee that exactly num_steps are taken -- it may be more but never less.

This termination condition can be overridden in subclasses by implementing the self._loop_condition_fn() method.

env A tf_environment.Base environment.
policy A tf_policy.TFPolicy policy.
observers A list of observers that are updated after every step in the environment. Each observer is a callable(time_step.Trajectory).
transition_observers A list of observers that are updated after every step in the environment. Each observer is a callable((TimeStep, PolicyStep, NextTimeStep)).
num_steps The number of steps to take in the environment. For batched or parallel environments, this is the total number of steps taken summed across all environments.

ValueError If env is not a tf_environment.Base or policy is not an instance of tf_policy.TFPolicy.

env

info_observers

observers

policy

transition_observers

Methods

run

View source

Takes steps in the environment using the policy while updating observers.

Args
time_step optional initial time_step. If None, it will use the current_time_step of the environment. Elements should be shape [batch_size, ...].
policy_state optional initial state for the policy.
maximum_iterations Optional maximum number of iterations of the while loop to run. If provided, the cond output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than maximum_iterations.

Returns
time_step TimeStep named tuple with final observation, reward, etc.
policy_state Tensor with final step policy state.