tfa.seq2seq.gather_tree

Calculates the full beams from the per-step ids and parent beam ids.

For a given beam, past the time step containing the first decoded end_token all values are filled in with end_token.

step_ids The predicted token IDs. A int32 Tensor of shape [max_time, batch_size, beam_width].
parent_ids The parent beam indices. A int32 Tensor of shape [max_time, batch_size, beam_width].
max_sequence_lengths The maximum sequence length of each batch. A int32 Tensor of shape [batch_size].
end_token The end token ID.

The reordered token IDs based on parent_ids.

InvalidArgumentError if parent_ids contains an invalid index.