자신만의 콜백 작성하기

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

시작하기

콜백은 훈련, 평가 또는 추론 중에 Keras 모델의 동작을 사용자 정의할 수 있는 강력한 도구입니다. TensorBoard로 훈련 진행 상황과 결과를 시각화하기 위한 tf.keras.callbacks.TensorBoard 또는 훈련 도중 모델을 주기적으로 저장하는 tf.keras.callbacks.ModelCheckpoint 등이 여기에 포함됩니다.

이 가이드에서는 Keras 콜백이 무엇인지, 무엇을 할 수 있는지, 어떻게 자신만의 콜백을 빌드할 수 있는지 배웁니다. 콜백 애플리케이션의 몇 가지 간단한 데모를 통해 시작할 수 있습니다.

Setup

import tensorflow as tf
from tensorflow import keras
2022-12-14 22:57:48.045363: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:57:48.045457: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:57:48.045473: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Keras 콜백 개요

모든 콜백은 keras.callbacks.Callback 클래스를 하위 클래스화하며, 훈련, 테스트 및 예측의 다양한 단계에서 호출되는 메서드 세트를 재정의합니다. 콜백은 훈련 중 모델의 내부 상태 및 통계를 볼 때 유용합니다.

콜백(키워드 인수 callbacks와 같은)의 목록을 다음 모델 메서드에 전달할 수 있습니다.

콜백 메서드의 개요

전역 메서드

on_(train|test|predict)_begin(self, logs=None)

fit/evaluate/predict 시작 시 호출됩니다.

on_(train|test|predict)_end(self, logs=None)

fit/evaluate/predict 종료 시 호출됩니다.

훈련/테스트/예측을 위한 배치 레벨의 메서드

on_(train|test|predict)_batch_begin(self, batch, logs=None)

훈련/테스트/예측 중에 배치를 처리하기 직전에 호출됩니다.

on_(train|test|predict)_batch_end(self, batch, logs=None)

훈련/테스트/예측이 끝날 때 호출됩니다. 이 메서드에서 logs는 메트릭 결과를 포함하는 dict입니다.

에포크 레벨 메서드(훈련만 해당)

on_epoch_begin(self, epoch, logs=None)

훈련 중 epoch가 시작될 때 호출됩니다.

on_epoch_end(self, epoch, logs=None)

훈련 중 epoc가이 끝날 때 호출됩니다.

기본적인 예제

구체적인 예를 살펴보겠습니다. 시작하려면 tensorflow를 가져오고 간단한 Sequential Keras 모델을 정의합니다.

# Define the Keras model to add callbacks to
def get_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(1, input_dim=784))
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss="mean_squared_error",
        metrics=["mean_absolute_error"],
    )
    return model

그런 다음 Keras 데이터세트 API에서 훈련 및 테스트용 MNIST 데이터를 로드합니다.

# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]

이제 다음의 경우 로깅하는 간단한 사용자 정의 콜백을 정의합니다.

  • fit/evaluate/predict가 시작하고 끝날 때
  • 각 에포크가 시작하고 끝날 때
  • 각 훈련 배치가 시작하고 끝날 때
  • 각 평가(테스트) 배치가 시작하고 끝날 때
  • 각 추론(예측) 배치가 시작하고 끝날 때
class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

사용해 보겠습니다.

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)

