Return the indices of non-zero elements - When only
condition is provided the result is an int64 tensor where each row is
the index of a non-zero element of condition. The result's shape
is [tf.math.count_nonzero(condition), tf.rank(condition)].
Multiplex x and y - When both x and y are provided the
result has the shape of x, y, and condition broadcast together. The
result is taken from x where condition is non-zero
or y where condition is zero.
1. Return the indices of non-zero elements
If x and y are not provided (both are None):
tf.where will return the indices of condition that are non-zero,
in the form of a 2-D tensor with shape [n, d], where n is the number of
non-zero elements in condition (tf.count_nonzero(condition)), and d is
the number of axes of condition (tf.rank(condition)).
Indices are output in row-major order. The condition can have a dtype of
tf.bool, or any numeric dtype.
Here condition is a 1-axis bool tensor with 2 True values. The result
has a shape of [2,1]
If x and y are also provided (both have non-None values) the condition
tensor acts as a mask that chooses whether the corresponding
element / row in the output should be taken from x (if the element in
condition is True) or y (if it is False).
The shape of the result is formed by
broadcasting
together the shapes of condition, x, and y.
When all three inputs have the same size, each is handled element-wise.
For a non-trivial example of broadcasting, here condition has a shape of
[3], x has a shape of [3,3], and y has a shape of [3,1].
Broadcasting first expands the shape of condition to [1,3]. The final
broadcast shape is [3,3]. condition will select columns from x and y.
Since y only has one column, all columns from y will be identical.
Note that if the gradient of either branch of the tf.where generates
a NaN, then the gradient of the entire tf.where will be NaN. This is
because the gradient calculation for tf.where combines the two branches, for
performance reasons.
A workaround is to use an inner tf.where to ensure the function has
no asymptote, and to avoid computing a value whose gradient is NaN by
replacing dangerous inputs with safe inputs.
tf.gather_nd, tf.scatter_nd, and related ops - Given the
list of indices returned from tf.where the scatter and gather family
of ops can be used fetch values or insert values at those indices.
A tf.Tensor of dtype bool, or any numeric dtype. condition
must have dtype bool when x and y are provided.
x
If provided, a Tensor which is of the same type as y, and has a shape
broadcastable with condition and y.
y
If provided, a Tensor which is of the same type as x, and has a shape
broadcastable with condition and x.
name
A name of the operation (optional).
Returns
If x and y are provided:
A Tensor with the same type as x and y, and shape that
is broadcast from condition, x, and y.
Otherwise, a Tensor with shape [tf.math.count_nonzero(condition),
tf.rank(condition)].
Raises
ValueError
When exactly one of x or y is non-None, or the shapes
are not all broadcastable.
[null,null,["Last updated 2023-05-09 UTC."],[],[],null,["# tfp.experimental.distributions.marginal_fns.ps.where\n\n\u003cbr /\u003e\n\nReturns the indices of non-zero elements, or multiplexes `x` and `y`. \n\n tfp.experimental.distributions.marginal_fns.ps.where(\n condition, x=None, y=None, name=None\n )\n\nThis operation has two modes:\n\n1. **Return the indices of non-zero elements** - When only `condition` is provided the result is an `int64` tensor where each row is the index of a non-zero element of `condition`. The result's shape is `[tf.math.count_nonzero(condition), tf.rank(condition)]`.\n2. **Multiplex `x` and `y`** - When both `x` and `y` are provided the result has the shape of `x`, `y`, and `condition` broadcast together. The result is taken from `x` where `condition` is non-zero or `y` where `condition` is zero.\n\n#### 1. Return the indices of non-zero elements\n\n| **Note:** In this mode `condition` can have a dtype of `bool` or any numeric dtype.\n\nIf `x` and `y` are not provided (both are None):\n\n[`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) will return the indices of `condition` that are non-zero,\nin the form of a 2-D tensor with shape `[n, d]`, where `n` is the number of\nnon-zero elements in `condition` (`tf.count_nonzero(condition)`), and `d` is\nthe number of axes of `condition` ([`tf.rank(condition)`](https://www.tensorflow.org/api_docs/python/tf/rank)).\n\nIndices are output in row-major order. The `condition` can have a `dtype` of\n[`tf.bool`](https://www.tensorflow.org/api_docs/python/tf#bool), or any numeric `dtype`.\n\nHere `condition` is a 1-axis `bool` tensor with 2 `True` values. The result\nhas a shape of `[2,1]` \n\n tf.where([True, False, False, True]).numpy()\n array([[0],\n [3]])\n\nHere `condition` is a 2-axis integer tensor, with 3 non-zero values. The\nresult has a shape of `[3, 2]`. \n\n tf.where([[1, 0, 0], [1, 0, 1]]).numpy()\n array([[0, 0],\n [1, 0],\n [1, 2]])\n\nHere `condition` is a 3-axis float tensor, with 5 non-zero values. The output\nshape is `[5, 3]`. \n\n float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],\n [[0, 0], [0, 0], [99, 0]]]\n tf.where(float_tensor).numpy()\n array([[0, 0, 0],\n [0, 1, 1],\n [0, 2, 0],\n [0, 2, 1],\n [1, 2, 0]])\n\nThese indices are the same that [`tf.sparse.SparseTensor`](https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor) would use to\nrepresent the condition tensor: \n\n sparse = tf.sparse.from_dense(float_tensor)\n sparse.indices.numpy()\n array([[0, 0, 0],\n [0, 1, 1],\n [0, 2, 0],\n [0, 2, 1],\n [1, 2, 0]])\n\nA complex number is considered non-zero if either the real or imaginary\ncomponent is non-zero: \n\n tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()\n array([[1],\n [2],\n [3]])\n\n#### 2. Multiplex `x` and `y`\n\n| **Note:** In this mode `condition` must have a dtype of `bool`.\n\nIf `x` and `y` are also provided (both have non-None values) the `condition`\ntensor acts as a mask that chooses whether the corresponding\nelement / row in the output should be taken from `x` (if the element in\n`condition` is `True`) or `y` (if it is `False`).\n\nThe shape of the result is formed by\n[broadcasting](https://docs.scipy.org/doc/numpy/reference/ufuncs.html)\ntogether the shapes of `condition`, `x`, and `y`.\n\nWhen all three inputs have the same size, each is handled element-wise. \n\n tf.where([True, False, False, True],\n [1, 2, 3, 4],\n [100, 200, 300, 400]).numpy()\n array([ 1, 200, 300, 4], dtype=int32)\n\nThere are two main rules for broadcasting:\n\n1. If a tensor has fewer axes than the others, length-1 axes are added to the left of the shape.\n2. Axes with length-1 are streched to match the coresponding axes of the other tensors.\n\nA length-1 vector is streched to match the other vectors: \n\n tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()\n array([ 1, 100, 100, 4], dtype=int32)\n\nA scalar is expanded to match the other arguments: \n\n tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()\n array([[ 1, 100], [100, 4]], dtype=int32)\n tf.where([[True, False], [False, True]], 1, 100).numpy()\n array([[ 1, 100], [100, 1]], dtype=int32)\n\nA scalar `condition` returns the complete `x` or `y` tensor, with\nbroadcasting applied. \n\n tf.where(True, [1, 2, 3, 4], 100).numpy()\n array([1, 2, 3, 4], dtype=int32)\n tf.where(False, [1, 2, 3, 4], 100).numpy()\n array([100, 100, 100, 100], dtype=int32)\n\nFor a non-trivial example of broadcasting, here `condition` has a shape of\n`[3]`, `x` has a shape of `[3,3]`, and `y` has a shape of `[3,1]`.\nBroadcasting first expands the shape of `condition` to `[1,3]`. The final\nbroadcast shape is `[3,3]`. `condition` will select columns from `x` and `y`.\nSince `y` only has one column, all columns from `y` will be identical. \n\n tf.where([True, False, True],\n x=[[1, 2, 3],\n [4, 5, 6],\n [7, 8, 9]],\n y=[[100],\n [200],\n [300]]\n ).numpy()\n array([[ 1, 100, 3],\n [ 4, 200, 6],\n [ 7, 300, 9]], dtype=int32)\n\nNote that if the gradient of either branch of the [`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) generates\na `NaN`, then the gradient of the entire [`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) will be `NaN`. This is\nbecause the gradient calculation for [`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) combines the two branches, for\nperformance reasons.\n\nA workaround is to use an inner [`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) to ensure the function has\nno asymptote, and to avoid computing a value whose gradient is `NaN` by\nreplacing dangerous inputs with safe inputs.\n\nInstead of this, \n\n x = tf.constant(0., dtype=tf.float32)\n with tf.GradientTape() as tape:\n tape.watch(x)\n y = tf.where(x \u003c 1., 0., 1. / x)\n print(tape.gradient(y, x))\n tf.Tensor(nan, shape=(), dtype=float32)\n\nAlthough, the `1. / x` values are never used, its gradient is a `NaN` when\n`x = 0`. Instead, we should guard that with another [`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) \n\n x = tf.constant(0., dtype=tf.float32)\n with tf.GradientTape() as tape:\n tape.watch(x)\n safe_x = tf.where(tf.equal(x, 0.), 1., x)\n y = tf.where(x \u003c 1., 0., 1. / safe_x)\n print(tape.gradient(y, x))\n tf.Tensor(0.0, shape=(), dtype=float32)\n\n#### See also:\n\n- [`tf.sparse`](https://www.tensorflow.org/api_docs/python/tf/sparse) - The indices returned by the first form of [`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) can be useful in [`tf.sparse.SparseTensor`](https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor) objects.\n- [`tf.gather_nd`](https://www.tensorflow.org/api_docs/python/tf/gather_nd), [`tf.scatter_nd`](https://www.tensorflow.org/api_docs/python/tf/scatter_nd), and related ops - Given the list of indices returned from [`tf.where`](https://www.tensorflow.org/api_docs/python/tf/where) the `scatter` and `gather` family of ops can be used fetch values or insert values at those indices.\n- [`tf.strings.length`](https://www.tensorflow.org/api_docs/python/tf/strings/length) - [`tf.string`](https://www.tensorflow.org/api_docs/python/tf#string) is not an allowed dtype for the `condition`. Use the string length instead.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `condition` | A [`tf.Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor) of dtype bool, or any numeric dtype. `condition` must have dtype `bool` when `x` and `y` are provided. |\n| `x` | If provided, a Tensor which is of the same type as `y`, and has a shape broadcastable with `condition` and `y`. |\n| `y` | If provided, a Tensor which is of the same type as `x`, and has a shape broadcastable with `condition` and `x`. |\n| `name` | A name of the operation (optional). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| If `x` and `y` are provided: A `Tensor` with the same type as `x` and `y`, and shape that is broadcast from `condition`, `x`, and `y`. Otherwise, a `Tensor` with shape `[tf.math.count_nonzero(condition), tf.rank(condition)]`. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|--------------------------------------------------------------------------------------|\n| `ValueError` | When exactly one of `x` or `y` is non-None, or the shapes are not all broadcastable. |\n\n\u003cbr /\u003e"]]