Menyesuaikan Wav2Vec2 dengan kepala LM

Lihat di TensorFlow.org Jalankan di Google Colab Lihat di GitHub Unduh buku catatan Lihat model TF Hub

Dalam notebook ini, kami akan memuat pra-model terlatih wav2vec2 dari TFHub dan akan menyempurnakan pada LibriSpeech dataset dengan menambahkan kepala Bahasa Modeling (LM) di atas model pra-dilatih kami. Tugas yang mendasari adalah untuk membangun model untuk Automatic Speech Recognition yaitu diberikan beberapa pidato, model harus bisa menuliskan ke dalam teks.

Pengaturan

Sebelum menjalankan notebook ini, pastikan bahwa Anda berada di GPU runtime ( Runtime > Change runtime type > GPU ). Sel berikut akan menginstal gsoc-wav2vec2 paket & dependensinya.

pip3 install -q git+https://github.com/vasudevgupta7/gsoc-wav2vec2@main
sudo apt-get install -y libsndfile1-dev
pip3 install -q SoundFile
The following packages were automatically installed and are no longer required:
  linux-gcp-5.4-headers-5.4.0-1040 linux-gcp-5.4-headers-5.4.0-1043
  linux-gcp-5.4-headers-5.4.0-1044 linux-gcp-5.4-headers-5.4.0-1049
  linux-headers-5.4.0-1049-gcp linux-image-5.4.0-1049-gcp
  linux-modules-5.4.0-1049-gcp linux-modules-extra-5.4.0-1049-gcp
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  libflac-dev libogg-dev libvorbis-dev libvorbisfile3
The following NEW packages will be installed:
  libflac-dev libogg-dev libsndfile1-dev libvorbis-dev libvorbisfile3
0 upgraded, 5 newly installed, 0 to remove and 143 not upgraded.
Need to get 1040 kB of archives.
After this operation, 4481 kB of additional disk space will be used.
Get:1 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libogg-dev amd64 1.3.2-1 [156 kB]
Get:2 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libflac-dev amd64 1.3.2-1 [260 kB]
Get:3 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbisfile3 amd64 1.3.5-4.2 [16.0 kB]
Get:4 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic/main amd64 libvorbis-dev amd64 1.3.5-4.2 [321 kB]
Get:5 http://asia-east1.gce.archive.ubuntu.com/ubuntu bionic-updates/main amd64 libsndfile1-dev amd64 1.0.28-4ubuntu0.18.04.2 [287 kB]
Fetched 1040 kB in 1s (1041 kB/s)
Selecting previously unselected package libogg-dev:amd64.
(Reading database ... 282211 files and directories currently installed.)
Preparing to unpack .../libogg-dev_1.3.2-1_amd64.deb ...
Unpacking libogg-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libflac-dev:amd64.
Preparing to unpack .../libflac-dev_1.3.2-1_amd64.deb ...
Unpacking libflac-dev:amd64 (1.3.2-1) ...
Selecting previously unselected package libvorbisfile3:amd64.
Preparing to unpack .../libvorbisfile3_1.3.5-4.2_amd64.deb ...
Unpacking libvorbisfile3:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libvorbis-dev:amd64.
Preparing to unpack .../libvorbis-dev_1.3.5-4.2_amd64.deb ...
Unpacking libvorbis-dev:amd64 (1.3.5-4.2) ...
Selecting previously unselected package libsndfile1-dev.
Preparing to unpack .../libsndfile1-dev_1.0.28-4ubuntu0.18.04.2_amd64.deb ...
Unpacking libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Setting up libvorbisfile3:amd64 (1.3.5-4.2) ...
Setting up libogg-dev:amd64 (1.3.2-1) ...
Setting up libvorbis-dev:amd64 (1.3.5-4.2) ...
Setting up libflac-dev:amd64 (1.3.2-1) ...
Setting up libsndfile1-dev (1.0.28-4ubuntu0.18.04.2) ...
Processing triggers for libc-bin (2.27-3ubuntu1.2) ...

