tfp.substrates.jax.stats.windowed_mean

Windowed estimates of mean.

Computes means among data in the Tensor x along the given windows:

result[i] = mean(x[low_indices[i]:high_indices[i]+1])

efficiently. To wit, if K is the size of low_indices and high_indices, and N is the size of x along the given axis, the computation takes O(K + N) work, O(log(N)) depth (the length of the longest series of operations that are performed sequentially), and only uses O(1) TensorFlow kernel invocations.

This function can be useful for assessing the behavior over time of trailing-window estimators from some iterative process, such as the last half of an MCMC chain.

Suppose x has shape Bx + [N] + E, where the Bx component has rank axis, and low_indices and high_indices broadcast to shape [M]. Then each element of low_indices and high_indices must be between 0 and N+1, and the shape of the output will be Bx + [M] + E. Batch shape in the indices is not currently supported.

The default windows are [0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ... This corresponds to analyzing x as though it were streaming, for example successive states of an MCMC sampler, and we were interested in the variance of the last half of the data at each point.

x A numeric Tensor holding N samples along the given axis, whose windowed means are desired.
low_indices An integer Tensor defining the lower boundary (inclusive) of each window. Default: elementwise half of high_indices.
high_indices An integer Tensor defining the upper boundary (exclusive) of each window. Must be broadcast-compatible with low_indices. Default: tf.range(1, N+1), i.e., N windows that each end in the corresponding datum from x (inclusive).
axis Scalar Tensor designating the axis holding samples. This is the axis of x along which we take windows, and therefore the axis that low_indices and high_indices index into. Other axes are treated in batch. Default value: 0 (leftmost dimension).
name Python str name prefixed to Ops created by this function. Default value: None (i.e., 'windowed_mean').

means A numeric Tensor holding the windowed means of x along the axis dimension.