tfp.substrates.jax.optimizer.convergence_criteria.LossNotDecreasing

Simple convergence criterion based on lack of decrease in loss values.

Inherits From: ConvergenceCriterion

This rule tracks an exponentially-weighted moving average of the decrease in loss values between successive steps, and stops when that average drops below a threshold.

decrease_in_loss[t] = loss[t-1] - loss[t]
average_decrease_in_loss[t] = (
  (window_size - 1) * average_decrease_in_loss[t - 1] +
   decrease_in_loss[t]) / window_size
has_converged = (average_decrease_in_loss < threshold)

The convergence threshold can be set directly as atol, or as a fraction of the average loss decrease across the first window_size steps of the optimization: threshold = rtol * average_decrease_in_loss[window_size]. If both atol and rtol are specified, the maximum of the two thresholds is used (equivalently, the optimization stops if either of the two conditions is met).

The state propagated across training steps is `state[t] = LossNotDecreasingState(loss[t], average_decrease_in_loss[t], average_decrease_in_loss[window_size]).

atol float Tensor absolute tolerance. Convergence is assumed whenever (an exponentially-weighted moving average of) the decrease in loss values from one step to the next is less than atol. If both atol and rtol are specified, then convergence is assumed if either of the criteria is met.
rtol float Tensor relative tolerance. Convergence is assumed whenever (an exponentially-weighted moving average of) the decrease in loss values from one step to the next is less than rtol * average_initial_decrease_in_loss, where average_initial_decrease_in_loss is the exponentially-weighted moving average of the decrease in loss over the first window_size steps of the optimization. If both atol and rtol are specified, then convergence is assumed if either of the criteria is met.
window_size int Tensor effective window size for the moving average decrease in loss. The moving average is computed as moving_average[t] = decrease_in_loss[t] + decay * (moving_average[t-1] - decrease_in_loss[t]) where decay = 1. - 1. / window_size. Default value: 10.
min_num_steps int Tensor minimum number of steps before convergence. The criterion will not return has_converged=True until step >= min_num_steps. This should generally be a larger value than window_size. Default value: 20.
name optional Python str name prefixed to ops created by this class.

atol

dtype

min_num_steps

name

rtol

window_size

Methods

bootstrap

View source

Returns a structure of Tensors for the rule's state at step 0.

The shape of the Tensors specifying loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s).

Args
loss float Tensor initial value of loss being optimized.
grads list of float Tensor gradients of loss wrt parameters.
parameters list of float Tensor initial values of parameters being optimized.

Returns
initial_auxiliary_state (Structure of) Tensor(s) representing the initial auxiliary state carried forward by this criterion.

one_step

View source

Updates tracked quantities for a new step, and determines if converged.

The shape of the Tensors specifying loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s). In this case, the returned value has_converged will have shape equal to the broadcast batch shape of whichever of those quantities is used by this convergence criterion, and the quantities defining the convergence criterion ( min_num_steps, etc.).

Args
step integer Tensor index of the current step, where step >= 1 (on step 0, initial_state should be called instead).
loss float Tensor value of loss at the current step.
grads list of float Tensor gradients of loss wrt parameters.
parameters list of float Tensor current values of parameters being optimized.
auxiliary_state the (structure of) Tensor(s) containing state carried forward from the previous step.

Returns
has_converged boolean Tensor indicating whether the optimization has converged.
updated_auxiliary_state (Structure of) Tensor(s) representing updated quantities tracked by the convergence criterion. This should match the structure of the value returned by bootstrap.