API de decodificación

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno

Descripción general

En el pasado reciente, ha habido mucha investigación en la generación de lenguajes con modelos autoregresivos. En la generación automática de idioma regresiva, la distribución de probabilidad de contador en el paso de tiempo K depende de token-predicciones del modelo hasta la etapa K-1. Para estos modelos, estrategias de decodificación como la búsqueda Beam, Greedy, Top-p, y Top-k son componentes críticos del modelo e influyen en gran medida el estilo / naturaleza de la salida token generado en un momento en el paso K dado.

Por ejemplo, haz de búsqueda reduce el riesgo de perder fichas de alta probabilidad ocultos por mantener los num_beams más probable de hipótesis en cada paso de tiempo y, finalmente, la elección de la hipótesis de que tiene la probabilidad más alta en general. Murray y col. (2018) y Yang et al. (2018) muestran que el haz de búsqueda funciona bien en tareas de traducción automática. Tanto en la búsqueda de haz y estrategias de Greedy tienen una posibilidad de generación de tokens que se repiten.

Fan et. al (2018) introdujo muestreo Top-K, en la que el K más probables tokens son filtrados y la masa de probabilidad se redistribuye entre sólo aquellos tokens K.

Ari Holtzman y col. al (2019) introdujo muestreo Top-p, que elige desde el más pequeño posible de fichas con probabilidad acumulada que añade hasta la probabilidad p. Luego, la masa de probabilidad se redistribuye entre este conjunto. De esta manera, el tamaño del conjunto de tokens puede aumentar y disminuir dinámicamente. Top-p, Top-k se utilizan generalmente en tareas tales como la historia de generación.

La API de decodificación proporciona una interfaz para experimentar con diferentes estrategias de decodificación en modelos autoregresivos.

  1. Las siguientes estrategias de muestreo se proporcionan en sampling_module.py, que hereda de la clase decodificación base:

  2. La búsqueda de haces se proporciona en beam_search.py. github

Configuración

pip install -q -U tensorflow-text
pip install -q tf-models-nightly
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

from official import nlp
from official.nlp.modeling.ops import sampling_module
from official.nlp.modeling.ops import beam_search
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release
  PkgResourcesDeprecationWarning,

Inicialice el módulo de muestreo en TF-NLP.

  • symbols_to_logits_fn: Usar este cierre para llamar al modelo para predecir las logits para el index+1 paso. Las entradas y salidas para este cierre son las siguientes:
Args:
  1] ids : Current decoded sequences. int tensor with shape (batch_size, index + 1 or 1 if padded_decode is True)],
  2] index [scalar] : current decoded step,
  3] cache [nested dictionary of tensors] : Only used for faster decoding to store pre-computed attention hidden states for keys and values. More explanation in the cell below.
Returns:
  1] tensor for next-step logits [batch_size, vocab]
  2] the updated_cache [nested dictionary of tensors].

La caché se utiliza para una decodificación más rápida. Aquí es una referencia de aplicación para el cierre anterior.

  • length_normalization_fn: Usar este cierre para el retorno de parámetro de longitud de normalización.
Args: 
  1] length : scalar for decoded step index.
  2] dtype : data-type of output tensor
Returns:
  1] value of length normalization factor.
  • vocab_size: el tamaño del vocabulario de salida.

  • max_decode_length: escalar para el número total de etapas de decodificación.

  • eos_id: La decodificación se detendrá si todos los identificadores de salida decodificada en el lote tienen esta eos_id.

  • padded_decode: Ponga esto a TRUE si se ejecuta en TPU. Los tensores se rellenan a max_decoding_length si esto es True.

  • top_k: top_k está habilitado si este valor es> 1.

  • top_p: top_p está habilitado si este valor es> 0 y <1.0

  • sampling_temperature: Se utiliza para volver a estimar la salida softmax. La temperatura sesga la distribución hacia fichas de alta probabilidad y reduce la masa en la distribución de la cola. El valor tiene que ser positivo. La temperatura baja es equivalente a la codicia y hace que la distribución sea más nítida, mientras que la temperatura alta la hace más plana.

  • enable_greedy: De forma predeterminada, esto es cierto y decodificación voraz está activado. Para experimentar con otras estrategias, establezca esto en Falso.

Inicializar los hiperparámetros del modelo

params = {}
params['num_heads'] = 2
params['num_layers'] = 2
params['batch_size'] = 2
params['n_dims'] = 256
params['max_decode_length'] = 4

