TensorFlow.org で表示 | Google Colab で実行 | Google Colab で実行 | ノートブックをダウンロード |
このチュートリアルでは、語彙を変更してテキストのセンチメント分類を行う際に、tf.keras.utils.warmstart_embedding_matrix
API を使用してトレーニングを「ウォームスタート」する方法を説明します。
まず、基本語彙を使って単純な Keras モデルをトレーニングすることから始め、語彙を更新した後に、モデルのトレーニングを続行します。これは「ウォームスタート」と呼ばれる方法で、新しい語彙に合わせてテキスト埋め込み行列をマッピングし直す必要があります。
埋め込み行列
埋め込みは、類似する語彙トークンに類似するエンコーディングのある、効率的な密の表現を使用するための手法です。トレーナブルなパラメータです(重みは、モデルが高密度レイヤーの重みを学習するのと同じように、トレーニング中にモデルによって学習されます)。小さなデータセットでは、8 次元の埋め込みがあるのが一般的で、大規模なデータセットを操作する場合には、最大 1024 次元にもなります。高次元の埋め込みであるほど、粒度の高い単語関係をキャプチャできますが、学習にはより多くのデータが必要となります。
語彙
一意の単語のセットは語彙と呼ばれます。テキストモデルを構築するには、固定の語彙を選択する必要があります。語彙は、データセット内の最も共通する単語からビルドするのが一般的です。語彙を使用することで、各テキストを、埋め込み行列でルックアップできる一連の ID で表現することができます。語彙では、各テキストをテキストに出現する特定の単語で表現することができます。
埋め込み行列をウォームスタートする理由
モデルは、特定の語彙を表現する埋め込みのセットでトレーニングされます。モデルを更新または改善する必要がある場合、前回のランの重みを再利用することで、トレーニングを収束する時間が短縮されます。前回のランの埋め込み行列を使用するのは、より困難です。語彙に何らかの変更があると、単語と ID のマッピングが無効になってしまうのが問題です。
tf.keras.utils.warmstart_embedding_matrix
は、基本語彙の埋め込み行列から新しい語彙の埋め込み行列を作成することで、この問題を解決します。単語が両方の語彙に存在する場合、基本の埋め込みベクトルは新しい埋め込み行列の正しい位置にコピーされます。このため、語彙のサイズまたは順序が変更された後にトレーニングをウォームスタートすることが可能です。
セットアップ
pip install --pre -U "tensorflow>2.10" # Requires 2.11
import io
import numpy as np
import os
import re
import shutil
import string
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D
from tensorflow.keras.layers import TextVectorization
2024-01-11 18:43:24.681174: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:43:24.681214: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:43:24.682822: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
データセットを読み込む
チュートリアルでは、Large Movie Review Dataset を使用します。このデータセットでセンチメント分類器モデルをトレーニングし、その過程で、ゼロから埋め込みを学習します。詳細については、テキストの読み込みチュートリアルをご覧ください。
Keras ファイルユーティリティを使用してデータセットをダウンロードし、ディレクトリを確認します。
url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
dataset = tf.keras.utils.get_file(
"aclImdb_v1.tar.gz", url, untar=True, cache_dir=".", cache_subdir=""
)
dataset_dir = os.path.join(os.path.dirname(dataset), "aclImdb")
os.listdir(dataset_dir)
Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz 84125825/84125825 [==============================] - 5s 0us/step ['train', 'test', 'imdb.vocab', 'imdbEr.txt', 'README']
train/
ディレクトリには pos
フォルダと neg
フォルダがあり、それぞれに、positive と negative としてラベル付けされた映画レビューが含まれます。pos
フォルダと neg
フォルダのレビューを使用して、二項分類モデルをトレーニングします。
train_dir = os.path.join(dataset_dir, "train")
os.listdir(train_dir)
['neg', 'pos', 'unsup', 'urls_pos.txt', 'labeledBow.feat', 'unsupBow.feat', 'urls_neg.txt', 'urls_unsup.txt']
train
には、トレーニングセットを作成する前に削除する必要のある他のフォルダも含まれています。
remove_dir = os.path.join(train_dir, "unsup")
shutil.rmtree(remove_dir)
次に、tf.keras.utils.text_dataset_from_directory
を使用して、tf.data.Dataset
を作成します。このユーティリティの使用についての詳細は、こちらのテキスト分類チュートリアルをご覧ください。
train
ディレクトリを使用して、トレーニングセットと検証セットを作成します。検証の分割は 20% とします。
batch_size = 1024
seed = 123
train_ds = tf.keras.utils.text_dataset_from_directory(
"aclImdb/train",
batch_size=batch_size,
validation_split=0.2,
subset="training",
seed=seed,
)
val_ds = tf.keras.utils.text_dataset_from_directory(
"aclImdb/train",
batch_size=batch_size,
validation_split=0.2,
subset="validation",
seed=seed,
)
Found 25000 files belonging to 2 classes. Using 20000 files for training. Found 25000 files belonging to 2 classes. Using 5000 files for validation.
データセットを構成してパフォーマンスを改善する
Dataset.cache
と Dataset.prefetch
、またデータをディスクにキャッシュする方法については、データパフォーマンスガイドをご覧ください。
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
テキストの前処理
次に、センチメント分類モデルに必要なデータセットの前処理ステップを定義します。layers.TextVectorization
レイヤーを、映画レビューをベクトル化する任意のパラメータで初期化します。このレイヤーの使用方法については、テキスト分類チュートリアルをご覧ください。
# Create a custom standardization function to strip HTML break tags '<br />'.
def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
return tf.strings.regex_replace(
stripped_html, "[%s]" % re.escape(string.punctuation), ""
)
# Vocabulary size and number of words in a sequence.
vocab_size = 10000
sequence_length = 100
# Use the text vectorization layer to normalize, split, and map strings to
# integers. Note that the layer uses the custom standardization defined above.
# Set maximum_sequence length as all samples are not of the same length.
vectorize_layer = TextVectorization(
standardize=custom_standardization,
max_tokens=vocab_size,
output_mode="int",
output_sequence_length=sequence_length,
)
# Make a text-only dataset (no labels) and call `Dataset.adapt` to build the
# vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)
分類モデルを作成する
Keras Sequential API を使用して、センチメント分類モデルを定義します。
embedding_dim = 16
text_embedding = Embedding(vocab_size, embedding_dim, name="embedding")
text_input = tf.keras.Sequential(
[vectorize_layer, text_embedding], name="text_input"
)
classifier_head = tf.keras.Sequential(
[GlobalAveragePooling1D(), Dense(16, activation="relu"), Dense(1)],
name="classifier_head",
)
model = tf.keras.Sequential([text_input, classifier_head])
モデルをコンパイルしてトレーニングする
損失と精度を含む指標の可視化には、TensorBoard を使用します。tf.keras.callbacks.TensorBoard
を作成しましょう。
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="logs")
Adam
オプティマイザと BinaryCrossentropy
損失を使用して、モデルをコンパイルし、トレーニングします。
model.compile(
optimizer="adam",
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=["accuracy"],
)
model.fit(
train_ds,
validation_data=val_ds,
epochs=15,
callbacks=[tensorboard_callback],
)
Epoch 1/15 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1704998640.773595 92700 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 20/20 [==============================] - 6s 200ms/step - loss: 0.6921 - accuracy: 0.5028 - val_loss: 0.6905 - val_accuracy: 0.4886 Epoch 2/15 20/20 [==============================] - 1s 52ms/step - loss: 0.6876 - accuracy: 0.5028 - val_loss: 0.6849 - val_accuracy: 0.4886 Epoch 3/15 20/20 [==============================] - 1s 54ms/step - loss: 0.6800 - accuracy: 0.5028 - val_loss: 0.6759 - val_accuracy: 0.4886 Epoch 4/15 20/20 [==============================] - 1s 55ms/step - loss: 0.6684 - accuracy: 0.5028 - val_loss: 0.6629 - val_accuracy: 0.4886 Epoch 5/15 20/20 [==============================] - 1s 54ms/step - loss: 0.6520 - accuracy: 0.5028 - val_loss: 0.6452 - val_accuracy: 0.4886 Epoch 6/15 20/20 [==============================] - 1s 55ms/step - loss: 0.6304 - accuracy: 0.5028 - val_loss: 0.6231 - val_accuracy: 0.4888 Epoch 7/15 20/20 [==============================] - 1s 53ms/step - loss: 0.6039 - accuracy: 0.5299 - val_loss: 0.5974 - val_accuracy: 0.5604 Epoch 8/15 20/20 [==============================] - 1s 53ms/step - loss: 0.5738 - accuracy: 0.6198 - val_loss: 0.5698 - val_accuracy: 0.6246 Epoch 9/15 20/20 [==============================] - 1s 54ms/step - loss: 0.5418 - accuracy: 0.6837 - val_loss: 0.5420 - val_accuracy: 0.6694 Epoch 10/15 20/20 [==============================] - 1s 53ms/step - loss: 0.5098 - accuracy: 0.7314 - val_loss: 0.5157 - val_accuracy: 0.7052 Epoch 11/15 20/20 [==============================] - 1s 53ms/step - loss: 0.4794 - accuracy: 0.7636 - val_loss: 0.4919 - val_accuracy: 0.7344 Epoch 12/15 20/20 [==============================] - 1s 52ms/step - loss: 0.4514 - accuracy: 0.7874 - val_loss: 0.4712 - val_accuracy: 0.7530 Epoch 13/15 20/20 [==============================] - 1s 54ms/step - loss: 0.4262 - accuracy: 0.8051 - val_loss: 0.4536 - val_accuracy: 0.7638 Epoch 14/15 20/20 [==============================] - 1s 54ms/step - loss: 0.4037 - accuracy: 0.8195 - val_loss: 0.4387 - val_accuracy: 0.7716 Epoch 15/15 20/20 [==============================] - 1s 52ms/step - loss: 0.3837 - accuracy: 0.8324 - val_loss: 0.4263 - val_accuracy: 0.7824 <keras.src.callbacks.History at 0x7f86dd4af220>
このアプローチでは、モデルは約 85% の検証精度を達成します。
注意: 結果は、埋め込みレイヤーをトレーニングする前に、どのようにして重みがランダムに初期化されたかによって、多少異なる可能性があります。
モデルの要約を見ると、モデルの各レイヤーについて知ることができます。
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= text_input (Sequential) (None, 100, 16) 160000 classifier_head (Sequentia (None, 1) 289 l) ================================================================= Total params: 160289 (626.13 KB) Trainable params: 160289 (626.13 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
TensorBoard でモデルのメトリクスを可視化します。
# docs_infra: no_execute
%load_ext tensorboard
%tensorboard --logdir logs
語彙の再マッピング
では、語彙を更新し、ウォームスタートでトレーニングを続けることにしましょう。
まず、基本語彙と埋め込み行列を取得します。
embedding_weights_base = (
model.get_layer("text_input").get_layer("embedding").get_weights()[0]
)
vocab_base = vectorize_layer.get_vocabulary()
新しいより大きな語彙を生成するように、新しいベクトル化レイヤーを定義します。
# Vocabulary size and number of words in a sequence.
vocab_size_new = 10200
sequence_length = 100
vectorize_layer_new = TextVectorization(
standardize=custom_standardization,
max_tokens=vocab_size_new,
output_mode="int",
output_sequence_length=sequence_length,
)
# Make a text-only dataset (no labels) and call adapt to build the vocabulary.
text_ds = train_ds.map(lambda x, y: x)
vectorize_layer_new.adapt(text_ds)
# Get the new vocabulary
vocab_new = vectorize_layer_new.get_vocabulary()
# View the new vocabulary tokens that weren't in `vocab_base`
set(vocab_base) ^ set(vocab_new)
{'bullying', 'bumps', 'canvas', 'carole', 'chains', 'chairman', 'checks', 'coarse', 'competitive', 'component', 'compound', 'confirm', 'contemplate', 'coping', 'corporations', 'costuming', 'counterpart', 'crop', 'custody', 'cyborgs', 'daft', 'danced', 'daphne', 'darkest', 'davids', 'december', 'declared', 'defence', 'delve', 'demonstration', 'dense', 'denver', 'devilish', 'devious', 'dickinson', 'digs', 'directorwriter', 'download', 'effortless', 'electricity', 'elliot', 'enlightenment', 'erratic', 'exceedingly', 'eyeballs', 'fearless', 'fenton', 'fiennes', 'filter', 'fireworks', 'flipping', 'float', 'foggy', 'forgivable', 'framework', 'fulllength', 'funds', 'gamut', 'geeks', 'glee', 'goo', 'gripe', 'hardest', 'harmony', 'henchman', 'heritage', 'hg', 'hi', 'hightech', 'homework', 'houston', 'howards', 'hunger', 'imho', 'immigrants', 'improvised', 'impulse', 'inch', 'interpret', 'intimidating', 'iowa', 'jaffar', 'jeep', 'jock', 'kriemhild', 'kristofferson', 'lassie', 'laughoutloud', 'lennon', 'librarian', 'liza', 'locker', 'lommel', 'loren', 'lowered', 'marital', 'martins', 'mastroianni', 'megan', 'melt', 'mischievous', 'monstrosity', 'monumental', 'morse', 'mostel', 'muddy', 'noah', 'noirs', 'nostril', 'numbing', 'occupation', 'oceans', 'onesided', 'opus', 'organ', 'osullivan', 'otoole', 'overnight', 'parisian', 'partial', 'patriotism', 'pbs', 'penchant', 'penguin', 'plotted', 'powerfully', 'pows', 'practicing', 'prehistoric', 'prestigious', 'prevalent', 'prevents', 'profits', 'promotion', 'puke', 'pulse', 'punchline', 'quarters', 'rainer', 'ranting', 'rapists', 'rapture', 'rarity', 'rays', 'recommending', 'redeemed', 'refuge', 'refugee', 'relates', 'religions', 'remaking', 'renee', 'reply', 'restoration', 'resurrection', 'retreat', 'retro', 'rockets', 'romano', 'rooker', 'rooted', 'runtime', 'sap', 'scarred', 'secluded', 'selfabsorbed', 'separation', 'shattered', 'shenanigans', 'shootings', 'shue', 'silk', 'sm', 'soooo', 'spoton', 'sr', 'staple', 'stepfather', 'stoic', 'stud', 'suite', 'swanson', 'sweetness', 'sybil', 'tease', 'technological', 'tensions', 'theft', 'therapist', 'threats', 'tin', 'towel', 'transform', 'travelling', 'troupe', 'unremarkable', 'unsatisfied', 'untrue', 'vertigo', 'vic'}
keras.utils.warmstart_embedding_matrix
util を使用して、更新された埋め込みを生成します。
# Generate the updated embedding matrix
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
base_vocabulary=vocab_base,
new_vocabulary=vocab_new,
base_embeddings=embedding_weights_base,
new_embeddings_initializer="uniform",
)
# Update the model variable
updated_embedding_variable = tf.Variable(updated_embedding)
または
新しい埋め込み行列の初期化に使用したい埋め込み行列がある場合は、keras.initializers.Constant
を new_embeddings イニシャライザとして使用します。以下のブロックをコードセルにコピーして、試して組みましょう。こうすると、語彙に新しい単語があり、より優れた埋め込み行列初期化が必要な場合に便利です。
# generate updated embedding matrix
new_embedding = np.random.rand(len(vocab_new), 16)
updated_embedding = tf.keras.utils.warmstart_embedding_matrix(
base_vocabulary=vocab_base,
new_vocabulary=vocab_new,
base_embeddings=embedding_weights_base,
new_embeddings_initializer=tf.keras.initializers.Constant(
new_embedding
)
)
# update model variable
updated_embedding_variable = tf.Variable(updated_embedding)
埋め込み行列の形状が新しい語彙に合わせて変更されたかを検証します。
updated_embedding_variable.shape
TensorShape([10200, 16])
埋め込み行列が更新されたため、次は、レイヤーの重みを更新しましょう。
text_embedding_layer_new = Embedding(
vectorize_layer_new.vocabulary_size(), embedding_dim, name="embedding"
)
text_embedding_layer_new.build(input_shape=[None])
text_embedding_layer_new.embeddings.assign(updated_embedding)
text_input_new = tf.keras.Sequential(
[vectorize_layer_new, text_embedding_layer_new], name="text_input_new"
)
text_input_new.summary()
# Verify the shape of updated weights
# The new weights shape should reflect the new vocabulary size
text_input_new.get_layer("embedding").get_weights()[0].shape
Model: "text_input_new" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= text_vectorization_1 (Text (None, 100) 0 Vectorization) embedding (Embedding) (None, 100, 16) 163200 ================================================================= Total params: 163200 (637.50 KB) Trainable params: 163200 (637.50 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ (10200, 16)
新しいテキストベクトル化レイヤーを使用するように、モデルのアーキテクチャを変更します。
以下のように、モデルをチェックポイントから読み込んで、モデルのアーキテクチャを更新することもできます。
warm_started_model = tf.keras.Sequential([text_input_new, classifier_head])
warm_started_model.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= text_input_new (Sequential (None, 100, 16) 163200 ) classifier_head (Sequentia (None, 1) 289 l) ================================================================= Total params: 163489 (638.63 KB) Trainable params: 163489 (638.63 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
新しい語彙を受け入れるようにモデルを正しく更新しました。埋め込みレイヤーは、古い語彙の単語を古い埋め込みにマッピングして、学習の必要がある新しい語彙の埋め込みを初期化するように更新されています。モデルの残りの学習済重みは変更されません。モデルがウォームスタートされ、前回中断された場所からトレーニングが再開します。
再マッピングがうまく行われたかを検証することができます。基本語彙と新しい語彙の両方に存在する語彙の単語「the」のインデックスを取得し、埋め込みの値を比較しましょう。同じであるはずです。
# New vocab words
base_vocab_index = vectorize_layer("the")[0]
new_vocab_index = vectorize_layer_new("the")[0]
print(
warm_started_model.get_layer("text_input_new").get_layer("embedding")(
new_vocab_index
)
== embedding_weights_base[base_vocab_index]
)
tf.Tensor( [ True True True True True True True True True True True True True True True True], shape=(16,), dtype=bool)
ウォームスタートされたトレーニングを続ける
トレーニングがどのようにウォームスタートされたかに注目してください。最初のエポックの精度は、約 85% で、前回のトレーニングが終了したときの精度に近似しています。
model.compile(
optimizer="adam",
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=["accuracy"],
)
model.fit(
train_ds,
validation_data=val_ds,
epochs=15,
callbacks=[tensorboard_callback],
)
Epoch 1/15 20/20 [==============================] - 4s 137ms/step - loss: 0.3679 - accuracy: 0.8410 - val_loss: 0.4172 - val_accuracy: 0.7920 Epoch 2/15 20/20 [==============================] - 1s 54ms/step - loss: 0.3534 - accuracy: 0.8475 - val_loss: 0.4094 - val_accuracy: 0.7960 Epoch 3/15 20/20 [==============================] - 1s 52ms/step - loss: 0.3407 - accuracy: 0.8528 - val_loss: 0.4026 - val_accuracy: 0.8016 Epoch 4/15 20/20 [==============================] - 1s 52ms/step - loss: 0.3289 - accuracy: 0.8585 - val_loss: 0.3969 - val_accuracy: 0.8044 Epoch 5/15 20/20 [==============================] - 1s 53ms/step - loss: 0.3179 - accuracy: 0.8634 - val_loss: 0.3922 - val_accuracy: 0.8072 Epoch 6/15 20/20 [==============================] - 1s 52ms/step - loss: 0.3077 - accuracy: 0.8687 - val_loss: 0.3883 - val_accuracy: 0.8112 Epoch 7/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2981 - accuracy: 0.8740 - val_loss: 0.3851 - val_accuracy: 0.8122 Epoch 8/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2891 - accuracy: 0.8790 - val_loss: 0.3826 - val_accuracy: 0.8144 Epoch 9/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2805 - accuracy: 0.8827 - val_loss: 0.3807 - val_accuracy: 0.8168 Epoch 10/15 20/20 [==============================] - 1s 55ms/step - loss: 0.2724 - accuracy: 0.8862 - val_loss: 0.3793 - val_accuracy: 0.8170 Epoch 11/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2647 - accuracy: 0.8896 - val_loss: 0.3785 - val_accuracy: 0.8198 Epoch 12/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2574 - accuracy: 0.8939 - val_loss: 0.3781 - val_accuracy: 0.8218 Epoch 13/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2504 - accuracy: 0.8972 - val_loss: 0.3782 - val_accuracy: 0.8214 Epoch 14/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2437 - accuracy: 0.9000 - val_loss: 0.3787 - val_accuracy: 0.8232 Epoch 15/15 20/20 [==============================] - 1s 52ms/step - loss: 0.2373 - accuracy: 0.9033 - val_loss: 0.3795 - val_accuracy: 0.8236 <keras.src.callbacks.History at 0x7f86d853bb20>
ウォームスタートされたトレーニングを可視化する
# docs_infra: no_execute
%reload_ext tensorboard
%tensorboard --logdir logs
次のステップ
このチュートリアルでは、以下の内容を学習しました。
- 小さな語彙データセットで、センチメント分類モデルをゼロからトレーニングする
- 語彙サイズが変化したら、モデルのアーキテクチャを更新し、埋め込み行列をウォームスタートする
- データセットを拡大し、モデルの精度を絶えず改善する
埋め込みについての詳細は、Word2Vec と 言語を理解するためのトランスフォーマモデルチュートリアルをご覧ください。