tf.compat.v1.gather

Gather slices from params axis axis according to indices. (deprecated arguments)

Gather slices from params axis axis according to indices. indices must be an integer tensor of any dimension (often 1-D).

Tensor.getitem works for scalars, tf.newaxis, and python slices

tf.gather extends indexing to handle tensors of indices.

In the simplest case it's identical to scalar indexing:

params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
params[3].numpy()
b'p3'
tf.gather(params, 3).numpy()
b'p3'

The most common case is to pass a single axis tensor of indices (this can't be expressed as a python slice because the indices are not sequential):

indices = [2, 0, 2, 5]
tf.gather(params, indices).numpy()
array([b'p2', b'p0', b'p2', b'p5'], dtype=object)

The indices can have any shape. When the params has 1 axis, the output shape is equal to the input shape:

tf.gather(params, [[2, 0], [2, 5]]).numpy()
array([[b'p2', b'p0'],
       [b'p2', b'p5']], dtype=object)

The params may also have any shape. gather can select slices across any axis depending on the axis argument (which defaults to 0). Below it is used to gather first rows, then columns from a matrix:

params = tf.constant([[0, 1.0, 2.0],
                      [10.0, 11.0, 12.0],
                      [20.0, 21.0, 22.0],
                      [30.0, 31.0, 32.0]])
tf.gather(params, indices=[3,1]).numpy()
array([[30., 31., 32.],
       [10., 11., 12.]], dtype=float32)
tf.gather(params, indices=[2,1], axis=1).numpy()
array([[ 2.,  1.],
       [12., 11.],
       [22., 21.],
       [32., 31.]], dtype=float32)

More generally: The output shape has the same shape as the input, with the indexed-axis replaced by the shape of the indices.

def result_shape(p_shape, i_shape, axis=0):
  return p_shape[:axis] + i_shape + p_shape[axis+1:]

result_shape([1, 2, 3], [], axis=1)
[1, 3]
result_shape([1, 2, 3], [7], axis=1)
[1, 7, 3]
result_shape([1, 2, 3], [7, 5], axis=1)
[1, 7, 5, 3]

Here are some examples:

params.shape.as_list()
[4, 3]
indices = tf.constant([[0, 2]])
tf.gather(params, indices=indices, axis=0).shape.as_list()
[1, 2, 3]
tf.gather(params, indices=indices, axis=1).shape.as_list()
[4, 1, 2]
params = tf.random.normal(shape=(5, 6, 7, 8))
indices = tf.random.uniform(shape=(10, 11), maxval=7, dtype=tf.int32)
result = tf.gather(params, indices, axis=2)
result.shape.as_list()
[5, 6, 10, 11, 8]

This is because each index takes a slice from params, and places it at the corresponding location in the output. For the above example

# For any location in indices
a, b = 0, 1
tf.reduce_all(
    # the corresponding slice of the result
    result[:, :, a, b, :] ==
    # is equal to the slice of `params` along `axis` at the index.
    params[:, :, indices[a, b], :]
).numpy()
True

Batching:

The batch_dims argument lets you gather different items from each element of a batch.

Using batch_dims=1 is equivalent to having an outer loop over the first axis of params and indices:

params = tf.constant([
    [0, 0, 1, 0, 2],
    [3, 0, 0, 0, 4],
    [0, 5, 0, 6, 0]])