Ajuste fino de Wav2Vec2 con un cabezal LM

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno Ver modelo TF Hub

En este cuaderno, vamos a cargar el modelo wav2vec2 pre-formados a partir TFHub y afinará en LibriSpeech conjunto de datos añadiendo la cabeza Modeling Language (LM) sobre la parte superior de nuestro modelo de pre-formados. La tarea subyacente es construir un modelo de reconocimiento de voz automático es decir, dado un poco de discurso, el modelo debe ser capaz de transcribir en texto.

Configuración

Antes de ejecutar este portátil, por favor asegúrese de que está en la GPU en tiempo de ejecución ( Runtime > Change runtime type > GPU ). La celda siguiente instalará gsoc-wav2vec2 paquete y sus dependencias.

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  linux-gcp-5.4-headers-5.4.0-1040 linux-gcp-5.4-headers-5.4.0-1043
  linux-gcp-5.4-headers-5.4.0-1044 linux-gcp-5.4-headers-5.4.0-1049
  linux-headers-5.4.0-1049-gcp linux-image-5.4.0-1049-gcp
  linux-modules-5.4.0-1049-gcp linux-modules-extra-5.4.0-1049-gcp
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libogg-dev libvorbis-dev libvorbisfile3
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev libvorbisfile3
0 upgraded, 5 newly installed, 0 to remove and 143 not upgraded.
Need to get 1040 kB of archives.
After this operation, 4481 kB of additional disk space will be used.
Get:1 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libogg-dev amd64 1.3.2-1 [156 kB]
Get:2 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libflac-dev amd64 1.3.2-1 [260 kB]
Get:3 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbisfile3 amd64 1.3.5-4.2 [16.0 kB]
Get:4 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbis-dev amd64 1.3.5-4.2 [321 kB]
Get:5 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic-updates/main amd64 libsndfile1-dev amd64 1.0.28-4ubuntu0.18.04.2 [287 kB]
Fetched 1040 kB in 1s (1041 kB/s)
Selecting previously unselected package libogg-dev:amd64.
(Reading database ... 282211 files and directories currently installed.)
Preparing to unpack .../libogg-dev_1.3.2-1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.2-1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libvorbisfile3:amd64.
Preparing to unpack .../libvorbisfile3_1.3.5-4.2_amd64.deb ...
Unpacking libvorbisfile3:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.5-4.2_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-4ubuntu0.18.04.2_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Setting up libvorbisfile3:amd64 (1.3.5-4.2) ...
Setting up libogg-dev:amd64 (1.3.2-1) ...
Setting up libvorbis-dev:amd64 (1.3.5-4.2) ...
Setting up libflac-dev:amd64 (1.3.2-1) ...
Setting up libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Processing triggers for libc-bin (2.27-3ubuntu1.2) ...

Configuración modelo usando TFHub

Comenzaremos importando algunas bibliotecas / módulos.

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
TF version: 2.7.0

En primer lugar, vamos a descargar nuestro modelo de TFHub y envolverá nuestra firma modelo con hub.KerasLayer para poder utilizar este modelo como cualquier otra capa Keras. Afortunadamente, hub.KerasLayer puede hacer tanto en tan sólo 1 línea.

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

Se puede hacer referencia a esta secuencia de comandos en caso de estar interesado en el guión exportar el modelo. Objeto pretrained_layer es la versión freezed de Wav2Vec2Model . Estos pesos pre-formados fueron convertidos de HuggingFace PyTorch pesos entrenada previamente usando esta secuencia de comandos .

Originalmente, wav2vec2 se entrenó previamente con un enfoque de modelado de lenguaje enmascarado con el objetivo de identificar la verdadera representación cuantificada del habla latente para un paso de tiempo enmascarado. Puede leer más sobre el objetivo de la capacitación en el papel- wav2vec 2.0: Un marco para la auto-aprendizaje supervisado de expresión Representaciones .

Ahora, definiremos algunas constantes e hiperparámetros que serán útiles en las siguientes celdas. AUDIO_MAXLEN se establece intencionalmente a 246000 como el modelo de la firma sólo acepta longitud de la secuencia estática de 246000 .

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

En la celda siguiente, vamos a envolver pretrained_layer y una densa capa (cabeza LM) con el API funcional de Keras .

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

La capa densa (definido anteriormente) está teniendo una dimensión de salida de vocab_size como queremos predecir probabilidades de cada contador en el vocabulario en cada paso de tiempo.

