tfp.substrates.numpy.math.minimize_stateless

Minimize a loss expressed as a pure function of its parameters.

loss_fn Python callable with signature loss = loss_fn(*init, seed=None). The loss function may optionally take a seed keyword argument, used to specify a per-iteration seed for stochastic loss functions (a stateless Tensor seed will be passed; see tfp.random.sanitize_seed).
init Tuple of Tensor initial parameter values (or nested structures of Tensor values) passed to the loss function.
num_steps Python int maximum number of steps to run the optimizer.
optimizer Pure functional optimizer to use. This may be an optax.GradientTransformation instance (in JAX), or any similar object that implements methods optimizer_state = optimizer.init(parameters) and updates, optimizer_state = optimizer.update(grads, optimizer_state, parameters).
convergence_criterion Optional instance of tfp.optimizer.convergence_criteria.ConvergenceCriterion representing a criterion for detecting convergence. If None, the optimization will run for num_steps steps, otherwise, it will run for at most num_steps steps, as determined by the provided criterion. Default value: None.
batch_convergence_reduce_fn Python callable of signature has_converged = batch_convergence_reduce_fn(batch_has_converged) whose input is a Tensor of boolean values of the same shape as the loss returned by loss_fn, and output is a scalar boolean Tensor. This determines the behavior of batched optimization loops when loss_fn's return value is non-scalar. For example, tf.reduce_all will stop the optimization once all members of the batch have converged, tf.reduce_any once any member has converged, lambda x: tf.reduce_mean(tf.cast(x, tf.float32)) > 0.5 once more than half have converged, etc. Default value: tf.reduce_all.
trace_fn Python callable with signature traced_values = trace_fn( traceable_quantities), where the argument is an instance of tfp.math.MinimizeTraceableQuantities and the returned traced_values may be a Tensor or nested structure of Tensors. The traced values are stacked across steps and returned. The default trace_fn simply returns the loss. In general, trace functions may also examine the gradients, values of parameters, the state propagated by the specified convergence_criterion, if any (if no convergence criterion is specified, this will be None), as well as any other quantities captured in the closure of trace_fn, for example, statistics of a variational distribution. Default value: lambda traceable_quantities: traceable_quantities.loss.
return_full_length_trace Python bool indicating whether to return a trace of the full length num_steps, even if a convergence criterion stopped the optimization early, by tiling the value(s) traced at the final optimization step. This enables use in contexts such as XLA that require shapes to be known statically. Default value: True.
jit_compile If True, compiles the minimization loop using XLA. XLA performs compiler optimizations, such as fusion, and attempts to emit more efficient code. This may drastically improve the performance. See the docs for tf.function. (In JAX, this will apply jax.jit). Default value: False.
seed PRNG seed for stochastic losses; see tfp.random.sanitize_seed. Default value: None.
name Python str name prefixed to ops created by this function. Default value: 'minimize_stateless'.

final_parameters Tuple of final parameter values, with the same structure and Tensor shapes as init.
trace Tensor or nested structure of Tensors, according to the return type of trace_fn. Each Tensor has an added leading dimension stacking the trajectory of the traced values over the course of the optimization. The size of this dimension is equal to num_steps if a convergence criterion was not specified and/or return_full_length_trace=True, and otherwise it is equal equal to the number of optimization steps taken.

Examples

To minimize the scalar function (x - 5)**2:

import optax  # Assume JAX backend.

loss_fn = lambda x: (x - 5.)**2
final_x, losses = tfp.math.minimize_stateless(
  loss_fn,
  init=0.,
  num_steps=100,
  optimizer=optax.adam(0.1))
print("optimized value is {} with loss {}".format(final_x, losses[-1]))

We can attempt to automatically detect convergence and stop the optimization by passing an instance of tfp.optimize.convergence_criteria.ConvergenceCriterion. For example, to stop the optimization once a moving average of the per-step decrease in loss drops below 0.01:

_, losses = tfp.math.minimize_stateless(
  loss_fn,
  init=0.,
  num_steps=1000,
  optimizer=optax.adam(0.1),
  convergence_criterion=(
    tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))

Here num_steps=1000 defines an upper bound: the optimization will be stopped after 1000 steps even if no convergence is detected.

In some cases, we may want to track additional context inside the optimization. We can do this by defining a custom trace_fn. This accepts a tfp.math.MinimizeTraceableQuantities tuple and returns a structure values to trace; these may include the loss, gradients, parameter values, or any auxiliary state maintained by the convergence criterion (if any).

trace_fn = lambda traceable_quantities: {
  'loss': traceable_quantities.loss,
  'x': traceable_quantities.parameters}
_, trace = tfp.math.minimize_stateless(loss_fn,
                                  init=0.,
                                  num_steps=100,
                                  optimizer=optax.adam(0.1),
                                  trace_fn=trace_fn)
print(trace['loss'].shape,   # => [100]
      trace['x'].shape)      # => [100]

When optimizing a batch of losses, some batch members will converge before others. The optimization will continue until the condition defined by the batch_convergence_reduce_fn becomes True. During these additional steps, converged elements will continue to be updated and may become unconverged. The convergence status of batch members can be diagnosed by tracing has_converged:

batch_size = 10
trace_fn = lambda traceable_quantities: {
  'loss': traceable_quantities.loss,
  'has_converged': traceable_quantities.has_converged}
_, trace = tfp.math.minimize_stateless(
  loss_fn,
  init=tf.zeros([batch_size]),
  num_steps=100,
  optimizer=optax.adam(0.1),
  trace_fn=trace_fn,
  convergence_criterion=(
    tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01)))

for i in range(batch_size):
  print('Batch element {} final state is {}converged.'
        ' It first converged at step {}.'.format(
        i, '' if has_converged[-1, i] else 'not ',
        np.argmax(trace.has_converged[:, i])))