tf_agents.policies.TFPolicy

Abstract base class for TF Policies.

Used in the notebooks

Used in the tutorials

The Policy represents a mapping from time_steps recieved from the environment to actions that can be applied to the environment.

Agents expose two policies. A policy meant for deployment and evaluation, and a collect_policy for collecting data from the environment. The collect_policy is usually stochastic for exploring the environment better and may log auxilliary information such as log probabilities required for training as well. Policy objects can also be created directly by the users without using an Agent.

The main methods of TFPolicy are:

  • action: Maps a time_step from the environment to an action.
  • distribution: Maps a time_step to a distribution over actions.
  • get_initial_state: Generates the initial state for stateful policies, e.g. RNN/LSTM policies.

Example usage:

env = SomeTFEnvironment()
policy = TFRandomPolicy(env.time_step_spec(), env.action_spec())
# Or policy = agent.policy or agent.collect_policy

policy_state = policy.get_initial_state(env.batch_size)
time_step = env.reset()

while not time_step.is_last():
  policy_step = policy.action(time_step, policy_state)
  time_step = env.step(policy_step.action)

  policy_state = policy_step.state
  # policy_step.info may contain side info for logging, such as action log
  # probabilities.

Policies can be saved to disk as SavedModels (see policy_saver.py and policy_loader.py) or as TF Checkpoints.

A PyTFEagerPolicy can be used to wrap a TFPolicy so that it works with PyEnvironments.

For researchers, and those developing new Policies, the TFPolicy base class constructor also accept a validate_args parameter. If False, this disables all spec structure, dtype, and shape checks in the public methods of these classes. It allows algorithm developers to iterate and try different input and output structures without worrying about overly restrictive requirements, or input and output states being in a certain format. However, disabling argument validation can make it very hard to identify structural input or algorithmic errors; and should not be done for final, or production-ready, Policies. In addition to having implementations that may disagree with specs, this mean that the resulting Policy may no longer interact well with other parts of TF-Agents. Examples include impedance mismatches with Actor/Learner APIs, replay buffers, and the model export functionality in `PolicySaver.

time_step_spec A TimeStep spec of the expected time_steps. Usually provided by the user to the subclass.
action_spec A nest of BoundedTensorSpec representing the actions. Usually provided by the user to the subclass.
policy_state_spec A nest of TensorSpec representing the policy_state. Provided by the subclass, not directly by the user.
info_spec A nest of TensorSpec representing the policy info. Provided by the subclass, not directly by the user.
clip Whether to clip actions to spec before returning them. Default True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped continuous actions for training.
emit_log_probability Emit log-probabilities of actions, if supported. If True, policy_step.info will have CommonFields.LOG_PROBABILITY set. Please consult utility methods provided in policy_step for setting and retrieving these. When working with custom policies, either provide a dictionary info_spec or a namedtuple with the field 'log_probability'.
automatic_state_reset If True, then get_initial_policy_state is used to clear state in action() and distribution() for for time steps where time_step.is_first().
observation_and_action_constraint_splitter A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example observation_and_action_constraint_splitter could be as simple as: def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] Note: when using observation_and_action_constraint_splitter, make sure the provided q_network is compatible with the network-specific half of the output of the observation_and_action_constraint_splitter. In particular, observation_and_action_constraint_splitter will be called on the observation before passing to the network. If observation_and_action_constraint_splitter is None, action constraints are not applied.
validate_args Python bool. Whether to verify inputs to, and outputs of, functions like action and distribution against spec structures, dtypes, and shapes. Research code may prefer to set this value to False to allow iterating on input and output structures without being hamstrung by overly rigid checking (at the cost of harder-to-debug errors). See also TFAgent.validate_args.
name A name for this module. Defaults to the class name.

action_spec Describes the TensorSpecs of the Tensors expected by step(action).

action can be a single Tensor, or a nested dict, list or tuple of Tensors.

collect_data_spec Describes the Tensors written when using this policy with an environment.
emit_log_probability Whether this policy instance emits log probabilities or not.
info_spec Describes the Tensors emitted as info by action and distribution.

info can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

observation_and_action_constraint_splitter

policy_state_spec Describes the Tensors expected by step(_, policy_state).

policy_state can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

policy_step_spec Describes the output of action().
time_step_spec Describes the TimeStep tensors returned by step().
trajectory_spec Describes the Tensors written when using this policy with an environment.
validate_args Whether action & distribution validate input and output args.

Methods

action

View source

Generates next action given the time_step and policy_state.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.
seed Seed to use if action performs sampling (optional).

Returns
A PolicyStep named tuple containing: action: An action Tensor matching the action_spec. state: A policy state tensor to be fed into the next call to action. info: Optional side information such as action log probabilities.

Raises
RuntimeError If subclass init didn't call super().init. ValueError or TypeError: If validate_args is True and inputs or outputs do not match time_step_spec, policy_state_spec, or policy_step_spec.

distribution

View source

Generates the distribution over next actions given the time_step.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.

Returns
A PolicyStep named tuple containing:

action: A tf.distribution capturing the distribution of next actions. state: A policy state tensor for the next call to distribution. info: Optional side information such as action log probabilities.

Raises
ValueError or TypeError: If validate_args is True and inputs or outputs do not match time_step_spec, policy_state_spec, or policy_step_spec.

get_initial_state

View source

Returns an initial state usable by the policy.

Args
batch_size Tensor or constant: size of the batch dimension. Can be None in which case no dimensions gets added.

Returns
A nested object of type policy_state containing properly initialized Tensors.

update

View source

Update the current policy with another policy.

This would include copying the variables from the other policy.

Args
policy Another policy it can update from.
tau A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard update. This is used for trainable variables.
tau_non_trainable A float scalar in [0, 1] for non_trainable variables. If None, will copy from tau.
sort_variables_by_name A bool, when True would sort the variables by name before doing the update.

Returns
An TF op to do the update.