tf_agents.utils.eager_utils.create_train_step

Creates a train_step that evaluates the gradients and returns the loss.

loss A (possibly nested tuple of) Tensor or function representing the loss.
optimizer A tf.Optimizer to use for computing the gradients.
global_step A Tensor representing the global step variable. If left as _USE_GLOBAL_STEP, then tf.train.get_or_create_global_step() is used.
total_loss_fn Function to call on loss value to access the final item to minimize.
update_ops An optional list of updates to execute. If update_ops is None, then the update ops are set to the contents of the tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None, but it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS, a warning will be displayed.
variables_to_train an optional list of variables to train. If None, it will default to all tf.trainable_variables().
transform_grads_fn A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated list.
summarize_gradients Whether or not add summaries for each gradient.
gate_gradients How to gate the computation of gradients. See tf.Optimizer.
aggregation_method Specifies the method used to combine gradient terms. Valid values are defined in the class AggregationMethod.
check_numerics Whether or not we apply check_numerics.

In graph mode: A (possibly nested tuple of) Tensor that when evaluated, calculates the current loss, computes the gradients, applies the optimizer, and returns the current loss. In eager mode: A lambda function that when is called, calculates the loss, then computes and applies the gradients and returns the original loss values.

ValueError if loss is not callable.
RuntimeError if resource variables are not enabled.