使用 LM 头微调 Wav2Vec2

View on TensorFlow.org 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本 查看 TF Hub 模型

在本笔记本中,我们将从 TFHub 加载预训练的 wav2vec2 模型,然后通过在预训练模型上附加语言建模头 (LM) 来使用 LibriSpeech 数据集对该模型进行微调。基本任务是构建一个自动语音识别模型,即,提供一些语音,该模型应该能够将其转录成文本。

设置

在运行本笔记本之前,请确保您处于 GPU 运行时 (Runtime > Change runtime type > GPU)。以下单元将安装 gsoc-wav2vec2 包及其依赖项。

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2
  libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2
  libparted-fs-resize0
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libflac8 libogg-dev libvorbis-dev
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev
The following packages will be upgraded:
  libflac8
1 upgraded, 4 newly installed, 0 to remove and 170 not upgraded.
Need to get 1012 kB of archives.
After this operation, 4279 kB of additional disk space will be used.
Get:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libflac8 amd64 1.3.3-1ubuntu0.1 [103 kB]
Get:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/main amd64 libogg-dev amd64 1.3.4-0ubuntu1 [161 kB]
Get:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libflac-dev amd64 1.3.3-1ubuntu0.1 [151 kB]
Get:4 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/main amd64 libvorbis-dev amd64 1.3.6-2ubuntu1 [316 kB]
Get:5 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libsndfile1-dev amd64 1.0.28-7ubuntu0.1 [280 kB]
Fetched 1012 kB in 2s (428 kB/s)
(Reading database ... 140734 files and directories currently installed.)
Preparing to unpack .../libflac8_1.3.3-1ubuntu0.1_amd64.deb ...
Unpacking libflac8:amd64 (1.3.3-1ubuntu0.1) over (1.3.3-1build1) ...
Selecting previously unselected package libogg-dev:amd64.
Preparing to unpack .../libogg-dev_1.3.4-0ubuntu1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.4-0ubuntu1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.3-1ubuntu0.1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.3-1ubuntu0.1) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.6-2ubuntu1_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.6-2ubuntu1) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-7ubuntu0.1_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-7ubuntu0.1) ...
Setting up libogg-dev:amd64 (1.3.4-0ubuntu1) ...
Setting up libflac8:amd64 (1.3.3-1ubuntu0.1) ...
Setting up libvorbis-dev:amd64 (1.3.6-2ubuntu1) ...
Setting up libflac-dev:amd64 (1.3.3-1ubuntu0.1) ...
Setting up libsndfile1-dev (1.0.28-7ubuntu0.1) ...
Processing triggers for libc-bin (2.31-0ubuntu9.9) ...

使用 TFHub 设置模型

我们先导入一些库/模块。

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
2022-12-14 22:29:28.070678: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:29:28.070793: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:29:28.070803: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
TF version: 2.11.0

首先,我们将从 TFHub 下载模型,并使用 hub.KerasLayer 包装模型签名,以便能够像使用任何其他 Keras 层一样使用此模型。幸运的是,只需 1 行代码,hub.KerasLayer 就可以完成这两项操作。

注:当使用 hub.KerasLayer 加载模型时,模型会变得有点不透明,但有时我们需要对模型进行更精细的控制,然后可以使用 tf.keras.models.load_model(...) 加载模型。

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.

如果您对模型导出脚本感兴趣,可以参考此脚本。对象 pretrained_layerWav2Vec2Model 的冻结版本。这些预训练权重是使用此脚本从 HuggingFace PyTorch 预训练权重转换而来的。

最初,wav2vec2 是使用遮罩语言建模方法进行预训练的,目的是确定遮罩时间步长的真实量化潜在语音表示。有关训练目标的更多信息,可以参阅论文 wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations

现在,我们将定义一些常量和超参数,它们将在接下来的几个单元中用到。AUDIO_MAXLEN 有意设置为 246000,因为模型签名仅接受 246000 的静态序列长度。

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

在以下单元中,我们将使用 Keras 的函数 API 包装 pretrained_layer 和一个密集层(LM 头)。

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

上面定义的密集层的输出维度为 vocab_size,因为我们希望在每个时间步长都预测词汇表中每个标记的概率。

设置训练状态

在 TensorFlow 中,只有第一次调用 model.callmodel.build 时才会构建模型权重,因此下面的单元将为我们构建模型权重。此外,我们将运行 model.summary() 来检查可训练参数的总数。

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