Setup model menggunakan TFHub

Kita akan mulai dengan mengimpor beberapa library/modul.

import os

import tensorflow as tf
import tensorflow_hub as hub
from wav2vec2 import Wav2Vec2Config

config = Wav2Vec2Config()

print("TF version:", tf.__version__)
TF version: 2.7.0

Pertama, kita akan men-download model kami dari TFHub & akan membungkus tanda tangan model kami dengan hub.KerasLayer untuk dapat menggunakan model ini seperti lapisan lainnya Keras. Untungnya, hub.KerasLayer dapat melakukan keduanya hanya dalam 1 baris.

pretrained_layer = hub.KerasLayer("https://tfhub.dev/vasudevgupta7/wav2vec2/1", trainable=True)

Anda dapat lihat ini naskah dalam kasus Anda tertarik dalam model mengekspor naskah. Objek pretrained_layer adalah versi yang dibekukan dari Wav2Vec2Model . Bobot pra-dilatih tersebut dikonversi dari HuggingFace PyTorch pra-dilatih bobot menggunakan script ini .

Awalnya, wav2vec2 telah dilatih sebelumnya dengan pendekatan pemodelan bahasa bertopeng dengan tujuan untuk mengidentifikasi representasi ucapan laten terkuantisasi yang sebenarnya untuk langkah waktu bertopeng. Anda dapat membaca lebih lanjut tentang tujuan pelatihan dalam kertas wav2vec 2.0: Sebuah Kerangka untuk Belajar Self-Diawasi of Speech Representasi .

Sekarang, kita akan mendefinisikan beberapa konstanta dan hyper-parameter yang akan berguna dalam beberapa sel berikutnya. AUDIO_MAXLEN sengaja diatur untuk 246000 sebagai tanda tangan model yang hanya menerima urutan panjang statis 246000 .

AUDIO_MAXLEN = 246000
LABEL_MAXLEN = 256
BATCH_SIZE = 2

Dalam sel berikut, kita akan membungkus pretrained_layer & lapisan padat (kepala LM) dengan API Fungsional Keras ini .

inputs = tf.keras.Input(shape=(AUDIO_MAXLEN,))
hidden_states = pretrained_layer(inputs)
outputs = tf.keras.layers.Dense(config.vocab_size)(hidden_states)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

Lapisan padat (didefinisikan di atas) adalah memiliki output dimensi vocab_size seperti yang kita inginkan untuk memprediksi probabilitas setiap token dalam kosa kata pada setiap langkah waktu.

Menyiapkan status pelatihan

Dalam TensorFlow, Model bobot yang dibangun hanya ketika model.call atau model.build disebut untuk pertama kalinya, sehingga sel berikut akan membangun bobot model untuk kita. Selanjutnya, kita akan berjalan model.summary() untuk memeriksa jumlah parameter dilatih.

model(tf.random.uniform(shape=(BATCH_SIZE, AUDIO_MAXLEN)))
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 246000)]          0         
                                                                 
 keras_layer (KerasLayer)    (None, 768, 768)          94371712  
                                                                 
 dense (Dense)               (None, 768, 32)           24608     
                                                                 
=================================================================
Total params: 94,396,320
Trainable params: 94,396,320
Non-trainable params: 0
_________________________________________________________________

Sekarang, kita perlu mendefinisikan loss_fn dan optimizer untuk dapat melatih model. Sel berikut akan melakukannya untuk kita. Kami akan menggunakan Adam optimizer untuk kesederhanaan. CTCLoss adalah jenis kerugian umum yang digunakan untuk tugas-tugas (seperti ASR ) di mana masukan sub-bagian tidak dapat dengan mudah disesuaikan dengan output sub-bagian. Anda dapat membaca lebih lanjut tentang CTC-kerugian yang luar biasa ini posting blog .

CTCLoss (dari gsoc-wav2vec2 paket) menerima 3 argumen: config , model_input_shape & division_factor . Jika division_factor=1 , maka kerugian akan hanya mendapatkan dijumlahkan, sehingga lulus division_factor sesuai untuk mendapatkan rata-rata lebih batch.

