tf_agents.agents.PPOAgent

A PPO Agent.

Inherits From: TFAgent

time_step_spec A TimeStep spec of the expected time_steps.
action_spec A nest of BoundedTensorSpec representing the actions.
optimizer Optimizer to use for the agent, default to using tf.compat.v1.train.AdamOptimizer.
actor_net A network.DistributionNetwork which maps observations to action distributions. Commonly, it is set to actor_distribution_network.ActorDistributionNetwork.
value_net A Network which returns the value prediction for input states, with call(observation, step_type, network_state). Commonly, it is set to value_network.ValueNetwork.
greedy_eval Whether to use argmax/greedy action selection or sample from original action distribution for the evaluation policy. For environments such as ProcGen, stochastic is much better than greedy.
importance_ratio_clipping Epsilon in clipped, surrogate PPO objective. For more detail, see explanation at the top of the doc.
lambda_value Lambda parameter for TD-lambda computation.
discount_factor Discount factor for return computation. Default to 0.99 which is the value used for all environments from (Schulman, 2017).
entropy_regularization Coefficient for entropy regularization loss term. Default to 0.0 because no entropy bonus was used in (Schulman, 2017).
policy_l2_reg Coefficient for L2 regularization of unshared actor_net weights. Default to 0.0 because no L2 regularization was applied on the policy network weights in (Schulman, 2017).
value_function_l2_reg Coefficient for l2 regularization of unshared value function weights. Default to 0.0 because no L2 regularization was applied on the policy network weights in (Schulman, 2017).
shared_vars_l2_reg Coefficient for l2 regularization of weights shared between actor_net and value_net. Default to 0.0 because no L2 regularization was applied on the policy network or value network weights in (Schulman, 2017).
value_pred_loss_coef Multiplier for value prediction loss to balance with policy gradient loss. Default to 0.5, which was used for all environments in the OpenAI baseline implementation. This parameters is irrelevant unless you are sharing part of actor_net and value_net. In that case, you would want to tune this coeeficient, whose value depends on the network architecture of your choice.
num_epochs Number of epochs for computing policy updates. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool and 3 for Atari.
use_gae If True (default False), uses generalized advantage estimation for computing per-timestep advantage. Else, just subtracts value predictions from empirical return.
use_td_lambda_return If True (default False), uses td_lambda_return for training value function; here: td_lambda_return = gae_advantage + value_predictions. use_gae must be set to True as well to enable TD -lambda returns. If use_td_lambda_return is set to True while use_gae is False, the empirical return will be used and a warning will be logged.
normalize_rewards If true, keeps moving variance of rewards and normalizes incoming rewards. While not mentioned directly in (Schulman, 2017), reward normalization was implemented in OpenAI baselines and (Ilyas et al., 2018) pointed out that it largely improves performance. You may refer to Figure 1 of https://arxiv.org/pdf/1811.02553.pdf for a comparison with and without reward scaling.
reward_norm_clipping Value above and below to clip normalized reward. Additional optimization proposed in (Ilyas et al., 2018) set to 5 or 10.
normalize_observations If True, keeps moving mean and variance of observations and normalizes incoming observations. Additional optimization proposed in (Ilyas et al., 2018). If true, and the observation spec is not tf.float32 (such as Atari), please manually convert the observation spec received from the environment to tf.float32 before creating the networks. Otherwise, the normalized input to the network (float32) will have a different dtype as what the network expects, resulting in a mismatch error. Example usage: python observation_tensor_spec, action_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(env)) normalized_observation_tensor_spec = tf.nest.map_structure( lambda s: tf.TensorSpec( dtype=tf.float32, shape=s.shape, name=s.name ), observation_tensor_spec ) actor_net = actor_distribution_network.ActorDistributionNetwork( normalized_observation_tensor_spec, ...) value_net = value_network.ValueNetwork( normalized_observation_tensor_spec, ...) # Note that the agent still uses the original time_step_tensor_spec # from the environment. agent = ppo_clip_agent.PPOClipAgent( time_step_tensor_spec, action_spec, actor_net, value_net, ...)
log_prob_clipping +/- value for clipping log probs to prevent inf / NaN values. Default: no clipping.
kl_cutoff_factor Only meaningful when kl_cutoff_coef > 0.0. A multiplier used for calculating the KL cutoff ( = kl_cutoff_factor * adaptive_kl_target). If policy KL averaged across the batch changes more than the cutoff, a squared cutoff loss would be added to the loss function.
kl_cutoff_coef kl_cutoff_coef and kl_cutoff_factor are additional params if one wants to use a KL cutoff loss term in addition to the adaptive KL loss term. Default to 0.0 to disable the KL cutoff loss term as this was not used in the paper. kl_cutoff_coef is the coefficient to multiply by the KL cutoff loss term, before adding to the total loss function.
initial_adaptive_kl_beta Initial value for beta coefficient of adaptive KL penalty. This initial value is not important in practice because the algorithm quickly adjusts to it. A common default is 1.0.
adaptive_kl_target Desired KL target for policy updates. If actual KL is far from this target, adaptive_kl_beta will be updated. You should tune this for your environment. 0.01 was found to perform well for Mujoco.
adaptive_kl_tolerance A tolerance for adaptive_kl_beta. Mean KL above `(1

  • tol) * adaptive_kl_target, or below(1 - tol) * adaptive_kl_target, will causeadaptive_kl_betato be updated.0.5was chosen heuristically in the paper, but the algorithm is not very sensitive to it. </td> </tr><tr> <td>gradient_clipping<a id="gradient_clipping"></a> </td> <td> Norm length to clip gradients. Default: no clipping. </td> </tr><tr> <td>value_clipping<a id="value_clipping"></a> </td> <td> Difference between new and old value predictions are clipped to this threshold. Value clipping could be helpful when training very deep networks. Default: no clipping. </td> </tr><tr> <td>check_numerics<a id="check_numerics"></a> </td> <td> If true, adds <a href="https://www.tensorflow.org/api_docs/python/tf/debugging/check_numerics"><code>tf.debugging.check_numerics</code></a> to help find NaN / Inf values. For debugging only. </td> </tr><tr> <td>compute_value_and_advantage_in_train<a id="compute_value_and_advantage_in_train"></a> </td> <td> A bool to indicate where value prediction and advantage calculation happen. If True, both happen in agent.train(). If False, value prediction is computed during data collection. This argument must be set toFalseif mini batch learning is enabled. </td> </tr><tr> <td>update_normalizers_in_train<a id="update_normalizers_in_train"></a> </td> <td> A bool to indicate whether normalizers are updated as parts of thetrainmethod. Set toFalseif mini batch learning is enabled, or iftrainis called on multiple iterations of the same trajectories. In that case, you would need to usePPOLearner(which updates all the normalizers outside of the agent). This ensures that normalizers are updated in the same way as (Schulman, 2017). </td> </tr><tr> <td>aggregate_losses_across_replicas<a id="aggregate_losses_across_replicas"></a> </td> <td> only applicable to setups using multiple relicas. Default to aggregating across multiple cores using tf_agents. common.aggregate_losses. If set toFalse, usereduce_meandirectly, which is faster but may impact learning results. </td> </tr><tr> <td>debug_summaries<a id="debug_summaries"></a> </td> <td> A bool to gather debug summaries. </td> </tr><tr> <td>summarize_grads_and_vars<a id="summarize_grads_and_vars"></a> </td> <td> If true, gradient summaries will be written. </td> </tr><tr> <td>train_step_counter<a id="train_step_counter"></a> </td> <td> An optional counter to increment every time the train op is run. Defaults to the global_step. </td> </tr><tr> <td>name`