现在,我们需要定义 loss_fn 和优化器才能训练模型。下面的单元将执行此操作。为简单起见,我们将使用 Adam 优化器。CTCLoss是一个常见的损失类型,用于输入子部分无法轻易与输出子部分对齐的任务(如 ASR)。您可以从这篇精彩的博客文章中了解有关 CTC 损失的更多信息。

CTCLoss(来自 gsoc-wav2vec2 包)接受 3 个参数:configmodel_input_shapedivision_factor。如果 division_factor=1,那么损失将简单地相加,因此相应地传递 division_factor 可获得批次的平均值。

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

加载和预处理数据

现在让我们从官方网站下载 LibriSpeech 数据集并进行设置。

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2022-12-14 22:29:47--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: http://us.openslr.org/resources/12/dev-clean.tar.gz [following]
--2022-12-14 22:29:47--  http://us.openslr.org/resources/12/dev-clean.tar.gz
Resolving us.openslr.org (us.openslr.org)... 46.101.158.64
Connecting to us.openslr.org (us.openslr.org)|46.101.158.64|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  26.4MB/s    in 13s     

2022-12-14 22:30:01 (24.5 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]

注:我们将使用 dev-clean 配置,因为本笔记本只为演示之用,因此只需要少量数据。完整的训练数据可以轻易地从 LibriSpeech 网站下载。

ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

我们的数据集位于 LibriSpeech 目录中。让我们浏览一下这些文件。

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0001.flac', '2428-83705-0037.flac', '2428-83705-0026.flac', '2428-83705-0025.flac', '2428-83705-0014.flac', '2428-83705-0022.flac', '2428-83705-0018.flac', '2428-83705-0024.flac', '2428-83705-0038.flac', '2428-83705-0006.flac', '2428-83705-0032.flac', '2428-83705-0012.flac', '2428-83705-0019.flac', '2428-83705-0036.flac', '2428-83705-0000.flac', '2428-83705-0003.flac', '2428-83705-0031.flac', '2428-83705-0017.flac', '2428-83705-0034.flac', '2428-83705-0042.flac', '2428-83705-0016.flac', '2428-83705-0028.flac', '2428-83705-0010.flac', '2428-83705-0009.flac', '2428-83705-0015.flac', '2428-83705-0040.flac', '2428-83705-0033.flac', '2428-83705-0023.flac', '2428-83705-0011.flac', '2428-83705-0039.flac', '2428-83705-0013.flac', '2428-83705-0005.flac', '2428-83705-0008.flac', '2428-83705-0021.flac', '2428-83705-0035.flac', '2428-83705-0030.flac', '2428-83705-0007.flac', '2428-83705-0020.flac', '2428-83705-0002.flac', '2428-83705-0043.flac', '2428-83705-0027.flac', '2428-83705-0004.flac', '2428-83705-0029.flac', '2428-83705-0041.flac']

每个子目录都有许多 .flac 文件和一个 .txt 文件。这个 .txt 文件包含该子目录中存在的所有语音样本(即 .flac 文件)的文本转录。

我们可以按如下方式加载此文本数据:

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

同样,我们将定义一个从 .flac 文件加载语音样本的函数。

REQUIRED_SAMPLE_RATE 设置为 16000,因为 wav2vec2 是使用 16K 频率进行预训练的,建议在数据分布不会因频率而发生任何重大变化的情况下进行微调。

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

现在,我们将挑选一些随机样本并尝试将它们可视化。

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: THAT WAS WHAT MISSUS MACPHERSON SAID TO ME ONLY THE OTHER DAY 
Audio:

现在,我们将结合所有语音和文本样本,并为该目的定义函数(在下一个单元中)。

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

