tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper
Stay organized with collections
Save and categorize content based on your preferences.
A training helper that adds scheduled sampling.
Inherits From: TrainingHelper
tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
inputs, sequence_length, embedding, sampling_probability, time_major=False,
seed=None, scheduling_seed=None, name=None
)
Returns -1s for sample_ids where no sampling took place; valid sample id
values elsewhere.
Args |
inputs
|
A (structure of) input tensors.
|
sequence_length
|
An int32 vector tensor.
|
embedding
|
A callable that takes a vector tensor of ids (argmax ids),
or the params argument for embedding_lookup .
|
sampling_probability
|
A 0D float32 tensor: the probability of sampling
categorically from the output ids instead of reading directly from the
inputs.
|
time_major
|
Python bool. Whether the tensors in inputs are time major.
If False (default), they are assumed to be batch major.
|
seed
|
The sampling seed.
|
scheduling_seed
|
The schedule decision rule sampling seed.
|
name
|
Name scope for any created operations.
|
Raises |
ValueError
|
if sampling_probability is not a scalar or vector.
|
Attributes |
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor.
|
inputs
|
|
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType.
|
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a TensorShape .
|
sequence_length
|
|
Methods
initialize
View source
initialize(
name=None
)
Returns (initial_finished, initial_inputs)
.
View source
next_inputs(
time, outputs, state, sample_ids, name=None
)
next_inputs_fn for TrainingHelper.
sample
View source
sample(
time, outputs, state, name=None
)
Returns sample_ids
.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L315-L413) |\n\nA training helper that adds scheduled sampling.\n\nInherits From: [`TrainingHelper`](../../../tf/contrib/seq2seq/TrainingHelper) \n\n tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(\n inputs, sequence_length, embedding, sampling_probability, time_major=False,\n seed=None, scheduling_seed=None, name=None\n )\n\nReturns -1s for sample_ids where no sampling took place; valid sample id\nvalues elsewhere.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|------------------------|-----------------------------------------------------------------------------------------------------------------------------------|\n| `inputs` | A (structure of) input tensors. |\n| `sequence_length` | An int32 vector tensor. |\n| `embedding` | A callable that takes a vector tensor of `ids` (argmax ids), or the `params` argument for `embedding_lookup`. |\n| `sampling_probability` | A 0D `float32` tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs. |\n| `time_major` | Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. |\n| `seed` | The sampling seed. |\n| `scheduling_seed` | The schedule decision rule sampling seed. |\n| `name` | Name scope for any created operations. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|------------------------------------------------------|\n| `ValueError` | if `sampling_probability` is not a scalar or vector. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|--------------------|------------------------------------------------------------------------------------------------------|\n| `batch_size` | Batch size of tensor returned by `sample`. \u003cbr /\u003e Returns a scalar int32 tensor. |\n| `inputs` | \u003cbr /\u003e \u003cbr /\u003e |\n| `sample_ids_dtype` | DType of tensor returned by `sample`. \u003cbr /\u003e Returns a DType. |\n| `sample_ids_shape` | Shape of tensor returned by `sample`, excluding the batch dimension. \u003cbr /\u003e Returns a `TensorShape`. |\n| `sequence_length` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `initialize`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L364-L365) \n\n initialize(\n name=None\n )\n\nReturns `(initial_finished, initial_inputs)`.\n\n### `next_inputs`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L381-L413) \n\n next_inputs(\n time, outputs, state, sample_ids, name=None\n )\n\nnext_inputs_fn for TrainingHelper.\n\n### `sample`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L367-L379) \n\n sample(\n time, outputs, state, name=None\n )\n\nReturns `sample_ids`."]]