from wav2vec2 import CTCLoss

LEARNING_RATE = 5e-5

loss_fn = CTCLoss(config, (BATCH_SIZE, AUDIO_MAXLEN), division_factor=BATCH_SIZE)
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

Memuat & Pra-pemrosesan data

Sekarang mari kita men-download dataset LibriSpeech dari website resmi dan mengaturnya.

wget https://www.openslr.org/resources/12/dev-clean.tar.gz -P ./data/train/
tar -xf ./data/train/dev-clean.tar.gz -C ./data/train/
--2021-11-05 11:43:09--  https://www.openslr.org/resources/12/dev-clean.tar.gz
Resolving www.openslr.org (www.openslr.org)... 46.101.158.64
Connecting to www.openslr.org (www.openslr.org)|46.101.158.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 337926286 (322M) [application/x-gzip]
Saving to: ‘./data/train/dev-clean.tar.gz’

dev-clean.tar.gz    100%[===================>] 322.27M  11.6MB/s    in 31s     

2021-11-05 11:43:42 (10.3 MB/s) - ‘./data/train/dev-clean.tar.gz’ saved [337926286/337926286]
ls ./data/train/
LibriSpeech/  dev-clean.tar.gz

Dataset kami terletak di direktori LibriSpeech. Mari kita jelajahi file-file ini.

data_dir = "./data/train/LibriSpeech/dev-clean/2428/83705/"
all_files = os.listdir(data_dir)

flac_files = [f for f in all_files if f.endswith(".flac")]
txt_files = [f for f in all_files if f.endswith(".txt")]

print("Transcription files:", txt_files, "\nSound files:", flac_files)
Transcription files: ['2428-83705.trans.txt'] 
Sound files: ['2428-83705-0015.flac', '2428-83705-0004.flac', '2428-83705-0006.flac', '2428-83705-0026.flac', '2428-83705-0023.flac', '2428-83705-0001.flac', '2428-83705-0005.flac', '2428-83705-0040.flac', '2428-83705-0038.flac', '2428-83705-0042.flac', '2428-83705-0008.flac', '2428-83705-0019.flac', '2428-83705-0021.flac', '2428-83705-0002.flac', '2428-83705-0039.flac', '2428-83705-0034.flac', '2428-83705-0028.flac', '2428-83705-0000.flac', '2428-83705-0029.flac', '2428-83705-0041.flac', '2428-83705-0035.flac', '2428-83705-0032.flac', '2428-83705-0020.flac', '2428-83705-0025.flac', '2428-83705-0010.flac', '2428-83705-0014.flac', '2428-83705-0003.flac', '2428-83705-0031.flac', '2428-83705-0017.flac', '2428-83705-0027.flac', '2428-83705-0012.flac', '2428-83705-0043.flac', '2428-83705-0030.flac', '2428-83705-0022.flac', '2428-83705-0016.flac', '2428-83705-0037.flac', '2428-83705-0011.flac', '2428-83705-0036.flac', '2428-83705-0009.flac', '2428-83705-0013.flac', '2428-83705-0007.flac', '2428-83705-0018.flac', '2428-83705-0024.flac', '2428-83705-0033.flac']

Baiklah, jadi masing-masing sub-direktori memiliki banyak .flac file dan .txt file yang. The .txt file berisi transkripsi teks untuk semua sampel pidato (yaitu .flac file) hadir dalam sub-direktori.

Kita dapat memuat data teks ini sebagai berikut:

def read_txt_file(f):
  with open(f, "r") as f:
    samples = f.read().split("\n")
    samples = {s.split()[0]: " ".join(s.split()[1:]) for s in samples if len(s.split()) > 2}
  return samples

Demikian pula, kita akan mendefinisikan fungsi untuk memuat sampel pidato dari .flac berkas.

