|  View source on GitHub | 
Additional layers for sequence to sequence models.
Classes
class AttentionMechanism: Base class for attention mechanisms.
class AttentionWrapper: Wraps another RNN cell with attention.
class AttentionWrapperState: State of a tfa.seq2seq.AttentionWrapper.
class BahdanauAttention: Implements Bahdanau-style (additive) attention.
class BahdanauMonotonicAttention: Monotonic attention mechanism with Bahdanau-style energy function.
class BaseDecoder: An RNN Decoder that is based on a Keras layer.
class BasicDecoder: Basic sampling decoder for training and inference.
class BasicDecoderOutput: Outputs of a tfa.seq2seq.BasicDecoder step.
class BeamSearchDecoder: Beam search decoder.
class BeamSearchDecoderOutput: Outputs of a tfa.seq2seq.BeamSearchDecoder step.
class BeamSearchDecoderState: State of a tfa.seq2seq.BeamSearchDecoder.
class CustomSampler: Base abstract class that allows the user to customize sampling.
class Decoder: An RNN Decoder abstract interface object.
class FinalBeamSearchDecoderOutput: Final outputs returned by the beam search after all decoding is finished.
class GreedyEmbeddingSampler: A inference sampler that takes the maximum from the output distribution.
class InferenceSampler: An inference sampler that uses a custom sampling function.
class LuongAttention: Implements Luong-style (multiplicative) attention scoring.
class LuongMonotonicAttention: Monotonic attention mechanism with Luong-style energy function.
class SampleEmbeddingSampler: An inference sampler that randomly samples from the output distribution.
class Sampler: Interface for implementing sampling in seq2seq decoders.
class ScheduledEmbeddingTrainingSampler: A training sampler that adds scheduled sampling.
class ScheduledOutputTrainingSampler: A training sampler that adds scheduled sampling directly to outputs.
class SequenceLoss: Weighted cross-entropy loss for a sequence of logits.
class TrainingSampler: A training sampler that simply reads its inputs.
Functions
dynamic_decode(...): Runs dynamic decoding with a decoder.
gather_tree(...): Calculates the full beams from the per-step ids and parent beam ids.
gather_tree_from_array(...): Calculates the full beams for a TensorArray.
hardmax(...): Returns batched one-hot vectors.
monotonic_attention(...): Computes monotonic attention distribution from choosing probabilities.
safe_cumprod(...): Computes cumprod of x in logspace using cumsum to avoid underflow.
sequence_loss(...): Computes the weighted cross-entropy loss for a sequence of logits.
tile_batch(...): Tiles the batch dimension of a (possibly nested structure of) tensor(s).