ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Computes a modified Cholesky decomposition for a batch of square matrices.

Given a symmetric matrix A, this function attempts to give a factorization A + E = LL^T where L is lower triangular, LL^T is positive definite, and E is small in some suitable sense. This is useful for nearly positive definite symmetric matrices that are otherwise numerically difficult to Cholesky factor.

In particular, this function first attempts a Cholesky decomposition of the input matrix. If that decomposition fails, exponentially-increasing diagonal jitter is added to the matrix until either a Cholesky decomposition succeeds or until the maximum specified number of iterations is reached.

This function is similar in spirit to a true modified Cholesky factorization ([1], [2]). However, it does not use pivoting or other strategies to ensure stability, so may not work well for e.g. ill-conditioned matrices. Further, this function may perform multiple Cholesky factorizations, while a true modified Cholesky can be done with only slightly more work than a single decomposition.


[1]: Nicholas Higham. What is a modified Cholesky factorization?

[2]: Sheung Hun Cheng and Nicholas Higham, A Modified Cholesky Algorithm Based on a Symmetric Indefinite Factorization, SIAM J. Matrix Anal. Appl. 19(4), 1097–1110, 1998.

matrix A batch of symmetric square matrices, with shape [..., n, n].
jitter Initial jitter to add to the diagnoal. Default: 1e-6, unless matrix.dtype is float64, in which case the default is 1e-10.
max_iters Maximum number of times to retry the Cholesky decomposition with larger diagonal jitter. Default: 5.
name Python str name prefixed to Ops created by this function. Default value: 'retrying_cholesky'.

triangular_factor A Tensor with shape [..., n, n]. The lower triangular Cholesky factor, modified as above. If the Cholesky decomposition failed for a batch member, then all lower triangular entries returned for that batch member will be NaN.
diagonal_shift A tensor of shape [...]. diag_shift[i] is the value added to the diagonal of matrix[i] in computing triangular_factor[i].