ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

View source on GitHub

Creates an Operation that evaluates the gradients and returns the loss.

total_loss A Tensor representing the total loss.
optimizer A tf.Optimizer to use for computing the gradients.
global_step A Tensor representing the global step variable. If left as _USE_GLOBAL_STEP, then tf.contrib.framework.global_step() is used.
update_ops An optional list of updates to execute. If update_ops is None, then the update ops are set to the contents of the tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None, but it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS, a warning will be displayed.
variables_to_train an optional list of variables to train. If None, it will default to all tf.compat.v1.trainable_variables().
transform_grads_fn A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated list.
summarize_gradients Whether or not add summaries for each gradient.
gate_gradients How to gate the computation of gradients. See tf.Optimizer.
aggregation_method Specifies the method used to combine gradient terms. Valid values are defined in the class AggregationMethod.
colocate_gradients_with_ops Whether or not to try colocating the gradients with the ops that generated them.
check_numerics Whether or not we apply check_numerics.

A Tensor that when evaluated, computes the gradients and returns the total loss value.