View source on GitHub
|
A inference sampler that takes the maximum from the output distribution.
Inherits From: Sampler
tfa.seq2seq.GreedyEmbeddingSampler(
embedding_fn: Optional[Callable] = None
)
Used in the notebooks
| Used in the tutorials |
|---|
Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.
Args | |
|---|---|
embedding_fn
|
A optional callable that takes a vector tensor of ids
(argmax ids). The returned tensor will be passed to the decoder
input. Default to use tf.nn.embedding_lookup.
|
Methods
initialize
initialize(
embedding, start_tokens=None, end_token=None
)
Initialize the GreedyEmbeddingSampler.
| Args | |
|---|---|
embedding
|
tensor that contains embedding states matrix. It will be
used to generate generate outputs with start_tokens and end_token.
The embedding will be ignored if the embedding_fn has been provided
at init().
|
start_tokens
|
int32 vector shaped [batch_size], the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
| Returns | |
|---|---|
Tuple of two items: (finished, self.start_inputs).
|
| Raises | |
|---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
next_inputs
next_inputs(
time, outputs, state, sample_ids
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state
)
sample for GreedyEmbeddingHelper.
View source on GitHub