View source on GitHub
|
Counts the number of occurrences of each value in an integer array arr.
tfp.substrates.jax.stats.count_integers(
arr,
weights=None,
minlength=None,
maxlength=None,
axis=None,
dtype=tf.int32,
name=None
)
Works like tf.math.bincount, but provides an axis kwarg that specifies
dimensions to reduce over. With
~axis = [i for i in range(arr.ndim) if i not in axis],
this function returns a Tensor of shape [K] + arr.shape[~axis].
If minlength and maxlength are not given, K = tf.reduce_max(arr) + 1
if arr is non-empty, and 0 otherwise.
If weights are non-None, then index i of the output stores the sum of the
value in weights at each index where the corresponding value in arr is
i.
Returns | |
|---|---|
A vector with the same dtype as weights or the given dtype. The bin
values.
|
View source on GitHub