Warning: This project is deprecated. TensorFlow Addons has stopped development, The project will only be providing minimal maintenance releases until May 2024. See the full announcement here or on github.


Interface for implementing sampling in seq2seq decoders.

Sampler classes implement the logic of sampling from the decoder output distribution and producing the inputs for the next decoding step. In most cases, they should not be used directly but passed to a tfa.seq2seq.BasicDecoder instance that will manage the sampling.

Here is an example using a training sampler directly to implement a custom decoding loop:

batch_size = 4
max_time = 7
hidden_size = 16

sampler = tfa.seq2seq.TrainingSampler()
cell = tf.keras.layers.LSTMCell(hidden_size)

input_tensors = tf.random.uniform([batch_size, max_time, hidden_size])
initial_finished, initial_inputs = sampler.initialize(input_tensors)

cell_input = initial_inputs
cell_state = cell.get_initial_state(initial_inputs)

for time_step in tf.range(max_time):
    cell_output, cell_state = cell(cell_input, cell_state)
    sample_ids = sampler.sample(time_step, cell_output, cell_state)
    finished, cell_input, cell_state = sampler.next_inputs(
        time_step, cell_output, cell_state, sample_ids)
    if tf.reduce_all(finished):

batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtype DType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().



View source

initialize the sampler with the input tensors.

This method must be invoked exactly once before calling other methods of the Sampler.

inputs A (structure of) input tensors, it could be a nested tuple or a single tensor.
**kwargs Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.

(initial_finished, initial_inputs).


View source

Returns (finished, next_inputs, next_state).


View source

Returns sample_ids.