View source on GitHub |
Returns lgamma(y) - lgamma(x + y), accurately.
tfp.substrates.jax.math.log_gamma_difference(
x, y, name=None
)
This is more accurate than subtracting lgammas directly because lgamma grows
as x log(x) - x + o(x)
, and thus subtracting the value of lgamma for two
close, large arguments incurs catastrophic cancellation.
When y >= 8
, the method is to partition lgamma into the Stirling
approximation and the correction log_gamma_correction
, symbolically cancel
the former, and compute and subtract the latter.
Args | |
---|---|
x
|
Floating-point Tensor. x should be non-negative, and elementwise no
more than y .
|
y
|
Floating-point Tensor. y should be positive.
|
name
|
Optional Python str naming the operation.
|
Returns | |
---|---|
lgamma_diff
|
Floating-point Tensor, the difference lgamma(y) - lgamma(x+y), computed elementwise. |