TensorFlow Addons 回调:TimeStopping


此笔记本将演示如何使用 TensorFlow Addons 中的 TimeStopping 回调。


pip install -U tensorflow-addons
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0
构建简单的 MNIST CNN 模型

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))

              loss = 'sparse_categorical_crossentropy',
简单的 TimeStopping 用法

# initialize TimeStopping callback 
time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          validation_data=(x_test, y_test))
2021-08-13 19:51:40.120947: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/100
938/938 [==============================] - 3s 2ms/step - loss: 0.3375 - accuracy: 0.9032 - val_loss: 0.1676 - val_accuracy: 0.9500
Epoch 2/100
938/938 [==============================] - 2s 2ms/step - loss: 0.1608 - accuracy: 0.9532 - val_loss: 0.1146 - val_accuracy: 0.9639
Epoch 3/100
938/938 [==============================] - 2s 2ms/step - loss: 0.1212 - accuracy: 0.9638 - val_loss: 0.0947 - val_accuracy: 0.9726
Timed stopping at epoch 3 after training for 0:00:05