REQUIRED_SAMPLE_RATE diatur untuk 16000 sebagai wav2vec2 pra-dilatih dengan 16K frekuensi dan itu dianjurkan untuk menyempurnakan tanpa perubahan besar dalam distribusi data karena frekuensi.

import soundfile as sf

REQUIRED_SAMPLE_RATE = 16000

def read_flac_file(file_path):
  with open(file_path, "rb") as f:
      audio, sample_rate = sf.read(f)
  if sample_rate != REQUIRED_SAMPLE_RATE:
      raise ValueError(
          f"sample rate (={sample_rate}) of your files must be {REQUIRED_SAMPLE_RATE}"
      )
  file_id = os.path.split(file_path)[-1][:-len(".flac")]
  return {file_id: audio}

Sekarang, kami akan memilih beberapa sampel acak & akan mencoba memvisualisasikannya.

from IPython.display import Audio
import random

file_id = random.choice([f[:-len(".flac")] for f in flac_files])
flac_file_path, txt_file_path = os.path.join(data_dir, f"{file_id}.flac"), os.path.join(data_dir, "2428-83705.trans.txt")

print("Text Transcription:", read_txt_file(txt_file_path)[file_id], "\nAudio:")
Audio(filename=flac_file_path)
Text Transcription: HE HAS GIVEN US FREE PASSES ALL THE WAY TO THE END OF OUR JOURNEY AND ALL THE WAY BACK AGAIN AND COUPONS FOR FREE BOARD AND LODGING AT THE HOTEL IT'S A WEDDING PRESENT 
Audio:

Sekarang, kita akan menggabungkan semua contoh pidato & teks dan akan mendefinisikan fungsi (di sel berikutnya) untuk tujuan itu.

def fetch_sound_text_mapping(data_dir):
  all_files = os.listdir(data_dir)

  flac_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".flac")]
  txt_files = [os.path.join(data_dir, f) for f in all_files if f.endswith(".txt")]

  txt_samples = {}
  for f in txt_files:
    txt_samples.update(read_txt_file(f))

  speech_samples = {}
  for f in flac_files:
    speech_samples.update(read_flac_file(f))

  assert len(txt_samples) == len(speech_samples)

  samples = [(speech_samples[file_id], txt_samples[file_id]) for file_id in speech_samples.keys() if len(speech_samples[file_id]) < AUDIO_MAXLEN]
  return samples

Saatnya untuk melihat beberapa sampel ...

samples = fetch_sound_text_mapping(data_dir)
samples[:5]
[(array([ 6.10351562e-05,  9.15527344e-05,  9.15527344e-05, ...,
         -3.05175781e-04, -5.79833984e-04, -8.23974609e-04]),
  'WHEN SHE HEARD OF MY ENGAGEMENT WITH MARY ANN SHE WROTE AND SUGGESTED THAT WE SHOULD SPEND OUR HONEYMOON IN HER COTTAGE OR PIGSTYE AND THAT I SHOULD PAY HER RENT FOR IT'),
 (array([-0.00112915, -0.00131226, -0.00158691, ...,  0.00067139,
          0.00091553,  0.00100708]),
  "IT MIGHT JUST AS WELL BE SOME ONE ELSE'S WEDDING SO UNIMPORTANT IS THE PART WHICH I AM SET TO PLAY IN IT"),
 (array([ 3.05175781e-05, -6.10351562e-05,  2.13623047e-04, ...,
         -5.18798828e-04, -2.13623047e-04, -2.74658203e-04]),
  'THE ACCIDENT IN QUESTION OCCURRED UPON THE SUNDAY EVENING'),
 (array([ 3.05175781e-04,  3.05175781e-05, -1.83105469e-04, ...,
          7.62939453e-04,  6.10351562e-04,  5.79833984e-04]),
  "OF COURSE THERE ARE SOME PEOPLE WITH WHOM YOU CAN'T BE PERFECTLY PLAIN BUT I SHALL BE AS PLAIN AS I CAN THERE'S A WAY AND A MANNER OF DOING THAT KIND OF THING"),
 (array([ 6.10351562e-05, -3.05175781e-05,  0.00000000e+00, ...,
         -3.66210938e-04, -7.93457031e-04, -1.19018555e-03]),
  'I KNOW WHAT MAMMA CAN AFFORD TO GIVE AND I WILL SEE SHE GIVES IT')]

