tfp.substrates.jax.math.low_rank_cholesky

Computes a low-rank approximation to the Cholesky decomposition.

This routine is similar to pivoted_cholesky, but works under JAX, at the cost of being slightly less numerically stable.

matrix Floating point Tensor batch of symmetric, positive definite matrices, or a tf.linalg.LinearOperator.
max_rank Scalar int Tensor, the rank at which to truncate the approximation.
trace_atol Scalar floating point Tensor (same dtype as matrix). If trace_atol > 0 and trace(matrix - LR * LR^t) < trace_atol, the output LR matrix is allowed to be of rank less than max_rank.
trace_rtol Scalar floating point Tensor (same dtype as matrix). If trace_rtol > 0 and trace(matrix - LR * LR^t) < trace_rtol * trace(matrix), the output LR matrix is allowed to be of rank less than max_rank.
name Optional name for the op.

A triplet (LR, r, residual_diag) of
LR a matrix such that LR * LR^t is approximately the input matrix. If matrix is of shape (b1, ..., bn, m, m), then LR will be of shape (b1, ..., bn, m, r) where r <= max_rank.
r the rank of LR. If r is < max_rank, then trace(matrix - LR * LR^t) < trace_atol, and
residual_diag The diagonal entries of matrix - LR * LR^t. This is returned because together with LR, it is useful for preconditioning the input matrix.