tf.contrib.training.create_train_op
Creates an Operation
that evaluates the gradients and returns the loss.
tf.contrib.training.create_train_op(
total_loss, optimizer, global_step=_USE_GLOBAL_STEP, update_ops=None,
variables_to_train=None, transform_grads_fn=None, summarize_gradients=False,
gate_gradients=tf_optimizer.Optimizer.GATE_OP, aggregation_method=None,
colocate_gradients_with_ops=False, check_numerics=True
)
Args |
total_loss
|
A Tensor representing the total loss.
|
optimizer
|
A tf.Optimizer to use for computing the gradients.
|
global_step
|
A Tensor representing the global step variable. If left as
_USE_GLOBAL_STEP , then tf.contrib.framework.global_step() is used.
|
update_ops
|
An optional list of updates to execute. If update_ops is
None , then the update ops are set to the contents of the
tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None , but
it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS , a
warning will be displayed.
|
variables_to_train
|
an optional list of variables to train. If None, it will
default to all tf.compat.v1.trainable_variables().
|
transform_grads_fn
|
A function which takes a single argument, a list of
gradient to variable pairs (tuples), performs any requested gradient
updates, such as gradient clipping or multipliers, and returns the updated
list.
|
summarize_gradients
|
Whether or not add summaries for each gradient.
|
gate_gradients
|
How to gate the computation of gradients. See tf.Optimizer.
|
aggregation_method
|
Specifies the method used to combine gradient terms.
Valid values are defined in the class AggregationMethod .
|
colocate_gradients_with_ops
|
Whether or not to try colocating the gradients
with the ops that generated them.
|
check_numerics
|
Whether or not we apply check_numerics.
|
Returns |
A Tensor that when evaluated, computes the gradients and returns the total
loss value.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[]]