View source on GitHub |
Namedtuple of quantities that may be traced from tfp.math.minimize
.
tfp.substrates.jax.math.MinimizeTraceableQuantities(
step,
loss,
gradients,
parameters,
has_converged,
convergence_criterion_state,
optimizer_state,
seed
)
These are (in order):
step
: intTensor
index (starting from zero) of the current optimization step.loss
: floatTensor
value returned from the user-providedloss_fn
.gradients
: list ofTensor
gradients ofloss
with respect to the parameters.parameters
: list ofTensor
values of parameters being optimized. This corresponds totrainable_variables
passed tominimize
, orinit
passed tominimize_stateless
.has_converged
: booleanTensor
of the same shape asloss_fn
, withTrue
values corresponding to loss entries that have converged according to the user-provided convergence criterion. If no convergence criterion was specified, this isNone
.convergence_criterion_state
: structure ofTensor
s containing any auxiliary state (e.g., moving averages of loss or other quantities) maintained by the user-provided convergence criterion.optimizer_state
: structure ofTensor
s containing optional state from a user-provided pure optimizer.