Тонкая настройка Wav2Vec2 с головкой LM

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть на GitHub Скачать блокнот См. Модель TF Hub

В этом ноутбуке, мы будем загружать заранее подготовленные модели wav2vec2 от TFHub и подстроить его на LibriSpeech наборе данных , прилагая голова языка моделирования (LM) поверх нашей предварительно обучена модели. Базовая задача состоит в том, чтобы построить модель для автоматического распознавания речи , т.е. учитывая некоторую речь, модель должна быть в состоянии транскрибировать его в текст.

Настройка

Перед запуском этого ноутбука, убедитесь , что вы на GPU выполнения ( Runtime > Change runtime type > GPU ). Следующая ячейка будет установить gsoc-wav2vec2 пакет & его зависимости.

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  linux-gcp-5.4-headers-5.4.0-1040 linux-gcp-5.4-headers-5.4.0-1043
  linux-gcp-5.4-headers-5.4.0-1044 linux-gcp-5.4-headers-5.4.0-1049
  linux-headers-5.4.0-1049-gcp linux-image-5.4.0-1049-gcp
  linux-modules-5.4.0-1049-gcp linux-modules-extra-5.4.0-1049-gcp
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libogg-dev libvorbis-dev libvorbisfile3
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev libvorbisfile3
0 upgraded, 5 newly installed, 0 to remove and 143 not upgraded.
Need to get 1040 kB of archives.
After this operation, 4481 kB of additional disk space will be used.
Get:1 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libogg-dev amd64 1.3.2-1 [156 kB]
Get:2 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libflac-dev amd64 1.3.2-1 [260 kB]
Get:3 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbisfile3 amd64 1.3.5-4.2 [16.0 kB]
Get:4 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbis-dev amd64 1.3.5-4.2 [321 kB]
Get:5 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic-updates/main amd64 libsndfile1-dev amd64 1.0.28-4ubuntu0.18.04.2 [287 kB]
Fetched 1040 kB in 1s (1041 kB/s)
Selecting previously unselected package libogg-dev:amd64.
(Reading database ... 282211 files and directories currently installed.)
Preparing to unpack .../libogg-dev_1.3.2-1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.2-1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libvorbisfile3:amd64.
Preparing to unpack .../libvorbisfile3_1.3.5-4.2_amd64.deb ...
Unpacking libvorbisfile3:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.5-4.2_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-4ubuntu0.18.04.2_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Setting up libvorbisfile3:amd64 (1.3.5-4.2) ...
Setting up libogg-dev:amd64 (1.3.2-1) ...
Setting up libvorbis-dev:amd64 (1.3.5-4.2) ...
Setting up libflac-dev:amd64 (1.3.2-1) ...
Setting up libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Processing triggers for libc-bin (2.27-3ubuntu1.2) ...

Модель установки с использованием TFHub

Мы начнем с импорта некоторых библиотек / модулей.

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
TF version: 2.7.0

Во- первых, мы будем загружать нашу модель от TFHub и окутает нашу модель подписи с hub.KerasLayer , чтобы иметь возможность использовать эту модель , как и любой другой слой Keras. К счастью, hub.KerasLayer можно сделать как в только 1 линии.

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

Вы можете обратиться к этому сценарию в случае , если вы заинтересованы в экспорте сценария модели. Объект pretrained_layer является замороженной версией Wav2Vec2Model . Эти заранее подготовленные веса были преобразованы из HuggingFace PyTorch предварительно подготовленного вес с помощью этого сценария .

Первоначально wav2vec2 был предварительно обучен подходу к моделированию замаскированного языка с целью идентифицировать истинное квантованное представление скрытой речи для замаскированного временного шага. Вы можете прочитать более о цели обучения в бумажно wav2vec 2.0: основе для Self-поднадзорного обучения речевых представлений .

Теперь мы определим несколько констант и гиперпараметров, которые будут полезны в следующих нескольких ячейках. AUDIO_MAXLEN намеренно установлена в 246000 , так как модели подпись принимает только статическую длину последовательности , равные 246000 .

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

В следующей ячейке, мы будем заворачивать pretrained_layer & плотный слой (LM голова) с функциональной API Keras в .

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