res = model.evaluate(
    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
Starting training; got log keys: []
Start epoch 0 of training; got log keys: []
...Training: start of batch 0; got log keys: []
...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 1; got log keys: []
...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 2; got log keys: []
...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Training: start of batch 3; got log keys: []
...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']
Start testing; got log keys: []
...Evaluating: start of batch 0; got log keys: []
...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 1; got log keys: []
...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 2; got log keys: []
...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 3; got log keys: []
...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 4; got log keys: []
...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 5; got log keys: []
...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 6; got log keys: []
...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error']
...Evaluating: start of batch 7; got log keys: []
...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error']
Stop testing; got log keys: ['loss', 'mean_absolute_error']
Start predicting; got log keys: []
...Predicting: start of batch 0; got log keys: []
...Predicting: end of batch 0; got log keys: ['outputs']
1/8 [==>...........................] - ETA: 0s...Predicting: start of batch 1; got log keys: []
...Predicting: end of batch 1; got log keys: ['outputs']
...Predicting: start of batch 2; got log keys: []
...Predicting: end of batch 2; got log keys: ['outputs']
...Predicting: start of batch 3; got log keys: []
...Predicting: end of batch 3; got log keys: ['outputs']
...Predicting: start of batch 4; got log keys: []
...Predicting: end of batch 4; got log keys: ['outputs']
...Predicting: start of batch 5; got log keys: []
...Predicting: end of batch 5; got log keys: ['outputs']
...Predicting: start of batch 6; got log keys: []
...Predicting: end of batch 6; got log keys: ['outputs']
...Predicting: start of batch 7; got log keys: []
...Predicting: end of batch 7; got log keys: ['outputs']
Stop predicting; got log keys: []
8/8 [==============================] - 0s 2ms/step

logs dict 사용법

logs dict에는 손실값과 배치 또는 에포크의 끝에 있는 모든 메트릭이 포함됩니다. 이 예제에는 손실 및 평균 절대 오차가 포함됩니다.

class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print(
            "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
        )

    def on_test_batch_end(self, batch, logs=None):
        print(
            "Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
        )

    def on_epoch_end(self, epoch, logs=None):
        print(
            "The average loss for epoch {} is {:7.2f} "
            "and mean absolute error is {:7.2f}.".format(
                epoch, logs["loss"], logs["mean_absolute_error"]
            )
        )


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=2,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback()],
)

res = model.evaluate(
    x_test,
    y_test,
    batch_size=128,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback()],
)
Up to batch 0, the average loss is   30.85.
Up to batch 1, the average loss is  449.89.
Up to batch 2, the average loss is  307.66.
Up to batch 3, the average loss is  232.61.
Up to batch 4, the average loss is  187.82.
Up to batch 5, the average loss is  157.72.
Up to batch 6, the average loss is  136.10.
Up to batch 7, the average loss is  122.38.
The average loss for epoch 0 is  122.38 and mean absolute error is    6.01.
Up to batch 0, the average loss is    5.25.
Up to batch 1, the average loss is    4.90.
Up to batch 2, the average loss is    4.74.
Up to batch 3, the average loss is    4.52.
Up to batch 4, the average loss is    4.49.
Up to batch 5, the average loss is    4.43.
Up to batch 6, the average loss is    4.41.
Up to batch 7, the average loss is    4.47.
The average loss for epoch 1 is    4.47 and mean absolute error is    1.70.
Up to batch 0, the average loss is    4.60.
Up to batch 1, the average loss is    4.13.
Up to batch 2, the average loss is    4.26.
Up to batch 3, the average loss is    4.23.
Up to batch 4, the average loss is    4.41.
Up to batch 5, the average loss is    4.40.
Up to batch 6, the average loss is    4.38.
Up to batch 7, the average loss is    4.33.

self.model 속성의 사용법

메서드 중 하나가 호출될 때 로그 정보를 수신하는 것 외에도 콜백은 현재 훈련/평가/추론 라운드와 연결된 모델(self.model)에 액세스할 수 있습니다.

콜백에서 self.model로 수행할 수 있는 연산은 다음과 같습니다.

  • 훈련을 즉시 중단하려면 self.model.stop_training = True를 설정합니다.
  • self.model.optimizer.learning_rate와 같은 옵티마이저(self.model.optimizer로 사용 가능)의 하이퍼파라미터를 변경합니다.
  • 주기적으로 모델을 저장합니다.
  • 각 에포크가 끝날 때 몇 가지 테스트 샘플에 model.predict()의 출력을 기록하여 훈련 중에 온전성 검사용으로 사용합니다.
  • 각 에포크가 끝날 때 중간 특성의 시각화를 추출하여 시간이 지남에 따라 모델이 학습하는 내용을 모니터링합니다.
  • 기타

몇 가지 실제 예를 살펴보겠습니다.

Keras 콜백 애플리케이션의 예

최소 손실 시 조기 중지

이 첫 번째 예는 self.model.stop_training (boolean) 속성을 설정하여 최소 손실에 도달했을 때 훈련을 중단하는 Callback을 생성하는 방법을 보여줍니다. 선택적으로, 로컬 최소값에 도달한 후 중단하기 전에 기다려야 하는 에포크 수를 지정하는 인수 patience을 제공할 수 있습니다.