The name of this agent. All variables in this module will fall under that name. Defaults to the class name.

TypeError if actor_net or value_net is not of type tf_agents.networks.Network.

action_spec TensorSpec describing the action produced by the agent.
actor_net Returns actor_net TensorFlow template function.
collect_data_context

collect_data_spec Returns a Trajectory spec, as expected by the collect_policy.
collect_policy Return a policy that can be used to collect data from the environment.
data_context

debug_summaries

policy Return the current policy held by the agent.
summaries_enabled

summarize_grads_and_vars

time_step_spec Describes the TimeStep tensors expected by the agent.
train_sequence_length The number of time steps needed in experience tensors passed to train.

Train requires experience to be a Trajectory containing tensors shaped [B, T, ...]. This argument describes the value of T required.

For example, for non-RNN DQN training, T=2 because DQN requires single transitions.

If this value is None, then train can handle an unknown T (it can be determined at runtime from the data). Most RNN-based agents fall into this category.

train_step_counter

training_data_spec Returns a trajectory spec, as expected by the train() function.

Methods

adaptive_kl_loss

View source

compute_advantages

View source

Compute advantages, optionally using GAE.

Based on baselines ppo1 implementation. Removes final timestep, as it needs to use this timestep for next-step value prediction for TD error computation.

Args
rewards Tensor of per-timestep rewards.
returns Tensor of per-timestep returns.
discounts Tensor of per-timestep discounts. Zero for terminal timesteps.
value_preds Cached value estimates from the data-collection policy.

