tfp.substrates.jax.math.log_cumsum_exp

Computes log(cumsum(exp(x))).

This is a pure-TF implementation of tf.math.cumulative_logsumexp; unlike the built-in op, it supports XLA compilation. It uses a similar algorithmic technique (parallel prefix sum) as the built-in op, so it has similar numerics and asymptotic performace. However, this implemenentation currently has higher overhead, so it is significantly slower on smaller inputs (n < 10000).

x the Tensor to sum over.
axis int Tensor axis to sum over.
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'cumulative_logsumexp').

cumulative_logsumexp Tensor of the same shape as x.