tfp.substrates.jax.math.lbeta

Returns log(Beta(x, y)).

This is semantically equal to lgamma(x) + lgamma(y) - lgamma(x + y) but the method is more accurate for arguments above 8.

The reason for accuracy loss in the naive computation is catastrophic cancellation between the lgammas. This method avoids the numeric cancellation by explicitly decomposing lgamma into the Stirling approximation and an explicit log_gamma_correction, and cancelling the large terms from the Stirling analytically.

The computed gradients are the same as for the naive forward computation, because (i) digamma grows much slower than lgamma, so cancellations aren't as bad, and (ii) it's simpler and faster than trying to be more accurate.

References:

[1] DiDonato and Morris, "Significant Digit Computation of the Incomplete Beta Function Ratios", 1988. Technical report NSWC TR 88-365, Naval Surface Warfare Center (K33), Dahlgren, VA 22448-5000. Section IV, Auxiliary Functions. https://apps.dtic.mil/dtic/tr/fulltext/u2/a210118.pdf

x Floating-point Tensor.
y Floating-point Tensor.
name Optional Python str naming the operation.

lbeta Tensor of elementwise log beta(x, y).