はじめに
コールバックは、トレーニング、評価、推論の間に Keras モデルの動作をカスタマイズするための強力なツールです。例には、TensorBoard でトレーニングの進捗状況や結果を可視化できる tf.keras.callbacks.TensorBoard
や、トレーニング中にモデルを定期的に保存できる tf.keras.callbacks.ModelCheckpoint
などを含みます。
このガイドでは、Keras コールバックとは何か、それができること、そして独自のコールバックを構築する方法を学ぶことができます。まずは、簡単なコールバックアプリケーションのデモをいくつか紹介します。
Setup
import tensorflow as tf
from tensorflow import keras
2022-12-14 21:33:44.444160: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:33:44.444264: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 21:33:44.444275: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Keras コールバックの概要
全てのコールバックは keras.callbacks.Callbacks.Callback
クラスをサブクラス化し、トレーニング、テスト、予測のさまざまな段階で呼び出される一連のメソッドをオーバーライドします。コールバックは、トレーニング中にモデルの内部状態や統計上のビューを取得するのに有用です。
以下のモデルメソッドには、(キーワード引数 callbacks
として)コールバックのリストを渡すことができます。
コールバックメソッドの概要
グローバルメソッド
on_(train|test|predict)_begin(self, logs=None)
fit
/evaluate
/predict
の先頭で呼び出されます。
on_(train|test|predict)_end(self, logs=None)
fit
/evaluate
/predict
の最後に呼び出されます。
トレーニング/テスト/予測のためのバッチレベルのメソッド
on_(train|test|predict)_batch_begin(self, batch, logs=None)
トレーニング/テスト/予測中に、バッチを処理する直前に呼び出されます。
on_(train|test|predict)_batch_end(self, batch, logs=None)
バッチのトレーニング/テスト/予測の終了時に呼び出されます。このメソッド内では、logs
はメトリクスの結果を含むディクショナリです。
エポックレベルのメソッド(トレーニングのみ)
on_epoch_begin(self, epoch, logs=None)
トレーニング中に、エポックの最初に呼び出されます。
on_epoch_end(self, epoch, logs=None)
トレーニング中、エポックの最後に呼び出されます。
基本的な例
具体的な例を見てみましょう。まず最初に、TensorFlow をインポートして単純な Sequential Keras モデルを定義してみます。
# Define the Keras model to add callbacks to
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, input_dim=784))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
次に、Keras データセット API からトレーニングとテスト用の MNIST データを読み込みます。
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
今度は、以下のログを記録する単純なカスタムコールバックを定義します。
- When
fit
/evaluate
/predict
starts & ends - When each epoch starts & ends
- 各トレーニングバッチの開始時と終了時
- 各評価(テスト)バッチの開始時と終了時
- 各推論(予測)バッチの開始時と終了時
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
試してみましょう。
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)
res = model.evaluate(
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Starting training; got log keys: [] Start epoch 0 of training; got log keys: [] ...Training: start of batch 0; got log keys: [] ...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 1; got log keys: [] ...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 2; got log keys: [] ...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 3; got log keys: [] ...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 4; got log keys: [] ...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 5; got log keys: [] ...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 6; got log keys: [] ...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 7; got log keys: [] ...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] Start predicting; got log keys: [] ...Predicting: start of batch 0; got log keys: [] ...Predicting: end of batch 0; got log keys: ['outputs'] 1/8 [==>...........................] - ETA: 0s...Predicting: start of batch 1; got log keys: [] ...Predicting: end of batch 1; got log keys: ['outputs'] ...Predicting: start of batch 2; got log keys: [] ...Predicting: end of batch 2; got log keys: ['outputs'] ...Predicting: start of batch 3; got log keys: [] ...Predicting: end of batch 3; got log keys: ['outputs'] ...Predicting: start of batch 4; got log keys: [] ...Predicting: end of batch 4; got log keys: ['outputs'] ...Predicting: start of batch 5; got log keys: [] ...Predicting: end of batch 5; got log keys: ['outputs'] ...Predicting: start of batch 6; got log keys: [] ...Predicting: end of batch 6; got log keys: ['outputs'] ...Predicting: start of batch 7; got log keys: [] ...Predicting: end of batch 7; got log keys: ['outputs'] Stop predicting; got log keys: [] 8/8 [==============================] - 0s 2ms/step
logs
ディクショナリを使用する
logs
ディクショナリは、バッチまたはエポックの最後の損失値と全てのメトリクスを含みます。次の例は、損失値と平均絶対誤差を含んでいます。
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
)
def on_test_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
)
def on_epoch_end(self, epoch, logs=None):
print(
"The average loss for epoch {} is {:7.2f} "
"and mean absolute error is {:7.2f}.".format(
epoch, logs["loss"], logs["mean_absolute_error"]
)
)
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=2,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
Up to batch 0, the average loss is 24.54. Up to batch 1, the average loss is 469.66. Up to batch 2, the average loss is 322.26. Up to batch 3, the average loss is 243.81. Up to batch 4, the average loss is 196.38. Up to batch 5, the average loss is 164.77. Up to batch 6, the average loss is 142.26. Up to batch 7, the average loss is 128.22. The average loss for epoch 0 is 128.22 and mean absolute error is 6.06. Up to batch 0, the average loss is 6.47. Up to batch 1, the average loss is 5.57. Up to batch 2, the average loss is 5.40. Up to batch 3, the average loss is 5.09. Up to batch 4, the average loss is 4.89. Up to batch 5, the average loss is 4.75. Up to batch 6, the average loss is 4.74. Up to batch 7, the average loss is 4.74. The average loss for epoch 1 is 4.74 and mean absolute error is 1.75. Up to batch 0, the average loss is 6.22. Up to batch 1, the average loss is 5.95. Up to batch 2, the average loss is 5.94. Up to batch 3, the average loss is 5.91. Up to batch 4, the average loss is 6.01. Up to batch 5, the average loss is 6.14. Up to batch 6, the average loss is 6.10. Up to batch 7, the average loss is 6.07.
self.model
属性を使用する
コールバックは、そのメソッドの 1 つが呼び出された時にログ情報を受け取ることに加え、現在のトレーニング/評価/推論のラウンドに関連付けられたモデルに、self.model
でアクセスすることができます。
コールバックで self.model
を使用してできることを幾つか次に示します。
self.model.stop_training = True
を設定して直ちにトレーニングを中断する。self.model.optimizer.learning_rate
など、オプティマイザ(self.model.optimizer
として使用可能)のハイパーパラメータを変化させる。- 一定間隔でモデルを保存する。
- 各エポックの終了時に幾つかのテストサンプルの
model.predict()
の出力を記録し、トレーニング中にサ二ティーチェックとして使用する。 - 各エポックの終了時に中間特徴の可視化を抽出して、モデルが何を学習しているかを経時的に監視する。
- など
これを確認するために、2 つの例で見てみましょう。
Keras コールバックアプリケーションの例
最小損失で Early stopping する
この最初の例は、属性 self.model.stop_training
(ブール)を設定して、損失の最小値に達した時点でトレーニングを停止する Callback
を作成しています。オプションで、ローカル最小値に到達した後、実際に停止するまでに幾つのエポックを待つべきか、引数 patience
で指定することが可能です。
tf.keras.callbacks.EarlyStopping
は、より完全で一般的な実装を提供します。
import numpy as np
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=30,
verbose=0,
callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
Up to batch 0, the average loss is 28.01. Up to batch 1, the average loss is 478.97. Up to batch 2, the average loss is 326.88. Up to batch 3, the average loss is 247.41. Up to batch 4, the average loss is 199.61. The average loss for epoch 0 is 199.61 and mean absolute error is 8.38. Up to batch 0, the average loss is 8.27. Up to batch 1, the average loss is 7.42. Up to batch 2, the average loss is 6.64. Up to batch 3, the average loss is 6.49. Up to batch 4, the average loss is 6.43. The average loss for epoch 1 is 6.43 and mean absolute error is 2.08. Up to batch 0, the average loss is 3.78. Up to batch 1, the average loss is 4.07. Up to batch 2, the average loss is 4.41. Up to batch 3, the average loss is 4.35. Up to batch 4, the average loss is 4.22. The average loss for epoch 2 is 4.22 and mean absolute error is 1.66. Up to batch 0, the average loss is 4.30. Up to batch 1, the average loss is 6.80. Up to batch 2, the average loss is 7.85. Up to batch 3, the average loss is 11.09. Up to batch 4, the average loss is 17.11. The average loss for epoch 3 is 17.11 and mean absolute error is 3.40. Restoring model weights from the end of the best epoch. Epoch 00004: early stopping <keras.callbacks.History at 0x7f2820238310>
学習率をスケジューリングする
この例では、トレーニングの過程でカスタムコールバックを使用して、オプティマイザの学習率を動的に変更する方法を示します。
より一般的な実装については、callbacks.LearningRateScheduler
をご覧ください。
class CustomLearningRateScheduler(keras.callbacks.Callback):
"""Learning rate scheduler which sets the learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index
(integer, indexed from 0) and current learning rate
as inputs and returns a new learning rate as output (float).
"""
def __init__(self, schedule):
super(CustomLearningRateScheduler, self).__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "lr"):
raise ValueError('Optimizer must have a "lr" attribute.')
# Get the current learning rate from model's optimizer.
lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))
LR_SCHEDULE = [
# (epoch to start, learning rate) tuples
(3, 0.05),
(6, 0.01),
(9, 0.005),
(12, 0.001),
]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
steps_per_epoch=5,
epochs=15,
verbose=0,
callbacks=[
LossAndErrorPrintingCallback(),
CustomLearningRateScheduler(lr_schedule),
],
)
Epoch 00000: Learning rate is 0.1000. Up to batch 0, the average loss is 31.39. Up to batch 1, the average loss is 474.76. Up to batch 2, the average loss is 323.83. Up to batch 3, the average loss is 245.54. Up to batch 4, the average loss is 198.48. The average loss for epoch 0 is 198.48 and mean absolute error is 8.41. Epoch 00001: Learning rate is 0.1000. Up to batch 0, the average loss is 8.15. Up to batch 1, the average loss is 6.27. Up to batch 2, the average loss is 6.25. Up to batch 3, the average loss is 6.16. Up to batch 4, the average loss is 6.00. The average loss for epoch 1 is 6.00 and mean absolute error is 2.00. Epoch 00002: Learning rate is 0.1000. Up to batch 0, the average loss is 4.04. Up to batch 1, the average loss is 4.12. Up to batch 2, the average loss is 3.87. Up to batch 3, the average loss is 4.20. Up to batch 4, the average loss is 4.33. The average loss for epoch 2 is 4.33 and mean absolute error is 1.68. Epoch 00003: Learning rate is 0.0500. Up to batch 0, the average loss is 4.64. Up to batch 1, the average loss is 4.56. Up to batch 2, the average loss is 4.37. Up to batch 3, the average loss is 4.59. Up to batch 4, the average loss is 4.29. The average loss for epoch 3 is 4.29 and mean absolute error is 1.65. Epoch 00004: Learning rate is 0.0500. Up to batch 0, the average loss is 3.17. Up to batch 1, the average loss is 3.19. Up to batch 2, the average loss is 3.39. Up to batch 3, the average loss is 3.22. Up to batch 4, the average loss is 3.31. The average loss for epoch 4 is 3.31 and mean absolute error is 1.44. Epoch 00005: Learning rate is 0.0500. Up to batch 0, the average loss is 5.49. Up to batch 1, the average loss is 6.32. Up to batch 2, the average loss is 6.98. Up to batch 3, the average loss is 7.44. Up to batch 4, the average loss is 7.11. The average loss for epoch 5 is 7.11 and mean absolute error is 2.17. Epoch 00006: Learning rate is 0.0100. Up to batch 0, the average loss is 4.03. Up to batch 1, the average loss is 4.08. Up to batch 2, the average loss is 3.89. Up to batch 3, the average loss is 3.79. Up to batch 4, the average loss is 3.78. The average loss for epoch 6 is 3.78 and mean absolute error is 1.56. Epoch 00007: Learning rate is 0.0100. Up to batch 0, the average loss is 4.43. Up to batch 1, the average loss is 4.00. Up to batch 2, the average loss is 3.72. Up to batch 3, the average loss is 3.64. Up to batch 4, the average loss is 3.34. The average loss for epoch 7 is 3.34 and mean absolute error is 1.44. Epoch 00008: Learning rate is 0.0100. Up to batch 0, the average loss is 4.40. Up to batch 1, the average loss is 3.69. Up to batch 2, the average loss is 3.62. Up to batch 3, the average loss is 3.55. Up to batch 4, the average loss is 3.31. The average loss for epoch 8 is 3.31 and mean absolute error is 1.44. Epoch 00009: Learning rate is 0.0050. Up to batch 0, the average loss is 2.58. Up to batch 1, the average loss is 3.41. Up to batch 2, the average loss is 3.80. Up to batch 3, the average loss is 3.87. Up to batch 4, the average loss is 3.59. The average loss for epoch 9 is 3.59 and mean absolute error is 1.46. Epoch 00010: Learning rate is 0.0050. Up to batch 0, the average loss is 2.16. Up to batch 1, the average loss is 3.18. Up to batch 2, the average loss is 2.96. Up to batch 3, the average loss is 3.33. Up to batch 4, the average loss is 3.48. The average loss for epoch 10 is 3.48 and mean absolute error is 1.42. Epoch 00011: Learning rate is 0.0050. Up to batch 0, the average loss is 3.29. Up to batch 1, the average loss is 3.03. Up to batch 2, the average loss is 3.27. Up to batch 3, the average loss is 3.11. Up to batch 4, the average loss is 3.03. The average loss for epoch 11 is 3.03 and mean absolute error is 1.38. Epoch 00012: Learning rate is 0.0010. Up to batch 0, the average loss is 3.55. Up to batch 1, the average loss is 3.31. Up to batch 2, the average loss is 3.36. Up to batch 3, the average loss is 3.34. Up to batch 4, the average loss is 3.26. The average loss for epoch 12 is 3.26 and mean absolute error is 1.41. Epoch 00013: Learning rate is 0.0010. Up to batch 0, the average loss is 4.06. Up to batch 1, the average loss is 3.39. Up to batch 2, the average loss is 3.44. Up to batch 3, the average loss is 3.71. Up to batch 4, the average loss is 3.55. The average loss for epoch 13 is 3.55 and mean absolute error is 1.48. Epoch 00014: Learning rate is 0.0010. Up to batch 0, the average loss is 3.41. Up to batch 1, the average loss is 3.62. Up to batch 2, the average loss is 3.40. Up to batch 3, the average loss is 3.38. Up to batch 4, the average loss is 3.22. The average loss for epoch 14 is 3.22 and mean absolute error is 1.38. <keras.callbacks.History at 0x7f28201d5c70>
組み込みの Keras コールバック
既存の Keras コールバックについては、API ドキュメントを読んで必ず確認してください。アプリケーションには、CSV へのロギング、モデルの保存、TensorBoard でのメトリクスの可視化、その他多数があります。