View source on GitHub
|
Returns the indices of non-zero elements, or multiplexes x and y.
tf.where(
condition, x=None, y=None, name=None
)
This operation has two modes:
- Return the indices of non-zero elements - When only
conditionis provided the result is anint64tensor where each row is the index of a non-zero element ofcondition. The result's shape is[tf.math.count_nonzero(condition), tf.rank(condition)]. - Multiplex
xandy- When bothxandyare provided the result has the shape ofx,y, andconditionbroadcast together. The result is taken fromxwhereconditionis non-zero orywhereconditionis zero.
1. Return the indices of non-zero elements
If x and y are not provided (both are None):
tf.where will return the indices of condition that are non-zero,
in the form of a 2-D tensor with shape [n, d], where n is the number of
non-zero elements in condition (tf.count_nonzero(condition)), and d is
the number of axes of condition (tf.rank(condition)).
Indices are output in row-major order. The condition can have a dtype of
tf.bool, or any numeric dtype.
Here condition is a 1-axis bool tensor with 2 True values. The result
has a shape of [2,1]
tf.where([True, False, False, True]).numpy()array([[0],[3]])
Here condition is a 2-axis integer tensor, with 3 non-zero values. The
result has a shape of [3, 2].
tf.where([[1, 0, 0], [1, 0, 1]]).numpy()array([[0, 0],[1, 0],[1, 2]])
Here condition is a 3-axis float tensor, with 5 non-zero values. The output
shape is [5, 3].
float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],[[0, 0], [0, 0], [99, 0]]]tf.where(float_tensor).numpy()array([[0, 0, 0],[0, 1, 1],[0, 2, 0],[0, 2, 1],[1, 2, 0]])
These indices are the same that tf.sparse.SparseTensor would use to
represent the condition tensor:
sparse = tf.sparse.from_dense(float_tensor)sparse.indices.numpy()array([[0, 0, 0],[0, 1, 1],[0, 2, 0],[0, 2, 1],[1, 2, 0]])
A complex number is considered non-zero if either the real or imaginary component is non-zero:
tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()array([[1],[2],[3]])
2. Multiplex x and y
If x and y are also provided (both have non-None values) the condition
tensor acts as a mask that chooses whether the corresponding
element / row in the output should be taken from x (if the element in
condition is True) or y (if it is False).
The shape of the result is formed by
broadcasting
together the shapes of condition, x, and y.
When all three inputs have the same size, each is handled element-wise.
tf.where([True, False, False, True],[1, 2, 3, 4],[100, 200, 300, 400]).numpy()array([ 1, 200, 300, 4], dtype=int32)
There are two main rules for broadcasting:
- If a tensor has fewer axes than the others, length-1 axes are added to the left of the shape.
- Axes with length-1 are streched to match the coresponding axes of the other tensors.
A length-1 vector is streched to match the other vectors:
tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()array([ 1, 100, 100, 4], dtype=int32)
A scalar is expanded to match the other arguments:
tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()array([[ 1, 100], [100, 4]], dtype=int32)tf.where([[True, False], [False, True]], 1, 100).numpy()array([[ 1, 100], [100, 1]], dtype=int32)
A scalar condition returns the complete x or y tensor, with
broadcasting applied.
tf.where(True, [1, 2, 3, 4], 100).numpy()array([1, 2, 3, 4], dtype=int32)tf.where(False, [1, 2, 3, 4], 100).numpy()array([100, 100, 100, 100], dtype=int32)
For a non-trivial example of broadcasting, here condition has a shape of
[3], x has a shape of [3,3], and y has a shape of [3,1].
Broadcasting first expands the shape of condition to [1,3]. The final
broadcast shape is [3,3]. condition will select columns from x and y.
Since y only has one column, all columns from y will be identical.
tf.where([True, False, True],x=[[1, 2, 3],[4, 5, 6],[7, 8, 9]],y=[[100],[200],[300]]).numpy()array([[ 1, 100, 3],[ 4, 200, 6],[ 7, 300, 9]], dtype=int32)
Note that if the gradient of either branch of the tf.where generates
a NaN, then the gradient of the entire tf.where will be NaN. This is
because the gradient calculation for tf.where combines the two branches, for
performance reasons.
A workaround is to use an inner tf.where to ensure the function has
no asymptote, and to avoid computing a value whose gradient is NaN by
replacing dangerous inputs with safe inputs.
Instead of this,
x = tf.constant(0., dtype=tf.float32)with tf.GradientTape() as tape:tape.watch(x)y = tf.where(x < 1., 0., 1. / x)print(tape.gradient(y, x))tf.Tensor(nan, shape=(), dtype=float32)
Although, the 1. / x values are never used, its gradient is a NaN when
x = 0. Instead, we should guard that with another tf.where
x = tf.constant(0., dtype=tf.float32)with tf.GradientTape() as tape:tape.watch(x)safe_x = tf.where(tf.equal(x, 0.), 1., x)y = tf.where(x < 1., 0., 1. / safe_x)print(tape.gradient(y, x))tf.Tensor(0.0, shape=(), dtype=float32)
See also:
tf.sparse- The indices returned by the first form oftf.wherecan be useful intf.sparse.SparseTensorobjects.tf.gather_nd,tf.scatter_nd, and related ops - Given the list of indices returned fromtf.wherethescatterandgatherfamily of ops can be used fetch values or insert values at those indices.tf.strings.length-tf.stringis not an allowed dtype for thecondition. Use the string length instead.
Args | |
|---|---|
condition
|
A tf.Tensor of dtype bool, or any numeric dtype. condition
must have dtype bool when x and y are provided.
|
x
|
If provided, a Tensor which is of the same type as y, and has a shape
broadcastable with condition and y.
|
y
|
If provided, a Tensor which is of the same type as x, and has a shape
broadcastable with condition and x.
|
name
|
A name of the operation (optional). |
Returns | |
|---|---|
If x and y are provided:
A Tensor with the same type as x and y, and shape that
is broadcast from condition, x, and y.
Otherwise, a Tensor with shape [tf.math.count_nonzero(condition),
tf.rank(condition)].
|
Raises | |
|---|---|
ValueError
|
When exactly one of x or y is non-None, or the shapes
are not all broadcastable.
|
View source on GitHub