ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.contrib.seq2seq.SampleEmbeddingHelper

View source on GitHub

A helper for use during inference.

Inherits From: GreedyEmbeddingHelper

Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.

embedding A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. The returned tensor will be passed to the decoder input.
start_tokens int32 vector shaped [batch_size], the start tokens.
end_token int32 scalar, the token that marks end of decoding.
softmax_temperature (Optional) float32 scalar, value to divide the logits by before computing the softmax. Larger values (above 1.0) result in more random samples, while smaller values push the sampling distribution towards the argmax. Must be strictly greater than 0. Defaults to 1.0.
seed (Optional) The sampling seed.

ValueError if start_tokens is not a 1D tensor or end_token is not a scalar.

batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor.

sample_ids_dtype DType of tensor returned by sample.

Returns a DType.

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape.

Methods

initialize

View source

Returns (initial_finished, initial_inputs).

next_inputs

View source

next_inputs_fn for GreedyEmbeddingHelper.

sample

View source

sample for SampleEmbeddingHelper.