Module: tfa.seq2seq

Additional layers for sequence to sequence models.

Classes

class AttentionMechanism: Base class for attention mechanisms.

class AttentionWrapper: Wraps another RNN cell with attention.

class AttentionWrapperState: State of a tfa.seq2seq.AttentionWrapper.

class BahdanauAttention: Implements Bahdanau-style (additive) attention.

class BahdanauMonotonicAttention: Monotonic attention mechanism with Bahdanau-style energy function.

class BaseDecoder: An RNN Decoder that is based on a Keras layer.

class BasicDecoder: Basic sampling decoder for training and inference.

class BasicDecoderOutput: Outputs of a tfa.seq2seq.BasicDecoder step.

class BeamSearchDecoder: Beam search decoder.

class BeamSearchDecoderOutput: Outputs of a tfa.seq2seq.BeamSearchDecoder step.

class BeamSearchDecoderState: State of a tfa.seq2seq.BeamSearchDecoder.

class CustomSampler: Base abstract class that allows the user to customize sampling.

class Decoder: An RNN Decoder abstract interface object.

class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is finished.

class GreedyEmbeddingSampler: A inference sampler that takes the maximum from the output distribution.

class InferenceSampler: An inference sampler that uses a custom sampling function.

class LuongAttention: Implements Luong-style (multiplicative) attention scoring.

class LuongMonotonicAttention: Monotonic attention mechanism with Luong-style energy function.

class SampleEmbeddingSampler: An inference sampler that randomly samples from the output distribution.

class Sampler: Interface for implementing sampling in seq2seq decoders.

class ScheduledEmbeddingTrainingSampler: A training sampler that adds scheduled sampling.

class ScheduledOutputTrainingSampler: A training sampler that adds scheduled sampling directly to outputs.

class SequenceLoss: Weighted cross-entropy loss for a sequence of logits.

class TrainingSampler: A training sampler that simply reads its inputs.

Functions

dynamic_decode(...): Runs dynamic decoding with a decoder.

gather_tree(...): Calculates the full beams from the per-step ids and parent beam ids.

gather_tree_from_array(...): Calculates the full beams for a TensorArray.

hardmax(...): Returns batched one-hot vectors.

monotonic_attention(...): Computes monotonic attention distribution from choosing probabilities.

safe_cumprod(...): Computes cumprod of x in logspace using cumsum to avoid underflow.

sequence_loss(...): Computes the weighted cross-entropy loss for a sequence of logits.

tile_batch(...): Tiles the batch dimension of a (possibly nested structure of) tensor(s).