View source on GitHub
|
An inference sampler that randomly samples from the output distribution.
Inherits From: GreedyEmbeddingSampler, Sampler
tfa.seq2seq.SampleEmbeddingSampler(
embedding_fn: Optional[Callable] = None,
softmax_temperature: Optional[TensorLike] = None,
seed: Optional[TensorLike] = None
)
Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.
Raises | |
|---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
Methods
initialize
initialize(
embedding, start_tokens=None, end_token=None
)
Initialize the GreedyEmbeddingSampler.
| Args | |
|---|---|
embedding
|
tensor that contains embedding states matrix. It will be
used to generate generate outputs with start_tokens and end_token.
The embedding will be ignored if the embedding_fn has been provided
at init().
|
start_tokens
|
int32 vector shaped [batch_size], the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
| Returns | |
|---|---|
Tuple of two items: (finished, self.start_inputs).
|
| Raises | |
|---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
next_inputs
next_inputs(
time, outputs, state, sample_ids
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state
)
sample for SampleEmbeddingHelper.
View source on GitHub