Computes the argmax where the allowed elements are given by a mask.
tf_agents.policies.utils.masked_argmax(
input_tensor: tf_agents.typing.types.Tensor
,
mask: tf_agents.typing.types.Tensor
,
output_type: tf.DType = tf.int32
) -> tf_agents.typing.types.Tensor
If a row of mask
contains all zeros, then this method will return -1 for the
corresponding row of input_tensor
.
Args | |
---|---|
input_tensor
|
Rank-2 Tensor of floats. |
mask
|
0-1 valued Tensor of the same shape as input. |
output_type
|
Integer type of the output. |
Returns | |
---|---|
A Tensor of rank 1 and type output_type , with the masked argmax of every
row of input_tensor .
|