Mari kita pra-proses datanya sekarang!!!

Kami pertama akan menentukan tokenizer & prosesor menggunakan gsoc-wav2vec2 paket. Kemudian, kami akan melakukan pra-pemrosesan yang sangat sederhana. processor akan menormalkan pidato baku wrto frame sumbu dan tokenizer akan mengkonversi output model kami ke dalam string (menggunakan kosakata ditentukan) & akan mengurus penghapusan token khusus (tergantung pada konfigurasi tokenizer Anda).

from wav2vec2 import Wav2Vec2Processor
tokenizer = Wav2Vec2Processor(is_tokenizer=True)
processor = Wav2Vec2Processor(is_tokenizer=False)

def preprocess_text(text):
  label = tokenizer(text)
  return tf.constant(label, dtype=tf.int32)

def preprocess_speech(audio):
  audio = tf.constant(audio, dtype=tf.float32)
  return processor(tf.transpose(audio))
Downloading `vocab.json` from https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/vocab.json ... DONE

Sekarang, kita akan mendefinisikan generator python untuk memanggil fungsi preprocessing yang kita definisikan di sel di atas.

def inputs_generator():
  for speech, text in samples:
    yield preprocess_speech(speech), preprocess_text(text)

Menyiapkan tf.data.Dataset

Berikut sel akan setup tf.data.Dataset objek menggunakan nya .from_generator(...) metode. Kami akan menggunakan generator objek, kita didefinisikan dalam sel di atas.

Anda dapat merujuk ke script ini untuk rincian lebih lanjut tentang cara mengkonversi data LibriSpeech menjadi tfrecords.

output_signature = (
    tf.TensorSpec(shape=(None),  dtype=tf.float32),
    tf.TensorSpec(shape=(None), dtype=tf.int32),
)

dataset = tf.data.Dataset.from_generator(inputs_generator, output_signature=output_signature)
BUFFER_SIZE = len(flac_files)
SEED = 42

dataset = dataset.shuffle(BUFFER_SIZE, seed=SEED)

Kami akan meneruskan dataset menjadi beberapa batch, jadi mari kita siapkan batch di sel berikut. Sekarang, semua urutan dalam satu batch harus diisi dengan panjang yang konstan. Kami akan menggunakan .padded_batch(...) metode untuk tujuan itu.

dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=(AUDIO_MAXLEN, LABEL_MAXLEN), padding_values=(0.0, 0))

Akselerator (seperti GPU/TPU) sangat cepat dan seringkali pemuatan data (& pra-pemrosesan) menjadi hambatan selama pelatihan karena bagian pemuatan data terjadi pada CPU. Hal ini dapat meningkatkan waktu pelatihan secara signifikan terutama bila ada banyak pra-pemrosesan online yang terlibat atau data dialirkan secara online dari bucket GCS. Untuk menangani masalah tersebut, tf.data.Dataset menawarkan .prefetch(...) metode. Metode ini membantu mempersiapkan beberapa batch berikutnya secara paralel (pada CPU) sementara model membuat prediksi (pada GPU/TPU) pada batch saat ini.

dataset = dataset.prefetch(tf.data.AUTOTUNE)

Sejak notebook ini dibuat untuk tujuan demonstrasi, kami akan mengambil pertama num_train_batches dan akan berkinerja melatih lebih hanya itu. Anda dianjurkan untuk melatih seluruh dataset. Demikian pula, kita akan mengevaluasi hanya num_val_batches .

num_train_batches = 10
num_val_batches = 4

train_dataset = dataset.take(num_train_batches)
val_dataset = dataset.skip(num_train_batches).take(num_val_batches)

Pelatihan model

