tf.contrib.seq2seq.GreedyEmbeddingHelper
Stay organized with collections
Save and categorize content based on your preferences.
A helper for use during inference.
Inherits From: Helper
tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding, start_tokens, end_token
)
Uses the argmax of the output (treated as logits) and passes the
result through an embedding layer to get the next input.
Args |
embedding
|
A callable that takes a vector tensor of ids (argmax ids),
or the params argument for embedding_lookup . The returned tensor
will be passed to the decoder input.
|
start_tokens
|
int32 vector shaped [batch_size] , the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
Raises |
ValueError
|
if start_tokens is not a 1D tensor or end_token is not a
scalar.
|
Attributes |
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor.
|
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType.
|
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a TensorShape .
|
Methods
initialize
View source
initialize(
name=None
)
Returns (initial_finished, initial_inputs)
.
View source
next_inputs(
time, outputs, state, sample_ids, name=None
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
View source
sample(
time, outputs, state, name=None
)
sample for GreedyEmbeddingHelper.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.contrib.seq2seq.GreedyEmbeddingHelper\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L554-L628) |\n\nA helper for use during inference.\n\nInherits From: [`Helper`](../../../tf/contrib/seq2seq/Helper) \n\n tf.contrib.seq2seq.GreedyEmbeddingHelper(\n embedding, start_tokens, end_token\n )\n\nUses the argmax of the output (treated as logits) and passes the\nresult through an embedding layer to get the next input.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `embedding` | A callable that takes a vector tensor of `ids` (argmax ids), or the `params` argument for `embedding_lookup`. The returned tensor will be passed to the decoder input. |\n| `start_tokens` | `int32` vector shaped `[batch_size]`, the start tokens. |\n| `end_token` | `int32` scalar, the token that marks end of decoding. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|----------------------------------------------------------------------|\n| `ValueError` | if `start_tokens` is not a 1D tensor or `end_token` is not a scalar. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|--------------------|------------------------------------------------------------------------------------------------------|\n| `batch_size` | Batch size of tensor returned by `sample`. \u003cbr /\u003e Returns a scalar int32 tensor. |\n| `sample_ids_dtype` | DType of tensor returned by `sample`. \u003cbr /\u003e Returns a DType. |\n| `sample_ids_shape` | Shape of tensor returned by `sample`, excluding the batch dimension. \u003cbr /\u003e Returns a `TensorShape`. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `initialize`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L604-L606) \n\n initialize(\n name=None\n )\n\nReturns `(initial_finished, initial_inputs)`.\n\n### `next_inputs`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L618-L628) \n\n next_inputs(\n time, outputs, state, sample_ids, name=None\n )\n\nnext_inputs_fn for GreedyEmbeddingHelper.\n\n### `sample`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/seq2seq/python/ops/helper.py#L608-L616) \n\n sample(\n time, outputs, state, name=None\n )\n\nsample for GreedyEmbeddingHelper."]]