tfm.nlp.ops.sequence_beam_search
Search for sequence of subtoken ids with the largest probability.
tfm.nlp.ops.sequence_beam_search(
symbols_to_logits_fn,
initial_ids,
initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False,
dtype='float32',
noise_multiplier: float = 0.0,
decoding_name=None
)
Args |
symbols_to_logits_fn
|
A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: ids -> A tensor with
shape [batch_size * beam_size, index]. index -> A scalar. cache -> A
nested dictionary of tensors [batch_size * beam_size, ...]. The function
must return a tuple of logits and new cache: logits -> A tensor with shape
[batch * beam_size, vocab_size]. new cache -> A nested dictionary with the
same shape/structure as the inputted cache.
|
initial_ids
|
An int32 tensor with shape [batch_size]. Starting ids for each
batch item.
|
initial_cache
|
A dictionary, containing starting decoder variables
information.
|
vocab_size
|
An integer, the size of tokens.
|
beam_size
|
An integer, the number of beams.
|
alpha
|
A float, defining the strength of length normalization.
|
max_decode_length
|
An integer, the maximum length to decoded a sequence.
|
eos_id
|
An integer, ID of eos token, used to determine when a sequence has
finished.
|
padded_decode
|
A bool, indicating if max_sequence_length padding is used for
beam search.
|
dtype
|
A tensorflow data type used for score computation. The default is
tf.float32.
|
noise_multiplier
|
The amount of noise.
|
decoding_name
|
an optional name for the decoding loop tensors.
|
Returns |
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-02-02 UTC.
[null,null,["Last updated 2024-02-02 UTC."],[],[]]