tfp.substrates.jax.math.reduce_kahan_sum

Reduces the input tensor along the given axis using Kahan summation.

Returns both the total and the correction term, as a namedtuple, representing the sum in higher precision as total - correction.

A practical use-case is computing the difference of two large (magnitude) sums we expect to be nearly equal. If instead we take their difference as (s0.total - s1.total) - (s0.correction - s1.correction), we can retain more precision in computing their difference.

Note that total holds all the high-order bits of the sum, so the correction can be safely neglected if further enhanced precision computations are not required.

input_tensor The tensor to sum.
axis One of None, a Python int, or a sequence of Python int. The axes to be reduced. None is taken as "reduce all axes".
keepdims Python bool indicating whether we return a tensor with singleton dimensions in the reduced axes (True), or squeeze the axes out (default, False).
name Optional name for ops in scope.

reduced A Kahan(total, correction) namedtuple.