tf.keras.callbacks.EarlyStopping은 더 완전한 일반적인 구현을 제공합니다.

import numpy as np


class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=30,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
Up to batch 0, the average loss is   20.82.
Up to batch 1, the average loss is  455.88.
Up to batch 2, the average loss is  314.20.
Up to batch 3, the average loss is  238.72.
Up to batch 4, the average loss is  192.52.
The average loss for epoch 0 is  192.52 and mean absolute error is    8.37.
Up to batch 0, the average loss is    7.02.
Up to batch 1, the average loss is    5.98.
Up to batch 2, the average loss is    6.46.
Up to batch 3, the average loss is    6.00.
Up to batch 4, the average loss is    6.13.
The average loss for epoch 1 is    6.13 and mean absolute error is    2.08.
Up to batch 0, the average loss is    4.21.
Up to batch 1, the average loss is    4.32.
Up to batch 2, the average loss is    4.47.
Up to batch 3, the average loss is    4.59.
Up to batch 4, the average loss is    4.32.
The average loss for epoch 2 is    4.32 and mean absolute error is    1.68.
Up to batch 0, the average loss is    5.80.
Up to batch 1, the average loss is    6.07.
Up to batch 2, the average loss is    5.62.
Up to batch 3, the average loss is    6.16.
Up to batch 4, the average loss is    6.63.
The average loss for epoch 3 is    6.63 and mean absolute error is    2.05.
Restoring model weights from the end of the best epoch.
Epoch 00004: early stopping
<keras.callbacks.History at 0x7ff9a80529d0>

학습 속도 스케줄링

이 예제에서는 사용자 정의 콜백을 사용하여 훈련 동안 옵티마이저의 학습 속도를 동적으로 변경하는 방법을 보여줍니다.

보다 일반적인 구현에 대해서는 callbacks.LearningRateScheduler를 참조하세요.

class CustomLearningRateScheduler(keras.callbacks.Callback):
    """Learning rate scheduler which sets the learning rate according to schedule.

  Arguments:
      schedule: a function that takes an epoch index
          (integer, indexed from 0) and current learning rate
          as inputs and returns a new learning rate as output (float).
  """

    def __init__(self, schedule):
        super(CustomLearningRateScheduler, self).__init__()
        self.schedule = schedule

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))


LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05),
    (6, 0.01),
    (9, 0.005),
    (12, 0.001),
]


def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=15,
    verbose=0,
    callbacks=[
        LossAndErrorPrintingCallback(),
        CustomLearningRateScheduler(lr_schedule),
    ],
)
Epoch 00000: Learning rate is 0.1000.
Up to batch 0, the average loss is   35.59.
Up to batch 1, the average loss is  458.07.
Up to batch 2, the average loss is  313.68.
Up to batch 3, the average loss is  237.50.
Up to batch 4, the average loss is  191.73.
The average loss for epoch 0 is  191.73 and mean absolute error is    8.40.

Epoch 00001: Learning rate is 0.1000.
Up to batch 0, the average loss is    8.22.
Up to batch 1, the average loss is    6.90.
Up to batch 2, the average loss is    6.36.
Up to batch 3, the average loss is    6.18.
Up to batch 4, the average loss is    5.78.
The average loss for epoch 1 is    5.78 and mean absolute error is    1.99.

Epoch 00002: Learning rate is 0.1000.
Up to batch 0, the average loss is    4.69.
Up to batch 1, the average loss is    5.04.
Up to batch 2, the average loss is    4.95.
Up to batch 3, the average loss is    4.64.
Up to batch 4, the average loss is    4.51.
The average loss for epoch 2 is    4.51 and mean absolute error is    1.74.

Epoch 00003: Learning rate is 0.0500.
Up to batch 0, the average loss is    5.31.
Up to batch 1, the average loss is    4.65.
Up to batch 2, the average loss is    4.43.
Up to batch 3, the average loss is    4.23.
Up to batch 4, the average loss is    3.92.
The average loss for epoch 3 is    3.92 and mean absolute error is    1.55.

