tfp.substrates.jax.math.log_sub_exp

Compute log(exp(max(x, y)) - exp(min(x, y))) in a numerically stable way.

Use return_sign=True unless x >= y, since we can't represent a negative in log-space.

x Float Tensor broadcastable with y.
y Float Tensor broadcastable with x.
return_sign Whether or not to return the second output value sign. If it is known that x >= y, this is unnecessary.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'log_sub_exp').

logsubexp Float Tensor of log(exp(max(x, y)) - exp(min(x, y))).
sign Float Tensor +/-1 indicating the sign of exp(x) - exp(y).