View source on GitHub |
Returns batched one-hot vectors.
tfa.seq2seq.hardmax(
logits: tfa.types.TensorLike
,
name: Optional[str] = None
) -> tf.Tensor
The depth index containing the 1
is that of the maximum logit value.
Args | |
---|---|
logits
|
A batch tensor of logit values. |
name
|
Name to use when creating ops. |
Returns | |
---|---|
A batched one-hot tensor. |