tf.contrib.seq2seq.ScheduledOutputTrainingHelper
Stay organized with collections
Save and categorize content based on your preferences.
A training helper that adds scheduled sampling directly to outputs.
Inherits From: TrainingHelper
tf.contrib.seq2seq.ScheduledOutputTrainingHelper(
inputs, sequence_length, sampling_probability, time_major=False, seed=None,
next_inputs_fn=None, auxiliary_inputs=None, name=None
)
Returns False for sample_ids where no sampling took place; True elsewhere.
Args |
inputs
|
A (structure) of input tensors.
|
sequence_length
|
An int32 vector tensor.
|
sampling_probability
|
A 0D float32 tensor: the probability of sampling
from the outputs instead of reading directly from the inputs.
|
time_major
|
Python bool. Whether the tensors in inputs are time major.
If False (default), they are assumed to be batch major.
|
seed
|
The sampling seed.
|
next_inputs_fn
|
(Optional) callable to apply to the RNN outputs to create
the next input when sampling. If None (default), the RNN outputs will
be used as the next inputs.
|
auxiliary_inputs
|
An optional (structure of) auxiliary input tensors with
a shape that matches inputs in all but (potentially) the final
dimension. These tensors will be concatenated to the sampled output or
the inputs when not sampling for use as the next input.
|
name
|
Name scope for any created operations.
|
Raises |
ValueError
|
if sampling_probability is not a scalar or vector.
|
Attributes |
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor.
|
inputs
|
|
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType.
|
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a TensorShape .
|
sequence_length
|
|
Methods
initialize
View source
initialize(
name=None
)
Returns (initial_finished, initial_inputs)
.
View source
next_inputs(
time, outputs, state, sample_ids, name=None
)
next_inputs_fn for TrainingHelper.
sample
View source
sample(
time, outputs, state, name=None
)
Returns sample_ids
.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.contrib.seq2seq.ScheduledOutputTrainingHelper\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L416-L551) |\n\nA training helper that adds scheduled sampling directly to outputs.\n\nInherits From: [`TrainingHelper`](../../../tf/contrib/seq2seq/TrainingHelper) \n\n tf.contrib.seq2seq.ScheduledOutputTrainingHelper(\n inputs, sequence_length, sampling_probability, time_major=False, seed=None,\n next_inputs_fn=None, auxiliary_inputs=None, name=None\n )\n\nReturns False for sample_ids where no sampling took place; True elsewhere.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `inputs` | A (structure) of input tensors. |\n| `sequence_length` | An int32 vector tensor. |\n| `sampling_probability` | A 0D `float32` tensor: the probability of sampling from the outputs instead of reading directly from the inputs. |\n| `time_major` | Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. |\n| `seed` | The sampling seed. |\n| `next_inputs_fn` | (Optional) callable to apply to the RNN outputs to create the next input when sampling. If `None` (default), the RNN outputs will be used as the next inputs. |\n| `auxiliary_inputs` | An optional (structure of) auxiliary input tensors with a shape that matches `inputs` in all but (potentially) the final dimension. These tensors will be concatenated to the sampled output or the `inputs` when not sampling for use as the next input. |\n| `name` | Name scope for any created operations. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|------------------------------------------------------|\n| `ValueError` | if `sampling_probability` is not a scalar or vector. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|--------------------|------------------------------------------------------------------------------------------------------|\n| `batch_size` | Batch size of tensor returned by `sample`. \u003cbr /\u003e Returns a scalar int32 tensor. |\n| `inputs` | \u003cbr /\u003e \u003cbr /\u003e |\n| `sample_ids_dtype` | DType of tensor returned by `sample`. \u003cbr /\u003e Returns a DType. |\n| `sample_ids_shape` | Shape of tensor returned by `sample`, excluding the batch dimension. \u003cbr /\u003e Returns a `TensorShape`. |\n| `sequence_length` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `initialize`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L483-L484) \n\n initialize(\n name=None\n )\n\nReturns `(initial_finished, initial_inputs)`.\n\n### `next_inputs`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L494-L551) \n\n next_inputs(\n time, outputs, state, sample_ids, name=None\n )\n\nnext_inputs_fn for TrainingHelper.\n\n### `sample`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L486-L492) \n\n sample(\n time, outputs, state, name=None\n )\n\nReturns `sample_ids`."]]