Epoch 00004: Learning rate is 0.0500.
Up to batch 0, the average loss is    3.17.
Up to batch 1, the average loss is    4.14.
Up to batch 2, the average loss is    3.94.
Up to batch 3, the average loss is    4.06.
Up to batch 4, the average loss is    4.07.
The average loss for epoch 4 is    4.07 and mean absolute error is    1.61.

Epoch 00005: Learning rate is 0.0500.
Up to batch 0, the average loss is    3.21.
Up to batch 1, the average loss is    3.17.
Up to batch 2, the average loss is    3.51.
Up to batch 3, the average loss is    4.25.
Up to batch 4, the average loss is    4.73.
The average loss for epoch 5 is    4.73 and mean absolute error is    1.71.

Epoch 00006: Learning rate is 0.0100.
Up to batch 0, the average loss is    7.54.
Up to batch 1, the average loss is    7.61.
Up to batch 2, the average loss is    6.40.
Up to batch 3, the average loss is    5.98.
Up to batch 4, the average loss is    5.60.
The average loss for epoch 6 is    5.60 and mean absolute error is    1.91.

Epoch 00007: Learning rate is 0.0100.
Up to batch 0, the average loss is    3.46.
Up to batch 1, the average loss is    3.25.
Up to batch 2, the average loss is    3.22.
Up to batch 3, the average loss is    3.30.
Up to batch 4, the average loss is    3.34.
The average loss for epoch 7 is    3.34 and mean absolute error is    1.46.

Epoch 00008: Learning rate is 0.0100.
Up to batch 0, the average loss is    2.77.
Up to batch 1, the average loss is    2.71.
Up to batch 2, the average loss is    3.08.
Up to batch 3, the average loss is    3.07.
Up to batch 4, the average loss is    3.02.
The average loss for epoch 8 is    3.02 and mean absolute error is    1.36.

Epoch 00009: Learning rate is 0.0050.
Up to batch 0, the average loss is    4.04.
Up to batch 1, the average loss is    3.42.
Up to batch 2, the average loss is    2.98.
Up to batch 3, the average loss is    2.88.
Up to batch 4, the average loss is    2.88.
The average loss for epoch 9 is    2.88 and mean absolute error is    1.32.

Epoch 00010: Learning rate is 0.0050.
Up to batch 0, the average loss is    3.90.
Up to batch 1, the average loss is    3.61.
Up to batch 2, the average loss is    3.19.
Up to batch 3, the average loss is    3.22.
Up to batch 4, the average loss is    3.30.
The average loss for epoch 10 is    3.30 and mean absolute error is    1.39.

Epoch 00011: Learning rate is 0.0050.
Up to batch 0, the average loss is    3.95.
Up to batch 1, the average loss is    3.26.
Up to batch 2, the average loss is    3.23.
Up to batch 3, the average loss is    3.35.
Up to batch 4, the average loss is    3.37.
The average loss for epoch 11 is    3.37 and mean absolute error is    1.43.

Epoch 00012: Learning rate is 0.0010.
Up to batch 0, the average loss is    3.68.
Up to batch 1, the average loss is    3.20.
Up to batch 2, the average loss is    3.31.
Up to batch 3, the average loss is    3.25.
Up to batch 4, the average loss is    3.38.
The average loss for epoch 12 is    3.38 and mean absolute error is    1.43.

Epoch 00013: Learning rate is 0.0010.
Up to batch 0, the average loss is    4.35.
Up to batch 1, the average loss is    3.50.
Up to batch 2, the average loss is    3.36.
Up to batch 3, the average loss is    3.35.
Up to batch 4, the average loss is    3.47.
The average loss for epoch 13 is    3.47 and mean absolute error is    1.46.

Epoch 00014: Learning rate is 0.0010.
Up to batch 0, the average loss is    4.99.
Up to batch 1, the average loss is    3.89.
Up to batch 2, the average loss is    3.25.
Up to batch 3, the average loss is    3.16.
Up to batch 4, the average loss is    2.96.
The average loss for epoch 14 is    2.96 and mean absolute error is    1.31.
<keras.callbacks.History at 0x7ff928675a30>

내장 Keras 콜백

API 문서를 읽고 기존 Keras 콜백을 확인하세요. 애플리케이션에는 CSV에 로깅하기, 모델 저장하기, TensorBoard에서 메트릭 시각화하기 등이 포함됩니다.