![]() |
![]() |
![]() |
![]() |
概述
Keras Tuner 是一个库,可帮助您为 TensorFlow 程序选择最佳的超参数集。为您的机器学习 (ML) 应用选择正确的超参数集,这一过程称为超参数调节或超调。
超参数是控制训练过程和 ML 模型拓扑的变量。这些变量在训练过程中保持不变,并会直接影响 ML 程序的性能。超参数有两种类型:
- 模型超参数:影响模型的选择,例如隐藏层的数量和宽度
- 算法超参数:影响学习算法的速度和质量,例如随机梯度下降 (SGD) 的学习率以及 k 近邻 (KNN) 分类器的近邻数
在本教程中,您将使用 Keras Tuner 对图像分类应用执行超调。
设置
import tensorflow as tf
from tensorflow import keras
2022-08-31 04:58:31.497507: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2022-08-31 04:58:32.213817: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-31 04:58:32.214070: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-31 04:58:32.214083: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
安装并导入 Keras Tuner。
pip install -q -U keras-tuner
import keras_tuner as kt
下载并准备数据集
在本教程中,您将使用 Keras Tuner 为某个对 Fashion MNIST 数据集内的服装图像进行分类的机器学习模型找到最佳超参数。
加载数据。
(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
定义模型
构建用于超调的模型时,除了模型架构之外,还要定义超参数搜索空间。您为超调设置的模型称为超模型。
您可以通过两种方式定义超模型:
- 使用模型构建工具函数
- 将 Keras Tuner API 的
HyperModel
类子类化
您还可以将两个预定义的 HyperModel
类 HyperXception 和 HyperResNet 用于计算机视觉应用。
在本教程中,您将使用模型构建工具函数来定义图像分类模型。模型构建工具函数将返回已编译的模型,并使用您以内嵌方式定义的超参数对模型进行超调。
def model_builder(hp):
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
# Tune the number of units in the first Dense layer
# Choose an optimal value between 32-512
hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
model.add(keras.layers.Dense(units=hp_units, activation='relu'))
model.add(keras.layers.Dense(10))
# Tune the learning rate for the optimizer
# Choose an optimal value from 0.01, 0.001, or 0.0001
hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
实例化调节器并执行超调
实例化调节器以执行超调。Keras Tuner 提供了四种调节器:RandomSearch
、Hyperband
、BayesianOptimization
和 Sklearn
。在本教程中,您将使用 Hyperband 调节器。
要实例化 Hyperband 调节器,必须指定超模型、要优化的 objective
和要训练的最大周期数 (max_epochs
)。
tuner = kt.Hyperband(model_builder,
objective='val_accuracy',
max_epochs=10,
factor=3,
directory='my_dir',
project_name='intro_to_kt')
Hyperband 调节算法使用自适应资源分配和早停法来快速收敛到高性能模型。该过程采用了体育竞技争冠模式的排除法。算法会将大量模型训练多个周期,并仅将性能最高的一半模型送入下一轮训练。Hyperband 通过计算 1 + logfactor
(max_epochs
) 并将其向上舍入到最接近的整数来确定要训练的模型的数量。
创建回调以在验证损失达到特定值后提前停止训练。
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
运行超参数搜索。除了上面的回调外,搜索方法的参数也与 tf.keras.model.fit
所用参数相同。
tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])
# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]
print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 36s] val_accuracy: 0.8759999871253967 Best val_accuracy So Far: 0.887333333492279 Total elapsed time: 00h 07m 28s INFO:tensorflow:Oracle triggered exit The hyperparameter search is complete. The optimal number of units in the first densely-connected layer is 192 and the optimal learning rate for the optimizer is 0.001.
训练模型
使用从搜索中获得的超参数找到训练模型的最佳周期数。
# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)
val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.5081 - accuracy: 0.8205 - val_loss: 0.3976 - val_accuracy: 0.8615 Epoch 2/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.3815 - accuracy: 0.8616 - val_loss: 0.3777 - val_accuracy: 0.8647 Epoch 3/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.3387 - accuracy: 0.8771 - val_loss: 0.3343 - val_accuracy: 0.8793 Epoch 4/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.3135 - accuracy: 0.8851 - val_loss: 0.3447 - val_accuracy: 0.8777 Epoch 5/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2946 - accuracy: 0.8910 - val_loss: 0.3187 - val_accuracy: 0.8843 Epoch 6/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2791 - accuracy: 0.8957 - val_loss: 0.3390 - val_accuracy: 0.8797 Epoch 7/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2665 - accuracy: 0.9010 - val_loss: 0.3288 - val_accuracy: 0.8785 Epoch 8/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2529 - accuracy: 0.9068 - val_loss: 0.3263 - val_accuracy: 0.8817 Epoch 9/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2421 - accuracy: 0.9096 - val_loss: 0.3267 - val_accuracy: 0.8882 Epoch 10/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2322 - accuracy: 0.9135 - val_loss: 0.3045 - val_accuracy: 0.8922 Epoch 11/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2252 - accuracy: 0.9160 - val_loss: 0.3102 - val_accuracy: 0.8940 Epoch 12/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2173 - accuracy: 0.9181 - val_loss: 0.3215 - val_accuracy: 0.8908 Epoch 13/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2087 - accuracy: 0.9217 - val_loss: 0.3235 - val_accuracy: 0.8910 Epoch 14/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2010 - accuracy: 0.9253 - val_loss: 0.3191 - val_accuracy: 0.8938 Epoch 15/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1939 - accuracy: 0.9285 - val_loss: 0.3409 - val_accuracy: 0.8881 Epoch 16/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1888 - accuracy: 0.9298 - val_loss: 0.3285 - val_accuracy: 0.8949 Epoch 17/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1843 - accuracy: 0.9301 - val_loss: 0.3573 - val_accuracy: 0.8884 Epoch 18/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1752 - accuracy: 0.9340 - val_loss: 0.3179 - val_accuracy: 0.8963 Epoch 19/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1716 - accuracy: 0.9362 - val_loss: 0.3325 - val_accuracy: 0.8928 Epoch 20/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1672 - accuracy: 0.9374 - val_loss: 0.3874 - val_accuracy: 0.8819 Epoch 21/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1619 - accuracy: 0.9392 - val_loss: 0.3625 - val_accuracy: 0.8905 Epoch 22/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1563 - accuracy: 0.9416 - val_loss: 0.3905 - val_accuracy: 0.8863 Epoch 23/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1534 - accuracy: 0.9426 - val_loss: 0.3585 - val_accuracy: 0.8957 Epoch 24/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1487 - accuracy: 0.9445 - val_loss: 0.3646 - val_accuracy: 0.8906 Epoch 25/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1448 - accuracy: 0.9464 - val_loss: 0.3592 - val_accuracy: 0.8942 Epoch 26/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1410 - accuracy: 0.9465 - val_loss: 0.3615 - val_accuracy: 0.8980 Epoch 27/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1382 - accuracy: 0.9482 - val_loss: 0.3929 - val_accuracy: 0.8906 Epoch 28/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1305 - accuracy: 0.9511 - val_loss: 0.3898 - val_accuracy: 0.8898 Epoch 29/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1285 - accuracy: 0.9516 - val_loss: 0.3907 - val_accuracy: 0.8867 Epoch 30/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1251 - accuracy: 0.9526 - val_loss: 0.4096 - val_accuracy: 0.8885 Epoch 31/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1230 - accuracy: 0.9533 - val_loss: 0.4078 - val_accuracy: 0.8908 Epoch 32/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1206 - accuracy: 0.9543 - val_loss: 0.3960 - val_accuracy: 0.8929 Epoch 33/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1180 - accuracy: 0.9552 - val_loss: 0.4026 - val_accuracy: 0.8933 Epoch 34/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1153 - accuracy: 0.9567 - val_loss: 0.4129 - val_accuracy: 0.8931 Epoch 35/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1118 - accuracy: 0.9579 - val_loss: 0.4496 - val_accuracy: 0.8833 Epoch 36/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1111 - accuracy: 0.9590 - val_loss: 0.4864 - val_accuracy: 0.8848 Epoch 37/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1049 - accuracy: 0.9608 - val_loss: 0.4736 - val_accuracy: 0.8821 Epoch 38/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1079 - accuracy: 0.9593 - val_loss: 0.4342 - val_accuracy: 0.8955 Epoch 39/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1020 - accuracy: 0.9618 - val_loss: 0.4411 - val_accuracy: 0.8942 Epoch 40/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0997 - accuracy: 0.9624 - val_loss: 0.4438 - val_accuracy: 0.8894 Epoch 41/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0979 - accuracy: 0.9641 - val_loss: 0.4515 - val_accuracy: 0.8945 Epoch 42/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0997 - accuracy: 0.9630 - val_loss: 0.4590 - val_accuracy: 0.8846 Epoch 43/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0922 - accuracy: 0.9658 - val_loss: 0.4472 - val_accuracy: 0.8958 Epoch 44/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0928 - accuracy: 0.9654 - val_loss: 0.4950 - val_accuracy: 0.8882 Epoch 45/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0919 - accuracy: 0.9658 - val_loss: 0.4881 - val_accuracy: 0.8903 Epoch 46/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0885 - accuracy: 0.9665 - val_loss: 0.4729 - val_accuracy: 0.8913 Epoch 47/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0857 - accuracy: 0.9675 - val_loss: 0.5060 - val_accuracy: 0.8899 Epoch 48/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0874 - accuracy: 0.9675 - val_loss: 0.5096 - val_accuracy: 0.8907 Epoch 49/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0829 - accuracy: 0.9694 - val_loss: 0.5013 - val_accuracy: 0.8891 Epoch 50/50 1500/1500 [==============================] - 3s 2ms/step - loss: 0.0829 - accuracy: 0.9682 - val_loss: 0.5330 - val_accuracy: 0.8888 Best epoch: 26
重新实例化超模型并使用上面的最佳周期数对其进行训练。
hypermodel = tuner.hypermodel.build(best_hps)
# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/26 1500/1500 [==============================] - 4s 2ms/step - loss: 0.5091 - accuracy: 0.8206 - val_loss: 0.4011 - val_accuracy: 0.8560 Epoch 2/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.3795 - accuracy: 0.8622 - val_loss: 0.3718 - val_accuracy: 0.8668 Epoch 3/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.3415 - accuracy: 0.8752 - val_loss: 0.3526 - val_accuracy: 0.8750 Epoch 4/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.3150 - accuracy: 0.8837 - val_loss: 0.3425 - val_accuracy: 0.8751 Epoch 5/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2972 - accuracy: 0.8916 - val_loss: 0.3255 - val_accuracy: 0.8837 Epoch 6/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2799 - accuracy: 0.8971 - val_loss: 0.3634 - val_accuracy: 0.8671 Epoch 7/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2661 - accuracy: 0.9003 - val_loss: 0.3273 - val_accuracy: 0.8847 Epoch 8/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2543 - accuracy: 0.9055 - val_loss: 0.3319 - val_accuracy: 0.8790 Epoch 9/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2438 - accuracy: 0.9082 - val_loss: 0.3318 - val_accuracy: 0.8827 Epoch 10/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2336 - accuracy: 0.9120 - val_loss: 0.3279 - val_accuracy: 0.8827 Epoch 11/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2258 - accuracy: 0.9152 - val_loss: 0.3272 - val_accuracy: 0.8900 Epoch 12/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2197 - accuracy: 0.9180 - val_loss: 0.3087 - val_accuracy: 0.8938 Epoch 13/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2094 - accuracy: 0.9215 - val_loss: 0.3391 - val_accuracy: 0.8814 Epoch 14/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.2050 - accuracy: 0.9217 - val_loss: 0.3199 - val_accuracy: 0.8900 Epoch 15/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1971 - accuracy: 0.9268 - val_loss: 0.3374 - val_accuracy: 0.8919 Epoch 16/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1912 - accuracy: 0.9280 - val_loss: 0.3417 - val_accuracy: 0.8886 Epoch 17/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1837 - accuracy: 0.9305 - val_loss: 0.3579 - val_accuracy: 0.8838 Epoch 18/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1780 - accuracy: 0.9332 - val_loss: 0.3185 - val_accuracy: 0.8971 Epoch 19/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1725 - accuracy: 0.9354 - val_loss: 0.3566 - val_accuracy: 0.8856 Epoch 20/26 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1699 - accuracy: 0.9368 - val_loss: 0.3564 - val_accuracy: 0.8895 Epoch 21/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1632 - accuracy: 0.9399 - val_loss: 0.3474 - val_accuracy: 0.8930 Epoch 22/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1599 - accuracy: 0.9400 - val_loss: 0.3639 - val_accuracy: 0.8882 Epoch 23/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1539 - accuracy: 0.9425 - val_loss: 0.3570 - val_accuracy: 0.8932 Epoch 24/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1507 - accuracy: 0.9433 - val_loss: 0.3596 - val_accuracy: 0.8925 Epoch 25/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1475 - accuracy: 0.9458 - val_loss: 0.3587 - val_accuracy: 0.8908 Epoch 26/26 1500/1500 [==============================] - 3s 2ms/step - loss: 0.1426 - accuracy: 0.9470 - val_loss: 0.3748 - val_accuracy: 0.8864 <keras.callbacks.History at 0x7f5e2eb57c40>
要完成本教程,请在测试数据上评估超模型。
eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.4196 - accuracy: 0.8807 [test loss, test accuracy]: [0.41957515478134155, 0.8806999921798706]
my_dir/intro_to_kt
目录中包含了在超参数搜索期间每次试验(模型配置)运行的详细日志和检查点。如果重新运行超参数搜索,Keras Tuner 将使用这些日志中记录的现有状态来继续搜索。要停用此行为,请在实例化调节器时传递一个附加的 overwrite = True
参数。
总结
在本教程中,您学习了如何使用 Keras Tuner 调节模型的超参数。要详细了解 Keras Tuner,请查看以下其他资源:
另请查看 TensorBoard 中的 HParams Dashboard,以交互方式调节模型超参数。