tf_agents.agents.ppo.ppo_policy.PPOPolicy

An ActorPolicy that also returns policy_info needed for PPO training.

Inherits From: ActorPolicy, TFPolicy

This policy requires two networks: the usual actor_network and the additional value_network. The value network can be executed with the apply_value_network() method.

When the networks have state (RNNs, LSTMs) you must be careful to pass the state for the actor network to action() and the state of the value network to apply_value_network(). Use get_initial_value_state() to access the state of the value network.

time_step_spec A TimeStep spec of the expected time_steps.
action_spec A nest of BoundedTensorSpec representing the actions.
actor_network An instance of a tf_agents.networks.network.Network, with call(observation, step_type, network_state). Network should return one of the following: 1. a nested tuple of tfp.distributions objects matching action_spec, or 2. a nested tuple of tf.Tensors representing actions.
value_network An instance of a tf_agents.networks.network.Network, with call(observation, step_type, network_state). Network should return value predictions for the input state.
observation_normalizer An object to use for obervation normalization.
clip Whether to clip actions to spec before returning them. Default True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped continuous actions for training.
collect If True, creates ops for actions_log_prob, value_preds, and action_distribution_params. (default True)
compute_value_and_advantage_in_train A bool to indicate where value prediction and advantage calculation happen. If True, both happen in agent.train(), therefore no need to save the value prediction inside of policy info. If False, value prediction is computed during data collection. This argument must be set to False if mini batch learning is enabled.

TypeError if actor_network or value_network is not of type tf_agents.networks.Network.
ValueError if actor_network or value_network do not emit valid outputs. For example, actor_network must either be a (legacy style) DistributionNetwork, or explicitly emit a nest of tfp.distribution.Distribution objects.

action_spec Describes the TensorSpecs of the Tensors expected by step(action).

action can be a single Tensor, or a nested dict, list or tuple of Tensors.

collect_data_spec Describes the Tensors written when using this policy with an environment.
emit_log_probability Whether this policy instance emits log probabilities or not.
info_spec Describes the Tensors emitted as info by action and distribution.

info can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

observation_and_action_constraint_splitter

observation_normalizer

policy_state_spec Describes the Tensors expected by step(_, policy_state).

policy_state can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

policy_step_spec Describes the output of action().
time_step_spec Describes the TimeStep tensors returned by step().
trajectory_spec Describes the Tensors written when using this policy with an environment.
validate_args Whether action & distribution validate input and output args.

Methods

action

View source

Generates next action given the time_step and policy_state.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.
seed Seed to use if action performs sampling (optional).

Returns
A PolicyStep named tuple containing: action: An action Tensor matching the action_spec. state: A policy state tensor to be fed into the next call to action. info: Optional side information such as action log probabilities.

Raises
RuntimeError If subclass init didn't call super().init. ValueError or TypeError: If validate_args is True and inputs or outputs do not match time_step_spec, policy_state_spec, or policy_step_spec.

apply_value_network

View source

Apply value network to time_step, potentially a sequence.

If observation_normalizer is not None, applies observation normalization.

Args
observations A (possibly nested) observation tensor with outer_dims either (batch_size,) or (batch_size, time_index). If observations is a time series and network is RNN, will run RNN steps over time series.
step_types A (possibly nested) step_types tensor with same outer_dims as observations.
value_state Optional. Initial state for the value_network. If not provided the behavior depends on the value network itself.
training Whether the output value is going to be used for training.

Returns
The output of value_net, which is a tuple of:

  • value_preds with same outer_dims as time_step
  • value_state at the end of the time series

distribution

View source

Generates the distribution over next actions given the time_step.

Args
time_step A TimeStep tuple corresponding to time_step_spec().
policy_state A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.

Returns
A PolicyStep named tuple containing:

action: A tf.distribution capturing the distribution of next actions. state: A policy state tensor for the next call to distribution. info: Optional side information such as action log probabilities.

Raises
ValueError or TypeError: If validate_args is True and inputs or outputs do not match time_step_spec, policy_state_spec, or policy_step_spec.

get_initial_state

View source

Returns an initial state usable by the policy.

Args
batch_size Tensor or constant: size of the batch dimension. Can be None in which case no dimensions gets added.

Returns
A nested object of type policy_state containing properly initialized Tensors.

get_initial_value_state

View source

Returns the initial state of the value network.

Args
batch_size A constant or Tensor holding the batch size. Can be None, in which case the state will not have a batch dimension added.

Returns
A nest of zero tensors matching the spec of the value network state.

update

View source

Update the current policy with another policy.

This would include copying the variables from the other policy.

Args
policy Another policy it can update from.
tau A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard update. This is used for trainable variables.
tau_non_trainable A float scalar in [0, 1] for non_trainable variables. If None, will copy from tau.
sort_variables_by_name A bool, when True would sort the variables by name before doing the update.

Returns
An TF op to do the update.