Configurando el estado de entrenamiento

En TensorFlow, pesas modelo se construyó solamente cuando model.call o model.build se llama por primera vez, por lo que la célula después construirá los pesos modelo para nosotros. Además, estaremos corriendo model.summary() para comprobar el número total de parámetros entrenables.

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

Ahora, tenemos que definir el loss_fn y un optimizador para poder entrenar el modelo. La siguiente celda lo hará por nosotros. Nosotros vamos a usar el Adam optimizador para la simplicidad. CTCLoss es un tipo de pérdida común que se utiliza para tareas (como ASR ) donde sub-partes de entrada no pueden ser fácilmente alineados con salida sub-partes. Puede leer más sobre CTC-pérdida de esta increíble entrada de blog .

CTCLoss (de gsoc-wav2vec2 paquete) acepta 3 argumentos: config , model_input_shape y division_factor . Si division_factor=1 , entonces la pérdida de simplemente obtener resumió, por lo que pasar division_factor en consecuencia para obtener lotes más de media.

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Carga y preprocesamiento de datos

Ahora vamos a descargar el conjunto de datos LibriSpeech desde el sitio web oficial y configurarlo.

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2021-11-05 11:43:09--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  11.6MB/s    in 31s     

2021-11-05 11:43:42 (10.3 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]
ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

Nuestro conjunto de datos se encuentra en el directorio LibriSpeech. Exploremos estos archivos.

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0015.flac', '2428-83705-0004.flac', '2428-83705-0006.flac', '2428-83705-0026.flac', '2428-83705-0023.flac', '2428-83705-0001.flac', '2428-83705-0005.flac', '2428-83705-0040.flac', '2428-83705-0038.flac', '2428-83705-0042.flac', '2428-83705-0008.flac', '2428-83705-0019.flac', '2428-83705-0021.flac', '2428-83705-0002.flac', '2428-83705-0039.flac', '2428-83705-0034.flac', '2428-83705-0028.flac', '2428-83705-0000.flac', '2428-83705-0029.flac', '2428-83705-0041.flac', '2428-83705-0035.flac', '2428-83705-0032.flac', '2428-83705-0020.flac', '2428-83705-0025.flac', '2428-83705-0010.flac', '2428-83705-0014.flac', '2428-83705-0003.flac', '2428-83705-0031.flac', '2428-83705-0017.flac', '2428-83705-0027.flac', '2428-83705-0012.flac', '2428-83705-0043.flac', '2428-83705-0030.flac', '2428-83705-0022.flac', '2428-83705-0016.flac', '2428-83705-0037.flac', '2428-83705-0011.flac', '2428-83705-0036.flac', '2428-83705-0009.flac', '2428-83705-0013.flac', '2428-83705-0007.flac', '2428-83705-0018.flac', '2428-83705-0024.flac', '2428-83705-0033.flac']

Muy bien, por lo que cada subdirectorio tiene muchas .flac archivos y un .txt archivo. El .txt archivo contiene transcripciones de texto para todas las muestras de voz (es decir .flac archivos) presente en ese subdirectorio.

Podemos cargar estos datos de texto de la siguiente manera:

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

Del mismo modo, vamos a definir una función para cargar una muestra de voz de un .flac archivo.

REQUIRED_SAMPLE_RATE se establece en 16000 como wav2vec2 fue pre-entrenó con 16K frecuencia y se recomienda ajustar sin ningún cambio importante en la distribución de datos debido a la frecuencia.

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

Ahora, elegiremos algunas muestras aleatorias e intentaremos visualizarlas.

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: HE HAS GIVEN US FREE PASSES ALL THE WAY TO THE END OF OUR JOURNEY AND ALL THE WAY BACK AGAIN AND COUPONS FOR FREE BOARD AND LODGING AT THE HOTEL IT'S A WEDDING PRESENT 
Audio:

Ahora, combinaremos todas las muestras de voz y texto y definiremos la función (en la siguiente celda) para ese propósito.

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