Returns
advantages Tensor of length (len(rewards) - 1), because the final timestep is just used for next-step value prediction.

compute_return_and_advantage

View source

Compute the Monte Carlo return and advantage.

Args
next_time_steps batched tensor of TimeStep tuples after action is taken.
value_preds Batched value prediction tensor. Should have one more entry in time index than time_steps, with the final value corresponding to the value prediction of the final state.

Returns
tuple of (return, advantage), both are batched tensors.

entropy_regularization_loss

View source

Create regularization loss tensor based on agent parameters.

get_loss

View source

Compute the loss and create optimization op for one training epoch.

All tensors should have a single batch dimension.

Args
time_steps A minibatch of TimeStep tuples.
actions A minibatch of actions.
act_log_probs A minibatch of action probabilities (probability under the sampling policy).
returns A minibatch of per-timestep returns.
normalized_advantages A minibatch of normalized per-timestep advantages.
action_distribution_parameters Parameters of data-collecting action distribution. Needed for KL computation.
weights Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps.
train_step A train_step variable to increment for each train step. Typically the global_step.
debug_summaries True if debug summaries should be created.
old_value_predictions (Optional) The saved value predictions, used for calculating the value estimation loss when value clipping is performed.
training Whether this loss is being used for training.

Returns
A tf_agent.LossInfo named tuple with the total_loss and all intermediate losses in the extra field contained in a PPOLossInfo named tuple.

initialize

View source

Initializes the agent.

Returns
An operation that can be used to initialize the agent.

Raises
RuntimeError If the class was not initialized properly (super.__init__ was not called).

kl_cutoff_loss

View source

kl_penalty_loss

View source

Compute a loss that penalizes policy steps with high KL.

Based on KL divergence from old (data-collection) policy to new (updated) policy.

All tensors should have a single batch dimension.

Args
time_steps TimeStep tuples with observations for each timestep. Used for computing new action distributions.
action_distribution_parameters Action distribution params of the data collection policy, used for reconstruction old action distributions.
current_policy_distribution The policy distribution, evaluated on all time_steps.
weights Optional scalar or element-wise (per-batch-entry) importance weights. Inlcudes a mask for invalid timesteps.
debug_summaries True if debug summaries should be created.

Returns
kl_penalty_loss The sum of a squared penalty for KL over a constant threshold, plus an adaptive penalty that encourages updates toward a target KL divergence.

l2_regularization_loss

View source

loss

View source

Gets loss from the agent.