Untuk melatih model kami, kami akan langsung menelepon .fit(...) metode setelah kompilasi model kita dengan .compile(...) .

model.compile(optimizer, loss=loss_fn)

Sel di atas akan mengatur status pelatihan kita. Sekarang kita dapat memulai pelatihan dengan .fit(...) metode.

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)
history.history
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1447: alias_inplace_add (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_add, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/ctc_ops.py:1430: alias_inplace_update (from tensorflow.python.ops.inplace_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer tf.tensor_scatter_nd_update, which offers the same functionality with well-defined read-write semantics.
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
WARNING:tensorflow:Gradients do not exist for variables ['wav2vec2/masked_spec_embed:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss`argument?
10/10 [==============================] - 32s 2s/step - loss: 649.3215 - val_loss: 315.0721
Epoch 2/3
10/10 [==============================] - 17s 2s/step - loss: 242.1202 - val_loss: 336.5721
Epoch 3/3
10/10 [==============================] - 17s 2s/step - loss: 222.1239 - val_loss: 253.0467
{'loss': [649.321533203125, 242.1201629638672, 222.1239013671875],
 'val_loss': [315.0721435546875, 336.5721130371094, 253.0466766357422]}

Mari kita menyimpan model kita dengan .save(...) metode untuk dapat melakukan inferensi kemudian. Anda juga dapat mengekspor SavedModel ini untuk TFHub dengan mengikuti dokumentasi TFHub .

save_dir = "finetuned-wav2vec2"
model.save(save_dir, include_optimizer=False)
2021-11-05 11:44:54.280793: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 855). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets
INFO:tensorflow:Assets written to: finetuned-wav2vec2/assets

Evaluasi

Sekarang kita akan menghitung Tingkat Kesalahan Kata di atas kumpulan data validasi

Tingkat kesalahan kata (WER) adalah metrik umum untuk mengukur kinerja sistem pengenalan suara otomatis. WER berasal dari jarak Levenshtein, bekerja pada tingkat kata. Tingkat kesalahan kata kemudian dapat dihitung sebagai: WER = (S + D + I) / N = (S + D + I) / (S + D + C) di mana S adalah jumlah penggantian, D adalah jumlah penghapusan , I adalah jumlah penyisipan, C adalah jumlah kata yang benar, N adalah jumlah kata dalam referensi (N=S+D+C). Nilai ini menunjukkan persentase kata yang salah diprediksi.

Anda dapat merujuk ke tulisan ini untuk mempelajari lebih lanjut tentang WER.

Kami akan menggunakan load_metric(...) Fungsi dari HuggingFace dataset perpustakaan. Mari kita pertama menginstal datasets perpustakaan menggunakan pip dan kemudian menentukan metric objek.

!pip3 install -q datasets

from datasets import load_metric
metric = load_metric("wer")
Downloading:   0%|          | 0.00/1.95k [00:00<?, ?B/s]
@tf.function(jit_compile=True)
def eval_fwd(batch):
  logits = model(batch, training=False)
  return tf.argmax(logits, axis=-1)

Saatnya menjalankan evaluasi pada data validasi sekarang.

from tqdm.auto import tqdm

for speech, labels in tqdm(val_dataset, total=num_val_batches):
    predictions  = eval_fwd(speech)
    predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
    references = [tokenizer.decode(label, group_tokens=False) for label in labels.numpy().tolist()]
    metric.add_batch(references=references, predictions=predictions)
0%|          | 0/4 [00:00<?, ?it/s]
2021-11-05 11:45:11.575128: W tensorflow/compiler/tf2xla/kernels/random_ops.cc:57] Warning: Using tf.random.uniform with XLA compilation will ignore seeds; consider using tf.random.stateless_uniform instead if reproducible behavior is desired. model/keras_layer/StatefulPartitionedCall/StatefulPartitionedCall/wav2vec2/encoder/layers/0/stochastic_depth/random_uniform/RandomUniform

Kami menggunakan tokenizer.decode(...) metode untuk decoding prediksi dan label kami kembali ke dalam teks dan akan menambahkannya ke metrik untuk WER perhitungan nanti.

Sekarang, mari kita hitung nilai metrik di sel berikut:

metric.compute()
1.0

Kesimpulan

Sekarang kami puas dengan proses pelatihan & telah disimpan model di save_dir , kita akan melihat bagaimana model ini dapat digunakan untuk inferensi.

Pertama, kita akan memuat model kita menggunakan tf.keras.models.load_model(...) .

finetuned_model = tf.keras.models.load_model(save_dir)
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

Mari unduh beberapa contoh pidato untuk melakukan inferensi. Anda dapat mengganti contoh berikut dengan contoh pidato Anda juga.

wget https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
--2021-11-05 11:45:28--  https://github.com/vasudevgupta7/gsoc-wav2vec2/raw/main/data/SA2.wav
Resolving github.com (github.com)... 13.114.40.48
Connecting to github.com (github.com)|13.114.40.48|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav [following]
--2021-11-05 11:45:28--  https://raw.githubusercontent.com/vasudevgupta7/gsoc-wav2vec2/main/data/SA2.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94252 (92K) [audio/wav]
Saving to: ‘SA2.wav’

SA2.wav             100%[===================>]  92.04K  --.-KB/s    in 0.02s   

2021-11-05 11:45:29 (5.38 MB/s) - ‘SA2.wav’ saved [94252/94252]

Sekarang, kita akan membaca sampel pidato menggunakan soundfile.read(...) dan pad untuk AUDIO_MAXLEN untuk memenuhi tanda tangan model yang. Kemudian kita akan menormalkan bahwa sampel pidato menggunakan Wav2Vec2Processor contoh & akan memberi makan ke dalam model.

import numpy as np

speech, _ = sf.read("SA2.wav")
speech = np.pad(speech, (0, AUDIO_MAXLEN - len(speech)))
speech = tf.expand_dims(processor(tf.constant(speech)), 0)

outputs = finetuned_model(speech)
outputs
<tf.Tensor: shape=(1, 768, 32), dtype=float32, numpy=
array([[[ 5.5087714 , -1.0872856 , -1.0728477 , ..., -1.3125695 ,
         -0.7992846 , -0.94512135],
        [ 5.508977  , -1.0873723 , -1.0727195 , ..., -1.3125291 ,
         -0.79928476, -0.9449429 ],
        [ 5.5091047 , -1.0871643 , -1.0728203 , ..., -1.312533  ,
         -0.7992611 , -0.94483167],
        ...,
        [ 5.5094743 , -1.0874028 , -1.0729864 , ..., -1.3126655 ,
         -0.7994431 , -0.9449925 ],
        [ 5.509465  , -1.0873648 , -1.072943  , ..., -1.3126557 ,
         -0.79943836, -0.94500387],
        [ 5.509408  , -1.0872416 , -1.0728781 , ..., -1.3125473 ,
         -0.7993649 , -0.9449776 ]]], dtype=float32)>

Nomor decode Mari kembali ke urutan teks menggunakan Wav2Vec2tokenizer contoh, kita didefinisikan di atas.

predictions = tf.argmax(outputs, axis=-1)
predictions = [tokenizer.decode(pred) for pred in predictions.numpy().tolist()]
predictions
['']

Prediksi ini cukup acak karena model tidak pernah dilatih pada data besar di notebook ini (karena notebook ini tidak dimaksudkan untuk melakukan pelatihan lengkap). Anda akan mendapatkan prediksi yang baik jika Anda melatih model ini pada dataset LibriSpeech yang lengkap.

Akhirnya, kami telah mencapai akhir dari buku catatan ini. Tapi itu bukan akhir dari belajar TensorFlow untuk tugas-tugas yang berhubungan dengan pidato, ini repositori berisi beberapa tutorial lebih menakjubkan. Dalam kasus Anda temui setiap bug di notebook ini, silakan membuat masalah di sini .