tfp.experimental.substrates.jax.stats.log_average_probs

Computes log(average(to_probs(logits))) in a numerically stable manner.

The meaning of to_probs is controlled by the event_axis argument. When event_axis is None, to_probs = tf.math.sigmoid and otherwise to_probs = lambda x: tf.math.log_softmax(x, axis=event_axis).

sample_axis and event_axis should have a null intersection. This requirement is always verified when validate_args is True.

logits A float Tensor representing logits.
sample_axis Scalar or vector Tensor designating axis holding samples, or None (meaning all axis hold samples). Default value: 0 (leftmost dimension).
event_axis Scalar or vector Tensor designating the axis representing categorical logits. Default value: None (i.e., Bernoulli logits).
keepdims Boolean. Whether to keep the sample axis as singletons. Default value: False (i.e., squeeze the reduced dimensions).
validate_args Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs. Default value: False (i.e., do not validate args).
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'log_average_probs').

log_avg_probs The natural log of the average of probs computed from logits.