ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf_agents.utils.tensor_normalizer.parallel_variance_calculation

Calculate the sufficient statistics (average & second moment) of two sets.

For better precision if sets are of different sizes, a should be the smaller and b the bigger.

For more details, see the parallel algorithm of Chan et al. at: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm

For stability we use kahan_summation to accumulate second moments.

Takes in the sufficient statistics for sets A and B and calculates the variance and sufficient statistics for the union of A and B.

If e.g. B is a single observation x_b, use n_b=1, avg_b = x_b, and m2_b = 0.

To get avg_a and m2_a from a tensor x of shape [n_a, ...], use:

n_a = tf.shape(x)[0]
avg_a = tf.math.reduce_mean(x, axis=[0])
m2_a = tf.math.reduce_sum(tf.math.squared_difference(t, avg_a), axis=[0])

n_a Number of elements in A.
avg_a The sample average of A.
m2_a The sample second moment of A.
n_b Number of elements in B.
avg_b The sample average of B.
m2_b The sample second moment of B.
m2_b_c Carry for accumulation of the sample second moment of B.

A tuple (n_ab, avg_ab, m2_ab, m2_ab_c) such that var_ab, the variance of A|B, may be calculated via var_ab = m2_ab / n_ab, and the sample variance assample_var_ab = m2_ab / (n_ab - 1).