Search for sequence of subtoken ids with the largest probability.
tfm.nlp.ops.sequence_beam_search(
symbols_to_logits_fn,
initial_ids,
initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False,
dtype='float32',
noise_multiplier: float = 0.0,
decoding_name=None
)
Args |
symbols_to_logits_fn
|
A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: ids -> A tensor with
shape [batch_size * beam_size, index]. index -> A scalar. cache -> A
nested dictionary of tensors [batch_size * beam_size, ...]. The function
must return a tuple of logits and new cache: logits -> A tensor with shape
[batch * beam_size, vocab_size]. new cache -> A nested dictionary with the
same shape/structure as the inputted cache.
|
initial_ids
|
An int32 tensor with shape [batch_size]. Starting ids for each
batch item.
|
initial_cache
|
A dictionary, containing starting decoder variables
information.
|
vocab_size
|
An integer, the size of tokens.
|
beam_size
|
An integer, the number of beams.
|
alpha
|
A float, defining the strength of length normalization.
|
max_decode_length
|
An integer, the maximum length to decoded a sequence.
|
eos_id
|
An integer, ID of eos token, used to determine when a sequence has
finished.
|
padded_decode
|
A bool, indicating if max_sequence_length padding is used for
beam search.
|
dtype
|
A tensorflow data type used for score computation. The default is
tf.float32.
|
noise_multiplier
|
The amount of noise.
|
decoding_name
|
an optional name for the decoding loop tensors.
|
Returns |
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
|