TensorFlow 2.0 Beta is available Learn more

tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper

View source on GitHub

Class ScheduledEmbeddingTrainingHelper

A training helper that adds scheduled sampling.

Inherits From: TrainingHelper

Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere.

__init__

View source

__init__(
    inputs,
    sequence_length,
    embedding,
    sampling_probability,
    time_major=False,
    seed=None,
    scheduling_seed=None,
    name=None
)

Initializer.

Args:

  • inputs: A (structure of) input tensors.
  • sequence_length: An int32 vector tensor.
  • embedding: A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup.
  • sampling_probability: A 0D float32 tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs.
  • time_major: Python bool. Whether the tensors in inputs are time major. If False (default), they are assumed to be batch major.
  • seed: The sampling seed.
  • scheduling_seed: The schedule decision rule sampling seed.
  • name: Name scope for any created operations.

Raises:

  • ValueError: if sampling_probability is not a scalar or vector.

Properties

batch_size

inputs

sample_ids_dtype

sample_ids_shape

sequence_length

Methods

initialize

View source

initialize(name=None)

next_inputs

View source

next_inputs(
    time,
    outputs,
    state,
    sample_ids,
    name=None
)

sample

View source

sample(
    time,
    outputs,
    state,
    name=None
)