tfp.substrates.jax.optimizer.convergence_criteria.SuccessiveGradientsAreUncorrelated

Convergence criterion based on inner products between successive gradients.

Inherits From: ConvergenceCriterion

Let g[t] be the gradient vector at step t, and g[t-1] the previous gradient. Their inner product:

grad_inner_product[t] = sum_i(g[t, i] * g[t - 1, i])

measures correlation between successive optimization steps. We expect this to be positive if the optimization is making progress; conversely, it can be shown to be negative in expectation at the stationary distribution of constant-step-size SGD [(Pflug, 1990)][2].

This criterion detects convergence when an exponentially-weighted moving average of grad_inner_product becomes negative; intuitively, when there has been no consistent direction to the most recent window_size steps.

Theoretical analysis shows that with no decay (window_size=np.inf), this rule stops in finite time almost surely, for constant-step-size SGD under standard assumptions ([Pflug, 1990][2]; [Chee and Toulis, 2017][1]). In practice, it is often more efficient to use a decaying moving average.

Batch semantics: because this criterion does not depend on the loss, vector-valued losses will not produce vector-valued convergence indicators. Instead, the returned has_converged is always scalar, and is computed from the inner product summed across gradients from all variables being optimized.

please contact tfprobability@tensorflow.org.

References

[1] Jerry Chee and Panos Toulis. Convergence diagnostics for stochastic gradient descent with constant step size. _arXiv preprint arXiv:1710.06382, 2017. https://arxiv.org/abs/1710.06382

[2] Georg Ch. Pflug. Non-asymptotic confidence bounds for stochastic approximation algorithms with constant step size. Monatshefte fur Mathematik, 110(3-4), pp.297-314, 1990.

window_size int Tensor effective window size for the moving average. The moving average inner product is computed as moving_average[t] = grad_inner_product[t] + decay * (moving_average[t - 1] - grad_inner_product[t]) where decay = 1. - 1. / window_size. The non-decaying (decay = 1.) setting can therefore be recovered by passing window_size=np.inf. Default value: 10.
min_num_steps int Tensor minimum number of steps before convergence. The criterion will not return has_converged=True until step >= min_num_steps. This should generally be a larger value than window_size. Default value: 20.
name optional Python str name prefixed to ops created by this class.

min_num_steps

name

window_size

Methods

bootstrap

View source

Returns a structure of Tensors for the rule's state at step 0.

The shape of the Tensors specifying loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s).

Args
loss float Tensor initial value of loss being optimized.
grads list of float Tensor gradients of loss wrt parameters.
parameters list of float Tensor initial values of parameters being optimized.

Returns
initial_auxiliary_state (Structure of) Tensor(s) representing the initial auxiliary state carried forward by this criterion.

one_step

View source

Updates tracked quantities for a new step, and determines if converged.

The shape of the Tensors specifying loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s). In this case, the returned value has_converged will have shape equal to the broadcast batch shape of whichever of those quantities is used by this convergence criterion, and the quantities defining the convergence criterion ( min_num_steps, etc.).

Args
step integer Tensor index of the current step, where step >= 1 (on step 0, initial_state should be called instead).
loss float Tensor value of loss at the current step.
grads list of float Tensor gradients of loss wrt parameters.
parameters list of float Tensor current values of parameters being optimized.
auxiliary_state the (structure of) Tensor(s) containing state carried forward from the previous step.

Returns
has_converged boolean Tensor indicating whether the optimization has converged.
updated_auxiliary_state (Structure of) Tensor(s) representing updated quantities tracked by the convergence criterion. This should match the structure of the value returned by bootstrap.