Computes the argmax where the allowed elements are given by a mask.
tf_agents.policies.utils.masked_argmax(
input_tensor: tf_agents.typing.types.Tensor
,
mask: tf_agents.typing.types.Tensor
,
output_type: tf.DType = tf.int32
) -> tf_agents.typing.types.Tensor
If a row of mask
contains all zeros, then this method will return -1 for the
corresponding row of input_tensor
.
Args |
input_tensor
|
Rank-2 Tensor of floats.
|
mask
|
0-1 valued Tensor of the same shape as input.
|
output_type
|
Integer type of the output.
|
Returns |
A Tensor of rank 1 and type output_type , with the masked argmax of every
row of input_tensor .
|