Estimate variance using samples.
View aliases
Main aliases
tfp.substrates.jax.stats.variance(
x, sample_axis=0, keepdims=False, name=None
)
Given N
samples of scalar valued random variable X
, variance may
be estimated as
Var[X] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(X_n - Xbar)}
Xbar := N^{-1} sum_{n=1}^N X_n
x = tf.random.stateless_normal(shape=(100, 2, 3))
# var[i, j] is the sample variance of the (i, j) batch member of x.
var = tfp.stats.variance(x, sample_axis=0)
Notice we divide by N
(the numpy default), which does not create NaN
when N = 1
, but is slightly biased.