The params may also have any shape. gather can select slices
across any axis depending on the axis argument (which defaults to 0).
Below it is used to gather first rows, then columns from a matrix:
This is because each index takes a slice from params, and
places it at the corresponding location in the output. For the above example
# For any location in indicesa,b=0,1tf.reduce_all(# the corresponding slice of the resultresult[:,:,a,b,:]==# is equal to the slice of `params` along `axis` at the index.params[:,:,indices[a,b],:]).numpy()True
Batching:
The batch_dims argument lets you gather different items from each element
of a batch.
Using batch_dims=1 is equivalent to having an outer loop over the first
axis of params and indices:
This comes up naturally if you need to use the indices of an operation like
tf.argsort, or tf.math.top_k where the last dimension of the indices
indexes into the last dimension of input, at the corresponding location.
In this case you can use tf.gather(values, indices, batch_dims=-1).
See also:
tf.Tensor.getitem: The direct tensor index operation (t[]), handles
scalars and python-slices tensor[..., 7, 1:-1]
tf.scatter: A collection of operations similar to __setitem__
(t[i] = x)
tf.gather_nd: An operation similar to tf.gather but gathers across
multiple axis at once (it can gather elements of a matrix instead of rows
or columns)
tf.slice and tf.strided_slice: For lower level access to the
implementation of __getitem__'s python-slice handling (t[1:-1:2])
Args
params
The Tensor from which to gather values. Must be at least rank
axis + 1.
indices
The index Tensor. Must be one of the following types: int32,
int64. The values must be in range [0, params.shape[axis]).
validate_indices
Deprecated, does nothing. Indices are always validated on
CPU, never validated on GPU.
axis
A Tensor. Must be one of the following types: int32, int64. The
axis in params to gather indices from. Must be greater than or equal
to batch_dims. Defaults to the first non-batch dimension. Supports
negative indexes.
batch_dims
An integer. The number of batch dimensions. Must be less
than or equal to rank(indices).
[null,null,["Last updated 2023-10-06 UTC."],[],[],null,["# tf.gather\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/array_ops.py#L5252-L5266) |\n\nGather slices from params axis `axis` according to indices. (deprecated arguments) \n\n tf.gather(\n params, indices, validate_indices=None, axis=None, batch_dims=0, name=None\n )\n\n| **Deprecated:** SOME ARGUMENTS ARE DEPRECATED: `(validate_indices)`. They will be removed in a future version. Instructions for updating: The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.\n\nGather slices from `params` axis `axis` according to `indices`. `indices`\nmust be an integer tensor of any dimension (often 1-D).\n\n[`Tensor.`**getitem**](../tf/Tensor#__getitem__) works for scalars, [`tf.newaxis`](../tf#newaxis), and\n[python slices](https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing)\n\n[`tf.gather`](../tf/gather) extends indexing to handle tensors of indices.\n\nIn the simplest case it's identical to scalar indexing: \n\n params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])\n params[3].numpy()\n b'p3'\n tf.gather(params, 3).numpy()\n b'p3'\n\nThe most common case is to pass a single axis tensor of indices (this\ncan't be expressed as a python slice because the indices are not sequential): \n\n indices = [2, 0, 2, 5]\n tf.gather(params, indices).numpy()\n array([b'p2', b'p0', b'p2', b'p5'], dtype=object)\n\nThe indices can have any shape. When the `params` has 1 axis, the\noutput shape is equal to the input shape: \n\n tf.gather(params, [[2, 0], [2, 5]]).numpy()\n array([[b'p2', b'p0'],\n [b'p2', b'p5']], dtype=object)\n\nThe `params` may also have any shape. `gather` can select slices\nacross any axis depending on the `axis` argument (which defaults to 0).\nBelow it is used to gather first rows, then columns from a matrix: \n\n params = tf.constant([[0, 1.0, 2.0],\n [10.0, 11.0, 12.0],\n [20.0, 21.0, 22.0],\n [30.0, 31.0, 32.0]])\n tf.gather(params, indices=[3,1]).numpy()\n array([[30., 31., 32.],\n [10., 11., 12.]], dtype=float32)\n tf.gather(params, indices=[2,1], axis=1).numpy()\n array([[ 2., 1.],\n [12., 11.],\n [22., 21.],\n [32., 31.]], dtype=float32)\n\nMore generally: The output shape has the same shape as the input, with the\nindexed-axis replaced by the shape of the indices. \n\n def result_shape(p_shape, i_shape, axis=0):\n return p_shape[:axis] + i_shape + p_shape[axis+1:]\n\n result_shape([1, 2, 3], [], axis=1)\n [1, 3]\n result_shape([1, 2, 3], [7], axis=1)\n [1, 7, 3]\n result_shape([1, 2, 3], [7, 5], axis=1)\n [1, 7, 5, 3]\n\nHere are some examples: \n\n params.shape.as_list()\n [4, 3]\n indices = tf.constant([[0, 2]])\n tf.gather(params, indices=indices, axis=0).shape.as_list()\n [1, 2, 3]\n tf.gather(params, indices=indices, axis=1).shape.as_list()\n [4, 1, 2]\n\n params = tf.random.normal(shape=(5, 6, 7, 8))\n indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)\n result = tf.gather(params, indices, axis=2)\n result.shape.as_list()\n [5, 6, 10, 11, 8]\n\nThis is because each index takes a slice from `params`, and\nplaces it at the corresponding location in the output. For the above example \n\n # For any location in indices\n a, b = 0, 1\n tf.reduce_all(\n # the corresponding slice of the result\n result[:, :, a, b, :] ==\n # is equal to the slice of `params` along `axis` at the index.\n params[:, :, indices[a, b], :]\n ).numpy()\n True\n\n### Batching:\n\nThe `batch_dims` argument lets you gather different items from each element\nof a batch.\n\nUsing `batch_dims=1` is equivalent to having an outer loop over the first\naxis of `params` and `indices`: \n\n params = tf.constant([\n [0, 0, 1, 0, 2],\n [3, 0, 0, 0, 4],\n [0, 5, 0, 6, 0]])\n indices = tf.constant([\n [2, 4],\n [0, 4],\n [1, 3]])\n\n tf.gather(params, indices, axis=1, batch_dims=1).numpy()\n array([[1, 2],\n [3, 4],\n [5, 6]], dtype=int32)\n\n#### This is equivalent to:\n\n def manually_batched_gather(params, indices, axis):\n batch_dims=1\n result = []\n for p,i in zip(params, indices):\n r = tf.gather(p, i, axis=axis-batch_dims)\n result.append(r)\n return tf.stack(result)\n manually_batched_gather(params, indices, axis=1).numpy()\n array([[1, 2],\n [3, 4],\n [5, 6]], dtype=int32)\n\nHigher values of `batch_dims` are equivalent to multiple nested loops over\nthe outer axes of `params` and `indices`. So the overall shape function is \n\n def batched_result_shape(p_shape, i_shape, axis=0, batch_dims=0):\n return p_shape[:axis] + i_shape[batch_dims:] + p_shape[axis+1:]\n\n batched_result_shape(\n p_shape=params.shape.as_list(),\n i_shape=indices.shape.as_list(),\n axis=1,\n batch_dims=1)\n [3, 2]\n\n tf.gather(params, indices, axis=1, batch_dims=1).shape.as_list()\n [3, 2]\n\nThis comes up naturally if you need to use the indices of an operation like\n[`tf.argsort`](../tf/argsort), or [`tf.math.top_k`](../tf/math/top_k) where the last dimension of the indices\nindexes into the last dimension of input, at the corresponding location.\nIn this case you can use `tf.gather(values, indices, batch_dims=-1)`.\n\n#### See also:\n\n- [`tf.Tensor.`**getitem**](../tf/Tensor#__getitem__): The direct tensor index operation (`t[]`), handles scalars and python-slices `tensor[..., 7, 1:-1]`\n- `tf.scatter`: A collection of operations similar to `__setitem__` (`t[i] = x`)\n- [`tf.gather_nd`](../tf/gather_nd): An operation similar to [`tf.gather`](../tf/gather) but gathers across multiple axis at once (it can gather elements of a matrix instead of rows or columns)\n- [`tf.boolean_mask`](../tf/boolean_mask), [`tf.where`](../tf/where): Binary indexing.\n- [`tf.slice`](../tf/slice) and [`tf.strided_slice`](../tf/strided_slice): For lower level access to the implementation of `__getitem__`'s python-slice handling (`t[1:-1:2]`)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|--------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `params` | The `Tensor` from which to gather values. Must be at least rank `axis + 1`. |\n| `indices` | The index `Tensor`. Must be one of the following types: `int32`, `int64`. The values must be in range `[0, params.shape[axis])`. |\n| `validate_indices` | Deprecated, does nothing. Indices are always validated on CPU, never validated on GPU. \u003cbr /\u003e | **Caution:** On CPU, if an out of bound index is found, an error is raised. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value. |\n| `axis` | A `Tensor`. Must be one of the following types: `int32`, `int64`. The `axis` in `params` to gather `indices` from. Must be greater than or equal to `batch_dims`. Defaults to the first non-batch dimension. Supports negative indexes. |\n| `batch_dims` | An `integer`. The number of batch dimensions. Must be less than or equal to `rank(indices)`. |\n| `name` | A name for the operation (optional). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A `Tensor`. Has the same type as `params`. ||\n\n\u003cbr /\u003e"]]