Es hora de echar un vistazo a algunas muestras ...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([ 6.10351562e-05,  9.15527344e-05,  9.15527344e-05, ...,
         -3.05175781e-04, -5.79833984e-04, -8.23974609e-04]),
  'WHEN SHE HEARD OF MY ENGAGEMENT WITH MARY ANN SHE WROTE AND SUGGESTED THAT WE SHOULD SPEND OUR HONEYMOON IN HER COTTAGE OR PIGSTYE AND THAT I SHOULD PAY HER RENT FOR IT'),
 (array([-0.00112915, -0.00131226, -0.00158691, ...,  0.00067139,
          0.00091553,  0.00100708]),
  "IT MIGHT JUST AS WELL BE SOME ONE ELSE'S WEDDING SO UNIMPORTANT IS THE PART WHICH I AM SET TO PLAY IN IT"),
 (array([ 3.05175781e-05, -6.10351562e-05,  2.13623047e-04, ...,
         -5.18798828e-04, -2.13623047e-04, -2.74658203e-04]),
  'THE ACCIDENT IN QUESTION OCCURRED UPON THE SUNDAY EVENING'),
 (array([ 3.05175781e-04,  3.05175781e-05, -1.83105469e-04, ...,
          7.62939453e-04,  6.10351562e-04,  5.79833984e-04]),
  "OF COURSE THERE ARE SOME PEOPLE WITH WHOM YOU CAN'T BE PERFECTLY PLAIN BUT I SHALL BE AS PLAIN AS I CAN THERE'S A WAY AND A MANNER OF DOING THAT KIND OF THING"),
 (array([ 6.10351562e-05, -3.05175781e-05,  0.00000000e+00, ...,
         -3.66210938e-04, -7.93457031e-04, -1.19018555e-03]),
  'I KNOW WHAT MAMMA CAN AFFORD TO GIVE AND I WILL SEE SHE GIVES IT')]

¡Procesemos previamente los datos ahora!

Vamos a definir primero el tokenizer y procesador utilizando gsoc-wav2vec2 paquete. Luego, haremos un preprocesamiento muy sencillo. processor se normalizará el habla prima WRTO marcos eje y tokenizer convertirán nuestros resultados del modelo en la cadena (usando el vocabulario definido) y se hará cargo de la eliminación de fichas especiales (dependiendo de la configuración tokenizer).

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

Ahora, definiremos el generador de Python para llamar a las funciones de preprocesamiento que definimos en las celdas anteriores.

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

La creación de tf.data.Dataset

Tras la instalación de células voluntad tf.data.Dataset objeto mediante su .from_generator(...) método. Vamos a utilizar el generator objeto, se definió en la celda anterior.

Se puede hacer referencia a esta secuencia de comandos para más detalles sobre cómo convertir los datos en LibriSpeech tfrecords.

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

Pasaremos el conjunto de datos a varios lotes, así que preparemos los lotes en la siguiente celda. Ahora, todas las secuencias de un lote deben rellenarse a una longitud constante. Vamos a utilizar el .padded_batch(...) el método para tal fin.

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

Los aceleradores (como las GPU / TPU) son muy rápidos y, a menudo, la carga de datos (y el preprocesamiento) se convierte en el cuello de botella durante el entrenamiento, ya que la parte de carga de datos ocurre en las CPU. Esto puede aumentar el tiempo de entrenamiento de manera significativa, especialmente cuando hay mucho procesamiento previo en línea involucrado o cuando los datos se transmiten en línea desde depósitos de GCS. Para hacer frente a esos problemas, tf.data.Dataset ofrece la .prefetch(...) método. Este método ayuda a preparar los siguientes lotes en paralelo (en CPU) mientras el modelo hace predicciones (en GPU / TPU) en el lote actual.

dataset = dataset.prefetch(tf.data.AUTOTUNE)

Desde este cuaderno se hace con fines de demostración, vamos a tomar primeros num_train_batches y llevará a cabo la capacitación sobre sólo eso. Sin embargo, se le anima a entrenar en todo el conjunto de datos. Del mismo modo, vamos a evaluar sólo num_val_batches .

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

Entrenamiento de modelos

Para la formación de nuestro modelo, estaremos llamando directamente .fit(...) el método después de compilar nuestro modelo con .compile(...) .

model.compile(optimizer, loss=loss_fn)

La celda anterior configurará nuestro estado de entrenamiento. Ahora podemos iniciar el entrenamiento con el .fit(...) método.

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
10/10 [==============================] - 32s 2s/step - loss: 649.3215 - val_loss: 315.0721
Epoch 2/3
10/10 [==============================] - 17s 2s/step - loss: 242.1202 - val_loss: 336.5721
Epoch 3/3
10/10 [==============================] - 17s 2s/step - loss: 222.1239 - val_loss: 253.0467
{'loss': [649.321533203125, 242.1201629638672, 222.1239013671875],
 'val_loss': [315.0721435546875, 336.5721130371094, 253.0466766357422]}