是时候看看几个样本了...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([0.00054932, 0.00033569, 0.00021362, ..., 0.00061035, 0.00054932,
         0.00048828]),
  'BUT IT IS QUITE PLAIN TO ME THAT ALL THE ARRANGEMENTS FOR MY WEDDING ARE GOING TO BE MADE BY THE SNELLINGS'),
 (array([-6.10351562e-05, -6.10351562e-05, -3.05175781e-05, ...,
         -2.13623047e-04, -9.15527344e-05, -3.05175781e-05]),
  'I CANNOT PRETEND TO EXPLAIN WHY EXCEPT ON THE SUPPOSITION THAT ROMANCE IS DEAD AT LEAST IN THAT CIRCLE OF SOCIETY IN WHICH THE SNELLINGS MOVE'),
 (array([ 3.05175781e-04,  3.05175781e-05, -1.83105469e-04, ...,
          7.62939453e-04,  6.10351562e-04,  5.79833984e-04]),
  "OF COURSE THERE ARE SOME PEOPLE WITH WHOM YOU CAN'T BE PERFECTLY PLAIN BUT I SHALL BE AS PLAIN AS I CAN THERE'S A WAY AND A MANNER OF DOING THAT KIND OF THING"),
 (array([-0.00018311, -0.00021362, -0.00021362, ...,  0.00073242,
          0.0007019 ,  0.00057983]),
  "I HAVE DRAWN UP A LIST OF ALL THE PEOPLE WHO OUGHT TO GIVE US A PRESENT AND I SHALL TELL THEM WHAT THEY OUGHT TO GIVE IT WON'T BE MY FAULT IF I DON'T GET IT"),
 (array([-0.00027466, -0.00033569, -0.00036621, ...,  0.00021362,

          0.        ,  0.        ]),
  'THERE SHE OWNS A COTTAGE OR IT MAY BE A PIGSTYE FOR ALL I KNOW')]

注:我们将此数据加载到内存中是因为我们在本笔记本中只处理少量数据集。但是对于基于完整数据集(约 300 GB)的训练,您必须延迟加载数据。您可以参考此脚本以了解更多相关信息。

现在,让我们对数据进行预处理!!!

我们先使用 gsoc-wav2vec2 包定义分词器 (tokenizer) 和处理器 (processor)。然后,我们将进行非常简单的预处理。processor 将根据帧轴对原始语音进行标准化,tokenizer 会将我们的模型输出转换为字符串(使用定义的词汇表)并负责删除特殊标记(取决于您的分词器配置)。

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

现在,我们将定义 python 生成器来调用我们在上面的单元中定义的预处理函数。

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

设置tf.data.Dataset

以下单元将使用 tf.data.Dataset 对象的 .from_generator(...) 方法设置该对象。我们将使用上面的单元中定义的 generator 对象。

注:对于分布式训练(尤其是 TPU 上的训练),.from_generator(...) 目前不起作用,建议使用以 .tfrecord 格式存储的数据进行训练(注意:理想情况下,TFRecord 应存储在 GCS 存储分区中,以使 TPU 发挥最大作用)。

有关如何将 LibriSpeech 数据转换为 tfrecord 的更多详细信息,可以参考此脚本

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

我们会将数据集传递到多个批次中,因此让我们在以下单元中准备批次。现在,一个批次中的所有序列都应该填充到一个恒定的长度。为此,我们将使用 .padded_batch(...) 方法。

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

加速器(如 GPU/TPU)非常快,数据加载(和预处理)经常成为训练过程中的瓶颈,因为数据加载部分在 CPU 上进行。这会明显增加训练时间,尤其是涉及大量在线预处理或从 GCS 存储分区在线流式传输数据时。为了解决这些问题,tf.data.Dataset 提供了 .prefetch(...) 方法。当模型对当前批次进行预测(在 GPU/TPU 上)时,该方法有助于并行准备接下来的几个批次(在 CPU 上)。

dataset = dataset.prefetch(tf.data.AUTOTUNE)

由于本笔记本用于演示目的,因此我们将采用第一个 num_train_batches 并只使用它进行训练。不过,我们鼓励您使用整个数据集进行训练。同样,我们将只评估 num_val_batches

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

模型训练

为了训练我们的模型,我们将在使用 .compile(...) 编译模型后直接调用 .fit(...) 方法。

model.compile(optimizer, loss=loss_fn)

上面的单元将设置我们的训练状态。现在我们可以使用 .fit(...) 方法开始训练。

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1467: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1467: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1450: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/ctc_ops.py:1450: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
10/10 [==============================] - 56s 2s/step - loss: 950.4785 - val_loss: 392.1649
Epoch 2/3
10/10 [==============================] - 17s 2s/step - loss: 432.3022 - val_loss: 501.2292
Epoch 3/3
10/10 [==============================] - 17s 2s/step - loss: 383.9614 - val_loss: 337.3086
{'loss': [950.478515625, 432.3021545410156, 383.96142578125],
 'val_loss': [392.1649475097656, 501.22918701171875, 337.30859375]}

