One training step.
tfp.experimental.nn.util.make_fit_op(
loss_fn,
optimizer,
trainable_variables,
grad_summary_fn=None,
tf_function=True,
xla_compile=True
)
Args |
loss_fn
|
Python callable which returns the pair loss (tf.Tensor ) and
any other second result such that
tf.nest.map_structure(tf.convert_to_tensor, other) will succeed.
|
optimizer
|
tf.optimizers.Optimizer -like instance which has members
gradient and apply_gradients .
|
trainable_variables
|
tf.nest.flatten -able structure of tf.Variable
instances.
|
grad_summary_fn
|
Python callable which takes a trainable_variables -like
structure of tf.Tensor s representing the gradient of the result of
loss_fn with respect to trainable_variables . For example,
lambda grads: tf.nest.map_structure(
lambda x: 0. if x is None else tf.norm(x), grads) .
Default value: None (i.e., no summarization is made).
|
tf_function
|
bool representing whether the resulting function should be
tf.function decoreated.
Default value: True .
|
xla_compile
|
bool representing whether XLA compilation should be
performed. (This argument is ignored if the function is executed eagerly.)
Default value: True .
|
Returns |
fit_op
|
A Python callable taking args which are forwarded to loss_fn and
such that when called trainable_variables are updated per the logic of
optimizer.apply_gradients .
|