ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


View source on GitHub

Base abstract class that allows the user to customize sampling.

Inherits From: Helper

initialize_fn callable that returns (finished, next_inputs) for the first iteration.
sample_fn callable that takes (time, outputs, state) and emits tensor sample_ids.
next_inputs_fn callable that takes (time, outputs, state, sample_ids) and emits (finished, next_inputs, next_state).
sample_ids_shape Either a list of integers, or a 1-D Tensor of type int32, the shape of each value in the sample_ids batch. Defaults to a scalar.
sample_ids_dtype The dtype of the sample_ids tensor. Defaults to int32.

batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor.

sample_ids_dtype DType of tensor returned by sample.

Returns a DType.

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape.



View source

Returns (initial_finished, initial_inputs).


View source

Returns (finished, next_inputs, next_state).


View source

Returns sample_ids.