让我们使用 .save(...) 方法保存模型,以便稍后执行推断。您还可以按照 TFHub 文档将此 SavedModel 导出到 TFHub。

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 342). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

注:我们设置 include_optimizer=False,因为我们只想将此模型用于推断。

评估

现在我们将基于验证数据集计算词错误率

词错误率 (WER) 是衡量自动语音识别系统性能的常用指标。WER 源自适用于单词级别的 Levenshtein 距离。词错误率的计算方法为:WER = (S + D + I) / N = (S + D + I) / (S + D + C),其中 S 是替换次数,D 是删除次数,I 是插入次数,C 是正确词数,N 是参考中的字数 (N=S+D+C)。该值表示被错误预测的单词的百分比。

您可以参考此论文来了解有关 WER 的更多信息。

我们将使用 HuggingFace 数据集库中的 load_metric(...) 函数。我们先使用 pip 安装 datasets 库,然后定义 metric 对象。

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
/tmpfs/tmp/ipykernel_102470/1786190190.py:4: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  metric = load_metric("wer")
Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

现在是时候根据验证数据进行评估了。

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2022-12-14 22:31:54.557870: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:57] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

我们使用 tokenizer.decode(...) 方法将预测和标签解码回文本,并将它们添加到 metric 中,以便稍后计算 WER

现在,我们在以下单元中计算 metric:

metric.compute()
1.0

注:这里的 metric 值没有任何意义,因为模型是使用非常少的数据训练的,而类似 ASR 的任务通常需要大量数据来学习从语音到文本的映射。您也许应该基于大数据进行训练以获得好的结果。本笔记本提供了一个微调预训练语音模型的模板。

推断

现在,我们对训练过程感到满意并将模型保存在 save_dir 中,我们将了解如何使用该模型进行推断。

首先,我们将使用 tf.keras.models.load_model(...) 加载模型。

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

让我们下载一些语音样本来进行推断。您也可以将以下样本替换为您的语音样本。

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2022-12-14 22:32:11--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://github.com/thevasudevgupta/gsoc-wav2vec2/raw/main/data/SA2.wav [following]
--2022-12-14 22:32:11--  https://github.com/thevasudevgupta/gsoc-wav2vec2/raw/main/data/SA2.wav
Reusing existing connection to github.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/thevasudevgupta/gsoc-wav2vec2/main/data/SA2.wav [following]
--2022-12-14 22:32:12--  https://raw.githubusercontent.com/thevasudevgupta/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.01s   

2022-12-14 22:32:12 (6.25 MB/s) - ‘SA2.wav’ saved [94252/94252]

现在,我们将使用 soundfile.read(...) 读取语音样本并将其填充到 AUDIO_MAXLEN 以符合模型签名。然后我们将使用 Wav2Vec2Processor 实例将该语音样本标准化并将其馈入模型。

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 0.6918487 , -0.9387155 , -1.0434096 , ..., -0.5661708 ,
         -0.6654788 , -0.23725727],
        [ 0.62171304, -0.9108545 , -1.1502925 , ..., -0.53505874,
         -0.6549468 , -0.2891952 ],
        [ 0.6588618 , -0.9884922 , -1.1409014 , ..., -0.55896515,
         -0.60972726, -0.2960359 ],
        ...,
        [ 0.54219365, -0.8249726 , -1.3309978 , ..., -0.2326689 ,
         -0.7179676 , -0.56859434],
        [ 0.53588253, -0.8336109 , -1.3327637 , ..., -0.23921491,
         -0.70584655, -0.56336695],
        [ 0.5281298 , -0.8327328 , -1.3418187 , ..., -0.24567655,
         -0.70036346, -0.5704396 ]]], dtype=float32)>

我们使用上面定义的 Wav2Vec2tokenizer 实例将数字解码回文本序列。

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['O']

预测结果是相当随机的,因为本笔记本从未基于大数据训练模型(本笔记本并不适合进行完整训练)。如果您基于完整的 LibriSpeech 数据集训练此模型,您将获得良好的预测结果。

终于,我们来到本笔记本的结尾。但这并不是学习 TensorFlow 执行语音相关任务的结束,此仓库包含一些更精彩的教程。如果您在本笔记本中遇到任何错误,请在此处创建问题。