View source on GitHub
  
 | 
Bin values into discrete intervals.
tfp.substrates.jax.stats.find_bins(
    x,
    edges,
    extend_lower_interval=False,
    extend_upper_interval=False,
    dtype=None,
    name=None
)
Given edges = [c0, ..., cK], defining intervals
I_0 = [c0, c1), I_1 = [c1, c2), ..., I_{K-1} = [c_{K-1}, cK],
This function returns bins, such that x[i] lies within I_{bins[i]}.
Returns | |
|---|---|
bins
 | 
Tensor with same shape as x and dtype.
Has whole number values.  bins[i] = k means the x[i] falls into the
kth bin, ie, edges[bins[i]] <= x[i] < edges[bins[i] + 1].
 | 
Raises | |
|---|---|
ValueError
 | 
 If edges.shape[0] is determined to be less than 2.
 | 
Examples
Cut a 1-D array
x = [0., 5., 6., 10., 20.]
edges = [0., 5., 10.]
tfp.stats.find_bins(x, edges)
==> [0., 1., 1., 1., np.nan]
Cut x into its deciles
x = tf.random.stateless_uniform(shape=(100, 200))
decile_edges = tfp.stats.quantiles(x, num_quantiles=10)
bins = tfp.stats.find_bins(x, edges=decile_edges)
bins.shape
==> (100, 200)
tf.reduce_mean(bins == 0.)
==> approximately 0.1
tf.reduce_mean(bins == 1.)
==> approximately 0.1
    View source on GitHub