If the user calls this from _train, it must be in a tf.GradientTape scope in order to apply gradients to trainable variables. If intermediate gradient steps are needed, _loss and _train will return different values since _loss only supports updating all gradients at once after all losses have been calculated.

Args
experience A batch of experience data in the form of a Trajectory. The structure of experience must match that of self.training_data_spec. All tensors in experience must be shaped [batch, time, ...] where time must be equal to self.train_step_length if that property is not None.
weights (optional). A Tensor, either 0-D or shaped [batch], containing weights to be used when calculating the total train loss. Weights are typically multiplied elementwise against the per-batch loss, but the implementation is up to the Agent.
training Explicit argument to pass to loss. This typically affects network computation paths like dropout and batch normalization.
**kwargs Any additional data as args to loss.

Returns
A LossInfo loss tuple containing loss and info tensors.

Raises
RuntimeError If the class was not initialized properly (super.__init__ was not called).

policy_gradient_loss

View source

Create tensor for policy gradient loss.

All tensors should have a single batch dimension.

Args
time_steps TimeSteps with observations for each timestep.
actions Tensor of actions for timesteps, aligned on index.
sample_action_log_probs Tensor of sample probability of each action.
advantages Tensor of advantage estimate for each timestep, aligned on index. Works better when advantage estimates are normalized.
current_policy_distribution The policy distribution, evaluated on all time_steps.
weights Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps.
debug_summaries True if debug summaries should be created.

Returns
policy_gradient_loss A tensor that will contain policy gradient loss for the on-policy experience.

post_process_policy

View source

Post process policies after training.

The policies of some agents require expensive post processing after training before they can be used. e.g. A Recommender agent might require rebuilding an index of actions. For such agents, this method will return a post processed version of the policy. The post processing may either update the existing policies in place or create a new policy, depnding on the agent. The default implementation for agents that do not want to override this method is to return agent.policy.

Returns
The post processed policy.

preprocess_sequence

View source

Defines preprocess_sequence function to be fed into replay buffers.

This defines how we preprocess the collected data before training. Defaults to pass through for most agents. Structure of experience must match that of self.collect_data_spec.

Args
experience a Trajectory shaped [batch, time, ...] or [time, ...] which represents the collected experience data.

Returns
A post processed Trajectory with the same shape as the input.

train

View source

Trains the agent.

Args
experience A batch of experience data in the form of a Trajectory. The structure of experience must match that of self.training_data_spec. All tensors in experience must be shaped [batch, time, ...] where time must be equal to self.train_step_length if that property is not None.
weights (optional). A Tensor, either 0-D or shaped [batch], containing weights to be used when calculating the total train loss. Weights are typically multiplied elementwise against the per-batch loss, but the implementation is up to the Agent.
**kwargs Any additional data to pass to the subclass.

Returns
A LossInfo loss tuple containing loss and info tensors.

  • In eager mode, the loss values are first calculated, then a train step is performed before they are returned.
  • In graph mode, executing any or all of the loss tensors will first calculate the loss value(s), then perform a train step, and return the pre-train-step LossInfo.

Raises
RuntimeError If the class was not initialized properly (super.__init__ was not called).

update_adaptive_kl_beta

View source

Create update op for adaptive KL penalty coefficient.

Args
kl_divergence KL divergence of old policy to new policy for all timesteps.

Returns
update_op An op which runs the update for the adaptive kl penalty term.

update_observation_normalizer

View source

update_reward_normalizer

View source

value_estimation_loss

View source

Computes the value estimation loss for actor-critic training.

All tensors should have a single batch dimension.

Args
time_steps A batch of timesteps.
returns Per-timestep returns for value function to predict. (Should come from TD-lambda computation.)
weights Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps.
old_value_predictions (Optional) The saved value predictions from policy_info, required when self._value_clipping > 0.
debug_summaries True if debug summaries should be created.
training Whether this loss is going to be used for training.

Returns
value_estimation_loss A scalar value_estimation_loss loss.

Raises
ValueError If old_value_predictions was not passed in, but value clipping was performed.