Плотный слой (определенный выше) , имеющий выходную размерность vocab_size , как мы хотим предсказать вероятности каждого маркера в словаре на каждом временном шаге.

Настройка состояния обучения

В TensorFlow, модели весов построены только тогда , когда model.call или model.build вызывается в первый раз, так что следующая ячейка будет строить веса модели для нас. Кроме того, мы будем работать model.summary() , чтобы проверить общее количество обучаемых параметров.

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

Теперь нам нужно определить loss_fn и оптимизатор , чтобы иметь возможность для обучения модели. Следующая ячейка сделает это за нас. Мы будем с помощью Adam оптимизатора для простоты. CTCLoss является распространенным типом потери , которые используются для выполнения задач (например , ASR ) , где входные суб-части не может быть легко совмещен с выходными подразделами. Вы можете прочитать больше о СТС-потери от этого удивительного блога .

CTCLoss (от gsoc-wav2vec2 пакет) принимает 3 аргумента: config , model_input_shape & division_factor . Если division_factor=1 , то потери будут просто получить суммируются, поэтому проходят division_factor соответственно , чтобы получить среднюю за партию.

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Загрузка и предварительная обработка данных

Теперь давайте скачать LibriSpeech набор данных с официального сайта и установить его.

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2021-11-05 11:43:09--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  11.6MB/s    in 31s     

2021-11-05 11:43:42 (10.3 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]
ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

Наш набор данных находится в каталоге LibriSpeech. Давайте исследуем эти файлы.

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0015.flac', '2428-83705-0004.flac', '2428-83705-0006.flac', '2428-83705-0026.flac', '2428-83705-0023.flac', '2428-83705-0001.flac', '2428-83705-0005.flac', '2428-83705-0040.flac', '2428-83705-0038.flac', '2428-83705-0042.flac', '2428-83705-0008.flac', '2428-83705-0019.flac', '2428-83705-0021.flac', '2428-83705-0002.flac', '2428-83705-0039.flac', '2428-83705-0034.flac', '2428-83705-0028.flac', '2428-83705-0000.flac', '2428-83705-0029.flac', '2428-83705-0041.flac', '2428-83705-0035.flac', '2428-83705-0032.flac', '2428-83705-0020.flac', '2428-83705-0025.flac', '2428-83705-0010.flac', '2428-83705-0014.flac', '2428-83705-0003.flac', '2428-83705-0031.flac', '2428-83705-0017.flac', '2428-83705-0027.flac', '2428-83705-0012.flac', '2428-83705-0043.flac', '2428-83705-0030.flac', '2428-83705-0022.flac', '2428-83705-0016.flac', '2428-83705-0037.flac', '2428-83705-0011.flac', '2428-83705-0036.flac', '2428-83705-0009.flac', '2428-83705-0013.flac', '2428-83705-0007.flac', '2428-83705-0018.flac', '2428-83705-0024.flac', '2428-83705-0033.flac']

Хорошо, так что каждый подкаталог имеет много .flac файлов и .txt файл. .txt файл содержит текстовые транскрипции для всех образцов речи (т.е. .flac файлов) , присутствующие в этом подкаталоге.

Мы можем загрузить эти текстовые данные следующим образом:

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

Кроме того , мы определим функцию для загрузки образца речи из .flac файла.

REQUIRED_SAMPLE_RATE установлен в 16000 , как wav2vec2 был предварительно обучен с 16K частотой и рекомендуется отладить его без каких - либо серьезных изменений в распределении данных из - за частоты.

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

Теперь мы выберем несколько случайных выборок и попытаемся их визуализировать.

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: HE HAS GIVEN US FREE PASSES ALL THE WAY TO THE END OF OUR JOURNEY AND ALL THE WAY BACK AGAIN AND COUPONS FOR FREE BOARD AND LODGING AT THE HOTEL IT'S A WEDDING PRESENT 
Audio:

Теперь мы объединим все образцы речи и текста и определим функцию (в следующей ячейке) для этой цели.

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

