tfp.substrates.jax.math.hpsd_solve

Computes matrix^-1 rhs, where matrix is HPSD.

Given matrix and rhs, computes matrix^-1 rhs, where matrix is a Hermitian positive semi-definite matrix.

matrix Floating-point Tensor of shape [..., N, N]. Represents a Hermitian positive semi-definite matrix.
rhs Floating-point Tensor of shape [..., N, K].
cholesky_matrix (Optional) Floating-point Tensor of shape [..., N, N] that represents a Cholesky factor of matrix.

hpsd_solve Tensor of shape [..., N, K].