tfr.utils.gather_per_row
Gathers the values from input tensor based on per-row indices.
tfr.utils.gather_per_row(
inputs, indices
)
Example Usage:
scores = [[1., 3., 2.], [1., 2., 3.]]
indices = [[1, 2], [2, 1]]
tfr.utils.gather_per_row(scores, indices)
Returns [[3., 2.], [3., 2.]]
Args |
inputs
|
(tf.Tensor) A tensor of shape [batch_size, list_size] or
[batch_size, list_size, feature_dims].
|
indices
|
(tf.Tensor) A tensor of shape [batch_size, size] of positions to
gather inputs from. Each index corresponds to a row entry in input_tensor.
|
Returns |
A tensor of values gathered from inputs, of shape [batch_size, size] or
[batch_size, size, feature_dims], depending on whether the input was 2D or
3D.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-08-18 UTC.
[null,null,["Last updated 2023-08-18 UTC."],[],[]]