Пришло время взглянуть на несколько образцов ...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([ 6.10351562e-05,  9.15527344e-05,  9.15527344e-05, ...,
         -3.05175781e-04, -5.79833984e-04, -8.23974609e-04]),
  'WHEN SHE HEARD OF MY ENGAGEMENT WITH MARY ANN SHE WROTE AND SUGGESTED THAT WE SHOULD SPEND OUR HONEYMOON IN HER COTTAGE OR PIGSTYE AND THAT I SHOULD PAY HER RENT FOR IT'),
 (array([-0.00112915, -0.00131226, -0.00158691, ...,  0.00067139,
          0.00091553,  0.00100708]),
  "IT MIGHT JUST AS WELL BE SOME ONE ELSE'S WEDDING SO UNIMPORTANT IS THE PART WHICH I AM SET TO PLAY IN IT"),
 (array([ 3.05175781e-05, -6.10351562e-05,  2.13623047e-04, ...,
         -5.18798828e-04, -2.13623047e-04, -2.74658203e-04]),
  'THE ACCIDENT IN QUESTION OCCURRED UPON THE SUNDAY EVENING'),
 (array([ 3.05175781e-04,  3.05175781e-05, -1.83105469e-04, ...,
          7.62939453e-04,  6.10351562e-04,  5.79833984e-04]),
  "OF COURSE THERE ARE SOME PEOPLE WITH WHOM YOU CAN'T BE PERFECTLY PLAIN BUT I SHALL BE AS PLAIN AS I CAN THERE'S A WAY AND A MANNER OF DOING THAT KIND OF THING"),
 (array([ 6.10351562e-05, -3.05175781e-05,  0.00000000e+00, ...,
         -3.66210938e-04, -7.93457031e-04, -1.19018555e-03]),
  'I KNOW WHAT MAMMA CAN AFFORD TO GIVE AND I WILL SEE SHE GIVES IT')]

Давайте сейчас предварительно обработаем данные !!!

Сначала мы определим Tokenizer & процессор используя gsoc-wav2vec2 пакет. Затем мы выполним очень простую предварительную обработку. processor нормализуется сырец речи wrto кадров оси и tokenizer преобразуют наши модели выходов в строку ( с использованием определенного словаря) и будет заботиться о удалении специальных маркеров ( в зависимости от конфигурации Tokenizer).

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

Теперь мы определим генератор Python для вызова функций предварительной обработки, которые мы определили в ячейках выше.

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

Настройка tf.data.Dataset

После установки ячейки будет tf.data.Dataset объект , используя его .from_generator(...) метод. Мы будем использовать generator объект, мы определили в вышеприведенном ячейке.

Вы можете обратиться к этому сценарию для получения более подробной информации о том , как преобразовать данные LibriSpeech в tfrecords.

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

Мы передадим набор данных в несколько пакетов, поэтому давайте подготовим пакеты в следующей ячейке. Теперь все последовательности в пакете должны быть дополнены до постоянной длины. Мы будем использовать .padded_batch(...) метод для этой цели.

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

Ускорители (например, графические процессоры / TPU) очень быстрые, и часто загрузка данных (и предварительная обработка) становятся узким местом во время обучения, поскольку часть загрузки данных происходит на процессорах. Это может значительно увеличить время обучения, особенно когда требуется много предварительной онлайн-обработки или данные передаются в потоковом режиме из корзин GCS. Для обработки этих вопросов, tf.data.Dataset предлагает .prefetch(...) метод. Этот метод помогает в подготовке следующих нескольких пакетов параллельно (на процессорах), пока модель делает прогнозы (на графических процессорах / TPU) для текущего пакета.

dataset = dataset.prefetch(tf.data.AUTOTUNE)

Поскольку этот ноутбук сделан для демонстрационных целей, мы будем принимать первые num_train_batches и будут проводить обучение в течение только. Тем не менее, вам рекомендуется обучаться всему набору данных. Кроме того , мы будем оценивать только num_val_batches .

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

Модельное обучение

Для обучения нашей модели, мы будем прямым вызов .fit(...) метод после компиляции нашей модели с .compile(...) .

model.compile(optimizer, loss=loss_fn)

Вышеупомянутая ячейка установит наше состояние обучения. Теперь мы можем начать обучение с .fit(...) метода.

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
10/10 [==============================] - 32s 2s/step - loss: 649.3215 - val_loss: 315.0721
Epoch 2/3
10/10 [==============================] - 17s 2s/step - loss: 242.1202 - val_loss: 336.5721
Epoch 3/3
10/10 [==============================] - 17s 2s/step - loss: 222.1239 - val_loss: 253.0467
{'loss': [649.321533203125, 242.1201629638672, 222.1239013671875],
 'val_loss': [315.0721435546875, 336.5721130371094, 253.0466766357422]}

