View source on GitHub |
Additional layers for sequence to sequence models.
Classes
class AttentionMechanism
: Base class for attention mechanisms.
class AttentionWrapper
: Wraps another RNN cell with attention.
class AttentionWrapperState
: State of a tfa.seq2seq.AttentionWrapper
.
class BahdanauAttention
: Implements Bahdanau-style (additive) attention.
class BahdanauMonotonicAttention
: Monotonic attention mechanism with Bahdanau-style energy function.
class BaseDecoder
: An RNN Decoder that is based on a Keras layer.
class BasicDecoder
: Basic sampling decoder for training and inference.
class BasicDecoderOutput
: Outputs of a tfa.seq2seq.BasicDecoder
step.
class BeamSearchDecoder
: Beam search decoder.
class BeamSearchDecoderOutput
: Outputs of a tfa.seq2seq.BeamSearchDecoder
step.
class BeamSearchDecoderState
: State of a tfa.seq2seq.BeamSearchDecoder
.
class CustomSampler
: Base abstract class that allows the user to customize sampling.
class Decoder
: An RNN Decoder abstract interface object.
class FinalBeamSearchDecoderOutput
: Final outputs returned by the beam search after all decoding is finished.
class GreedyEmbeddingSampler
: A inference sampler that takes the maximum from the output distribution.
class InferenceSampler
: An inference sampler that uses a custom sampling function.
class LuongAttention
: Implements Luong-style (multiplicative) attention scoring.
class LuongMonotonicAttention
: Monotonic attention mechanism with Luong-style energy function.
class SampleEmbeddingSampler
: An inference sampler that randomly samples from the output distribution.
class Sampler
: Interface for implementing sampling in seq2seq decoders.
class ScheduledEmbeddingTrainingSampler
: A training sampler that adds scheduled sampling.
class ScheduledOutputTrainingSampler
: A training sampler that adds scheduled sampling directly to outputs.
class SequenceLoss
: Weighted cross-entropy loss for a sequence of logits.
class TrainingSampler
: A training sampler that simply reads its inputs.
Functions
dynamic_decode(...)
: Runs dynamic decoding with a decoder.
gather_tree(...)
: Calculates the full beams from the per-step ids and parent beam ids.
gather_tree_from_array(...)
: Calculates the full beams for a TensorArray
.
hardmax(...)
: Returns batched one-hot vectors.
monotonic_attention(...)
: Computes monotonic attention distribution from choosing probabilities.
safe_cumprod(...)
: Computes cumprod of x in logspace using cumsum to avoid underflow.
sequence_loss(...)
: Computes the weighted cross-entropy loss for a sequence of logits.
tile_batch(...)
: Tiles the batch dimension of a (possibly nested structure of) tensor(s).