En las arquitecturas de auto-regresivos como base Transformador codificador-decodificador de modelos, la caché se utiliza para la decodificación secuencial rápido. Es un diccionario anidado que almacena estados ocultos calculados previamente (clave y valores en los bloques de atención propia y en los bloques de atención cruzada) para cada capa.

Inicializar caché.

cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], int(params['n_dims']/params['num_heads'])], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", cache['layer_1']['k'].shape)
cache key shape for layer 1 : (2, 4, 2, 128)

Defina el cierre para la normalización de la longitud si es necesario.

Se utiliza para normalizar las puntuaciones finales de las secuencias generadas y es opcional.

def length_norm(length, dtype):
  """Return length normalization factor."""
  return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)

Crear model_fn

En la práctica, esto va a ser sustituido por un modelo de aplicación real, como aquí

Args:
i : Step that is being decoded.
Returns:
  logit probabilities of size [batch_size, 1, vocab_size]
probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                            [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],
                              [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])
def model_fn(i):
  return probabilities[:, i, :]

Inicializar symbols_to_logits_fn

def _symbols_to_logits_fn():
  """Calculates logits of the next tokens."""
  def symbols_to_logits_fn(ids, i, temp_cache):
    del ids
    logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)
    return logits, temp_cache
  return symbols_to_logits_fn

Avaro

Decodificación Greedy selecciona el ID de símbolo con la probabilidad más alta como su próximo id: \(id_t = argmax_{w}P(id | id_{1:t-1})\) en cada paso de tiempo \(t\). El siguiente esquema muestra una codificación codiciosa.

greedy_obj = sampling_module.SamplingModule(
    length_normalization_fn=None,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False)
ids, _ = greedy_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("Greedy Decoded Ids:", ids)
Greedy Decoded Ids: tf.Tensor(
[[9 1 2 2 2]
 [1 1 1 2 2]], shape=(2, 5), dtype=int32)

muestreo top_k

En el muestreo Top-K, la K más probable ids siguiente token se filtró y la masa de probabilidad se redistribuye entre sólo aquellos ids K.

top_k_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_k=tf.constant(3),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_k_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-k sampled Ids:", ids)
top-k sampled Ids: tf.Tensor(
[[9 1 0 2 2]
 [1 0 1 2 2]], shape=(2, 5), dtype=int32)

muestreo top_p

En lugar de muestreo sólo desde el más probable K contador ids, en Top-p muestreo elige a partir de la más pequeña posible de ids cuya probabilidad acumulada excede la probabilidad p.

top_p_obj = sampling_module.SamplingModule(
    length_normalization_fn=length_norm,
    dtype=tf.float32,
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    vocab_size=3,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    sample_temperature=tf.constant(1.0),
    top_p=tf.constant(0.9),
    padded_decode=False,
    enable_greedy=False)
ids, _ = top_p_obj.generate(
    initial_ids=tf.constant([9, 1]), initial_cache=cache)
print("top-p sampled Ids:", ids)
top-p sampled Ids: tf.Tensor(
[[9 1 1 2 2]
 [1 1 1 0 2]], shape=(2, 5), dtype=int32)

Decodificación de búsqueda de haz

La búsqueda de haces reduce el riesgo de perder los identificadores de token de alta probabilidad ocultos al mantener el número más probable de haces de hipótesis en cada paso de tiempo y, finalmente, al elegir la hipótesis que tiene la probabilidad general más alta.

beam_size = 2
params['batch_size'] = 1
beam_cache = {
    'layer_%d' % layer: {
        'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),
        'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)
        } for layer in range(params['num_layers'])
    }
print("cache key shape for layer 1 :", beam_cache['layer_1']['k'].shape)
ids, _ = beam_search.sequence_beam_search(
    symbols_to_logits_fn=_symbols_to_logits_fn(),
    initial_ids=tf.constant([9], tf.int32),
    initial_cache=beam_cache,
    vocab_size=3,
    beam_size=beam_size,
    alpha=0.6,
    max_decode_length=params['max_decode_length'],
    eos_id=10,
    padded_decode=False,
    dtype=tf.float32)
print("Beam search ids:", ids)
cache key shape for layer 1 : (1, 4, 2, 256)
Beam search ids: tf.Tensor(
[[[9 0 1 2 2]
  [9 1 2 2 2]]], shape=(1, 2, 5), dtype=int32)