Спасем нашу модель с .save(...) метод , чтобы быть в состоянии выполнить умозаключение позже. Вы также можете экспортировать этот SavedModel в TFHub следуя TFHub документации .

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
2021-11-05 11:44:54.280793: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 855). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

Оценка

Теперь мы будем вычислять частоту ошибок Word по набору данных проверки.

Слово частота ошибок (WER) является общим показателем для измерения производительности автоматической системы распознавания речи. WER является производным от расстояния Левенштейна, работающего на уровне слов. Затем коэффициент ошибок в словах можно вычислить как: WER = (S + D + I) / N = (S + D + I) / (S + D + C), где S - количество замен, D - количество удалений. , I - количество вставок, C - количество правильных слов, N - количество слов в ссылке (N = S + D + C). Это значение указывает процент неверно предсказанных слов.

Вы можете обратиться к этой статье , чтобы узнать больше о WER.

Мы будем использовать load_metric(...) функции из HuggingFace наборов данных библиотеки. Давайте сначала установить datasets библиотеки с помощью pip , а затем определить metric объект.

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

Пришло время запустить оценку данных проверки.

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2021-11-05 11:45:11.575128: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:57] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

Мы используем tokenizer.decode(...) способ декодирования наших предсказаний и этикетки обратно в текст и добавить их к метрике для WER вычисления позже.

Теперь давайте рассчитаем значение метрики в следующей ячейке:

metric.compute()
1.0

Вывод

Теперь, когда мы удовлетворены процессом обучения и сохранили модель в save_dir , мы увидим , как эта модель может быть использована для вывода.

Во- первых, мы загружаем нашу модель , используя tf.keras.models.load_model(...) .

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

Давайте загрузим образцы речи для выполнения логического вывода. Вы также можете заменить следующий образец своим речевым образцом.

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2021-11-05 11:45:28--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 13.114.40.48
Connecting to github.com (github.com)|13.114.40.48|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-11-05 11:45:28--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.02s   

2021-11-05 11:45:29 (5.38 MB/s) - ‘SA2.wav’ saved [94252/94252]

Теперь мы будем читать образец речи , используя soundfile.read(...) и блокнот его AUDIO_MAXLEN удовлетворять модели подписи. Тогда мы будем нормировать , что речевой образец , используя Wav2Vec2Processor экземпляр и будет кормить его в модель.

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 5.5087714 , -1.0872856 , -1.0728477 , ..., -1.3125695 ,
         -0.7992846 , -0.94512135],
        [ 5.508977  , -1.0873723 , -1.0727195 , ..., -1.3125291 ,
         -0.79928476, -0.9449429 ],
        [ 5.5091047 , -1.0871643 , -1.0728203 , ..., -1.312533  ,
         -0.7992611 , -0.94483167],
        ...,
        [ 5.5094743 , -1.0874028 , -1.0729864 , ..., -1.3126655 ,
         -0.7994431 , -0.9449925 ],
        [ 5.509465  , -1.0873648 , -1.072943  , ..., -1.3126557 ,
         -0.79943836, -0.94500387],
        [ 5.509408  , -1.0872416 , -1.0728781 , ..., -1.3125473 ,
         -0.7993649 , -0.9449776 ]]], dtype=float32)>

Числа декодирующих Давайте обратно в текстовую последовательность , используя Wav2Vec2tokenizer экземпляр, мы определили выше.

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['']

Этот прогноз является довольно случайным, поскольку модель никогда не обучалась на больших данных в этом блокноте (поскольку этот блокнот не предназначен для выполнения полного обучения). Вы получите хорошие прогнозы, если обучите эту модель на полном наборе данных LibriSpeech.

Наконец-то мы подошли к концу этой записной книжки. Но это не конец обучения TensorFlow для речевых задач , связанных, это хранилище содержит некоторые более удивительные учебники. В случае , если вы столкнулись с какой - либо ошибкой в этом ноутбуке, пожалуйста , создайте вопрос здесь .