Warning: This project is deprecated. TensorFlow Addons has stopped development,
The project will only be providing minimal maintenance releases until May 2024. See the full
announcement here or on
github.
tfa.seq2seq.Sampler
Stay organized with collections
Save and categorize content based on your preferences.
Interface for implementing sampling in seq2seq decoders.
Sampler classes implement the logic of sampling from the decoder output distribution
and producing the inputs for the next decoding step. In most cases, they should not be
used directly but passed to a tfa.seq2seq.BasicDecoder
instance that will manage the
sampling.
Here is an example using a training sampler directly to implement a custom decoding
loop:
batch_size = 4
max_time = 7
hidden_size = 16
sampler = tfa.seq2seq.TrainingSampler()
cell = tf.keras.layers.LSTMCell(hidden_size)
input_tensors = tf.random.uniform([batch_size, max_time, hidden_size])
initial_finished, initial_inputs = sampler.initialize(input_tensors)
cell_input = initial_inputs
cell_state = cell.get_initial_state(initial_inputs)
for time_step in tf.range(max_time):
cell_output, cell_state = cell(cell_input, cell_state)
sample_ids = sampler.sample(time_step, cell_output, cell_state)
finished, cell_input, cell_state = sampler.next_inputs(
time_step, cell_output, cell_state, sample_ids)
if tf.reduce_all(finished):
break
Attributes |
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor. The return value might not
available before the invocation of initialize(), in this case,
ValueError is raised.
|
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType. The return value might not available before the
invocation of initialize().
|
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a TensorShape . The return value might not available
before the invocation of initialize().
|
Methods
initialize
View source
@abc.abstractmethod
initialize(
inputs, **kwargs
)
initialize the sampler with the input tensors.
This method must be invoked exactly once before calling other
methods of the Sampler.
Args |
inputs
|
A (structure of) input tensors, it could be a nested tuple or
a single tensor.
|
**kwargs
|
Other kwargs for initialization. It could contain tensors
like mask for inputs, or non tensor parameter.
|
Returns |
(initial_finished, initial_inputs) .
|
View source
@abc.abstractmethod
next_inputs(
time, outputs, state, sample_ids
)
Returns (finished, next_inputs, next_state)
.
sample
View source
@abc.abstractmethod
sample(
time, outputs, state
)
Returns sample_ids
.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-05-25 UTC.
[null,null,["Last updated 2023-05-25 UTC."],[],[],null,["# tfa.seq2seq.Sampler\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/addons/blob/v0.20.0/tensorflow_addons/seq2seq/sampler.py#L29-L116) |\n\nInterface for implementing sampling in seq2seq decoders.\n\nSampler classes implement the logic of sampling from the decoder output distribution\nand producing the inputs for the next decoding step. In most cases, they should not be\nused directly but passed to a [`tfa.seq2seq.BasicDecoder`](../../tfa/seq2seq/BasicDecoder) instance that will manage the\nsampling.\n\nHere is an example using a training sampler directly to implement a custom decoding\nloop: \n\n batch_size = 4\n max_time = 7\n hidden_size = 16\n\n sampler = tfa.seq2seq.TrainingSampler()\n cell = tf.keras.layers.LSTMCell(hidden_size)\n\n input_tensors = tf.random.uniform([batch_size, max_time, hidden_size])\n initial_finished, initial_inputs = sampler.initialize(input_tensors)\n\n cell_input = initial_inputs\n cell_state = cell.get_initial_state(initial_inputs)\n\n for time_step in tf.range(max_time):\n cell_output, cell_state = cell(cell_input, cell_state)\n sample_ids = sampler.sample(time_step, cell_output, cell_state)\n finished, cell_input, cell_state = sampler.next_inputs(\n time_step, cell_output, cell_state, sample_ids)\n if tf.reduce_all(finished):\n break\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|--------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `batch_size` | Batch size of tensor returned by `sample`. \u003cbr /\u003e Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised. |\n| `sample_ids_dtype` | DType of tensor returned by `sample`. \u003cbr /\u003e Returns a DType. The return value might not available before the invocation of initialize(). |\n| `sample_ids_shape` | Shape of tensor returned by `sample`, excluding the batch dimension. \u003cbr /\u003e Returns a `TensorShape`. The return value might not available before the invocation of initialize(). |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `initialize`\n\n[View source](https://github.com/tensorflow/addons/blob/v0.20.0/tensorflow_addons/seq2seq/sampler.py#L62-L78) \n\n @abc.abstractmethod\n initialize(\n inputs, **kwargs\n )\n\ninitialize the sampler with the input tensors.\n\nThis method must be invoked exactly once before calling other\nmethods of the Sampler.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|------------|----------------------------------------------------------------------------------------------------------|\n| `inputs` | A (structure of) input tensors, it could be a nested tuple or a single tensor. |\n| `**kwargs` | Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| `(initial_finished, initial_inputs)`. ||\n\n\u003cbr /\u003e\n\n### `next_inputs`\n\n[View source](https://github.com/tensorflow/addons/blob/v0.20.0/tensorflow_addons/seq2seq/sampler.py#L85-L88) \n\n @abc.abstractmethod\n next_inputs(\n time, outputs, state, sample_ids\n )\n\nReturns `(finished, next_inputs, next_state)`.\n\n### `sample`\n\n[View source](https://github.com/tensorflow/addons/blob/v0.20.0/tensorflow_addons/seq2seq/sampler.py#L80-L83) \n\n @abc.abstractmethod\n sample(\n time, outputs, state\n )\n\nReturns `sample_ids`."]]