Salvemos nuestro modelo con .save(...) método para poder realizar inferencias más tarde. También puede exportar esta SavedModel a TFHub siguiendo la documentación TFHub .

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
2021-11-05 11:44:54.280793: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 855). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

Evaluación

Ahora calcularemos la tasa de error de palabras sobre el conjunto de datos de validación

Tasa de error de palabra (WER) es una métrica común para medir el rendimiento de un sistema de reconocimiento automático del habla. El WER se deriva de la distancia de Levenshtein, trabajando a nivel de palabra. La tasa de error de palabras se puede calcular como: WER = (S + D + I) / N = (S + D + I) / (S + D + C) donde S es el número de sustituciones, D es el número de eliminaciones , I es el número de inserciones, C es el número de palabras correctas, N es el número de palabras en la referencia (N = S + D + C). Este valor indica el porcentaje de palabras que se predijeron incorrectamente.

Se puede hacer referencia a este documento para aprender más sobre WER.

Vamos a utilizar load_metric(...) la función de los conjuntos de datos HuggingFace biblioteca. Primero vamos a instalar el datasets la biblioteca usando el pip y luego definen la metric objeto.

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

Es hora de ejecutar la evaluación en datos de validación ahora.

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2021-11-05 11:45:11.575128: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:57] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

Estamos utilizando el tokenizer.decode(...) procedimiento para decodificar nuestras predicciones y etiquetas de nuevo en el texto y se sumará a la métrica para WER cálculo posterior.

Ahora, calculemos el valor métrico en la siguiente celda:

metric.compute()
1.0

Inferencia

Ahora que estamos satisfechos con el proceso de formación y ha guardado el modelo en el save_dir , vamos a ver cómo este modelo puede ser utilizado para la inferencia.

En primer lugar, vamos a cargar nuestro modelo usando tf.keras.models.load_model(...) .

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

Descarguemos algunas muestras de voz para realizar inferencias. También puede reemplazar la siguiente muestra con su muestra de voz.

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2021-11-05 11:45:28--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 13.114.40.48
Connecting to github.com (github.com)|13.114.40.48|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-11-05 11:45:28--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.02s   

2021-11-05 11:45:29 (5.38 MB/s) - ‘SA2.wav’ saved [94252/94252]

Ahora, vamos a leer la muestra de voz utilizando soundfile.read(...) y la almohadilla para AUDIO_MAXLEN para satisfacer el modelo de la firma. A continuación vamos a normalizar la que se muestra de voz utilizando el Wav2Vec2Processor ejemplo y se alimentan en el modelo.

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 5.5087714 , -1.0872856 , -1.0728477 , ..., -1.3125695 ,
         -0.7992846 , -0.94512135],
        [ 5.508977  , -1.0873723 , -1.0727195 , ..., -1.3125291 ,
         -0.79928476, -0.9449429 ],
        [ 5.5091047 , -1.0871643 , -1.0728203 , ..., -1.312533  ,
         -0.7992611 , -0.94483167],
        ...,
        [ 5.5094743 , -1.0874028 , -1.0729864 , ..., -1.3126655 ,
         -0.7994431 , -0.9449925 ],
        [ 5.509465  , -1.0873648 , -1.072943  , ..., -1.3126557 ,
         -0.79943836, -0.94500387],
        [ 5.509408  , -1.0872416 , -1.0728781 , ..., -1.3125473 ,
         -0.7993649 , -0.9449776 ]]], dtype=float32)>

Decodificar los números de entrar otra vez en la secuencia de texto con el Wav2Vec2tokenizer ejemplo, hemos definido anteriormente.

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['']

Esta predicción es bastante aleatoria ya que el modelo nunca se entrenó con datos grandes en este cuaderno (ya que este cuaderno no está diseñado para realizar un entrenamiento completo). Obtendrá buenas predicciones si entrena este modelo en un conjunto de datos LibriSpeech completo.

Finalmente, hemos llegado al final de este cuaderno. Pero no es un fin de aprender TensorFlow para las tareas relacionadas con el habla, este repositorio contiene algunos tutoriales más sorprendentes. En caso de que usted encontró cualquier error en este portátil, por favor crea un problema aquí .