在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
概述
Keras Tuner 是一个库,可帮助您为 TensorFlow 程序选择最佳的超参数集。为您的机器学习 (ML) 应用选择正确的超参数集,这一过程称为超参数调节或超调。
超参数是控制训练过程和 ML 模型拓扑的变量。这些变量在训练过程中保持不变,并会直接影响 ML 程序的性能。超参数有两种类型:
- 模型超参数:影响模型的选择,例如隐藏层的数量和宽度
- 算法超参数:影响学习算法的速度和质量,例如随机梯度下降 (SGD) 的学习率以及 k 近邻 (KNN) 分类器的近邻数
在本教程中,您将使用 Keras Tuner 对图像分类应用执行超调。
设置
import tensorflow as tf
from tensorflow import keras
2023-11-08 00:34:44.972816: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-08 00:34:44.972863: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-08 00:34:44.974559: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
安装并导入 Keras Tuner。
pip install -q -U keras-tuner
import keras_tuner as kt
下载并准备数据集
在本教程中,您将使用 Keras Tuner 为某个对 Fashion MNIST 数据集内的服装图像进行分类的机器学习模型找到最佳超参数。
加载数据。
(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
定义模型
构建用于超调的模型时,除了模型架构之外,还要定义超参数搜索空间。您为超调设置的模型称为超模型。
您可以通过两种方式定义超模型:
- 使用模型构建工具函数
- 将 Keras Tuner API 的
HyperModel
类子类化
您还可以将两个预定义的 HyperModel
类 HyperXception 和 HyperResNet 用于计算机视觉应用。
在本教程中,您将使用模型构建工具函数来定义图像分类模型。模型构建工具函数将返回已编译的模型,并使用您以内嵌方式定义的超参数对模型进行超调。
def model_builder(hp):
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
# Tune the number of units in the first Dense layer
# Choose an optimal value between 32-512
hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
model.add(keras.layers.Dense(units=hp_units, activation='relu'))
model.add(keras.layers.Dense(10))
# Tune the learning rate for the optimizer
# Choose an optimal value from 0.01, 0.001, or 0.0001
hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
实例化调节器并执行超调
实例化调节器以执行超调。Keras Tuner 提供了四种调节器:RandomSearch
、Hyperband
、BayesianOptimization
和 Sklearn
。在本教程中,您将使用 Hyperband 调节器。
要实例化 Hyperband 调节器,必须指定超模型、要优化的 objective
和要训练的最大周期数 (max_epochs
)。
tuner = kt.Hyperband(model_builder,
objective='val_accuracy',
max_epochs=10,
factor=3,
directory='my_dir',
project_name='intro_to_kt')
Hyperband 调节算法使用自适应资源分配和早停法来快速收敛到高性能模型。该过程采用了体育竞技争冠模式的排除法。算法会将大量模型训练多个周期,并仅将性能最高的一半模型送入下一轮训练。Hyperband 通过计算 1 + logfactor
(max_epochs
) 并将其向上舍入到最接近的整数来确定要训练的模型的数量。
创建回调以在验证损失达到特定值后提前停止训练。
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
运行超参数搜索。除了上面的回调外,搜索方法的参数也与 tf.keras.model.fit
所用参数相同。
tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])
# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]
print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 43s] val_accuracy: 0.8786666393280029 Best val_accuracy So Far: 0.8955833315849304 Total elapsed time: 00h 09m 03s The hyperparameter search is complete. The optimal number of units in the first densely-connected layer is 448 and the optimal learning rate for the optimizer is 0.001.
训练模型
使用从搜索中获得的超参数找到训练模型的最佳周期数。
# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)
val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.4930 - accuracy: 0.8246 - val_loss: 0.3828 - val_accuracy: 0.8637 Epoch 2/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3712 - accuracy: 0.8640 - val_loss: 0.3673 - val_accuracy: 0.8663 Epoch 3/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3302 - accuracy: 0.8771 - val_loss: 0.3450 - val_accuracy: 0.8763 Epoch 4/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3057 - accuracy: 0.8874 - val_loss: 0.3340 - val_accuracy: 0.8823 Epoch 5/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2855 - accuracy: 0.8940 - val_loss: 0.3199 - val_accuracy: 0.8861 Epoch 6/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2708 - accuracy: 0.8999 - val_loss: 0.3114 - val_accuracy: 0.8890 Epoch 7/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2584 - accuracy: 0.9046 - val_loss: 0.3261 - val_accuracy: 0.8852 Epoch 8/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2443 - accuracy: 0.9094 - val_loss: 0.3379 - val_accuracy: 0.8788 Epoch 9/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2331 - accuracy: 0.9119 - val_loss: 0.3133 - val_accuracy: 0.8894 Epoch 10/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2248 - accuracy: 0.9158 - val_loss: 0.3240 - val_accuracy: 0.8886 Epoch 11/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2142 - accuracy: 0.9200 - val_loss: 0.3291 - val_accuracy: 0.8891 Epoch 12/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2079 - accuracy: 0.9227 - val_loss: 0.3122 - val_accuracy: 0.8942 Epoch 13/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1980 - accuracy: 0.9258 - val_loss: 0.3313 - val_accuracy: 0.8852 Epoch 14/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1919 - accuracy: 0.9278 - val_loss: 0.3244 - val_accuracy: 0.8930 Epoch 15/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1852 - accuracy: 0.9301 - val_loss: 0.3487 - val_accuracy: 0.8875 Epoch 16/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1802 - accuracy: 0.9319 - val_loss: 0.3368 - val_accuracy: 0.8927 Epoch 17/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1740 - accuracy: 0.9343 - val_loss: 0.3056 - val_accuracy: 0.8973 Epoch 18/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1663 - accuracy: 0.9379 - val_loss: 0.3311 - val_accuracy: 0.8967 Epoch 19/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1626 - accuracy: 0.9388 - val_loss: 0.3390 - val_accuracy: 0.8933 Epoch 20/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1562 - accuracy: 0.9417 - val_loss: 0.3752 - val_accuracy: 0.8853 Epoch 21/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1507 - accuracy: 0.9435 - val_loss: 0.3467 - val_accuracy: 0.8933 Epoch 22/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1504 - accuracy: 0.9433 - val_loss: 0.3458 - val_accuracy: 0.8991 Epoch 23/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1419 - accuracy: 0.9470 - val_loss: 0.3440 - val_accuracy: 0.8942 Epoch 24/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1418 - accuracy: 0.9471 - val_loss: 0.3936 - val_accuracy: 0.8887 Epoch 25/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1351 - accuracy: 0.9489 - val_loss: 0.3691 - val_accuracy: 0.8937 Epoch 26/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1338 - accuracy: 0.9496 - val_loss: 0.3733 - val_accuracy: 0.8947 Epoch 27/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1290 - accuracy: 0.9509 - val_loss: 0.3894 - val_accuracy: 0.8938 Epoch 28/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1234 - accuracy: 0.9538 - val_loss: 0.3997 - val_accuracy: 0.8886 Epoch 29/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1203 - accuracy: 0.9542 - val_loss: 0.3820 - val_accuracy: 0.8971 Epoch 30/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1168 - accuracy: 0.9565 - val_loss: 0.3945 - val_accuracy: 0.8955 Epoch 31/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1173 - accuracy: 0.9554 - val_loss: 0.4014 - val_accuracy: 0.8967 Epoch 32/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1121 - accuracy: 0.9580 - val_loss: 0.4028 - val_accuracy: 0.8915 Epoch 33/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1096 - accuracy: 0.9580 - val_loss: 0.4551 - val_accuracy: 0.8954 Epoch 34/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1077 - accuracy: 0.9591 - val_loss: 0.4140 - val_accuracy: 0.8919 Epoch 35/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1042 - accuracy: 0.9614 - val_loss: 0.4150 - val_accuracy: 0.8972 Epoch 36/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1043 - accuracy: 0.9614 - val_loss: 0.4508 - val_accuracy: 0.8942 Epoch 37/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0973 - accuracy: 0.9630 - val_loss: 0.4470 - val_accuracy: 0.8930 Epoch 38/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0996 - accuracy: 0.9619 - val_loss: 0.4453 - val_accuracy: 0.8966 Epoch 39/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0929 - accuracy: 0.9655 - val_loss: 0.4483 - val_accuracy: 0.8944 Epoch 40/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0941 - accuracy: 0.9646 - val_loss: 0.4605 - val_accuracy: 0.8967 Epoch 41/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0897 - accuracy: 0.9658 - val_loss: 0.4561 - val_accuracy: 0.8975 Epoch 42/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0905 - accuracy: 0.9671 - val_loss: 0.4831 - val_accuracy: 0.8962 Epoch 43/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0857 - accuracy: 0.9677 - val_loss: 0.4745 - val_accuracy: 0.8955 Epoch 44/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0910 - accuracy: 0.9664 - val_loss: 0.4529 - val_accuracy: 0.8947 Epoch 45/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0851 - accuracy: 0.9683 - val_loss: 0.4733 - val_accuracy: 0.8954 Epoch 46/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0772 - accuracy: 0.9710 - val_loss: 0.4811 - val_accuracy: 0.8967 Epoch 47/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0821 - accuracy: 0.9691 - val_loss: 0.5194 - val_accuracy: 0.8911 Epoch 48/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0799 - accuracy: 0.9713 - val_loss: 0.5141 - val_accuracy: 0.8944 Epoch 49/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0788 - accuracy: 0.9705 - val_loss: 0.4839 - val_accuracy: 0.8959 Epoch 50/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0766 - accuracy: 0.9709 - val_loss: 0.5130 - val_accuracy: 0.8944 Best epoch: 22
重新实例化超模型并使用上面的最佳周期数对其进行训练。
hypermodel = tuner.hypermodel.build(best_hps)
# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/22 1500/1500 [==============================] - 5s 3ms/step - loss: 0.4890 - accuracy: 0.8271 - val_loss: 0.4094 - val_accuracy: 0.8528 Epoch 2/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3694 - accuracy: 0.8661 - val_loss: 0.3573 - val_accuracy: 0.8658 Epoch 3/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3321 - accuracy: 0.8784 - val_loss: 0.3685 - val_accuracy: 0.8646 Epoch 4/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3048 - accuracy: 0.8867 - val_loss: 0.3363 - val_accuracy: 0.8767 Epoch 5/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2875 - accuracy: 0.8927 - val_loss: 0.3549 - val_accuracy: 0.8754 Epoch 6/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2714 - accuracy: 0.8988 - val_loss: 0.3112 - val_accuracy: 0.8876 Epoch 7/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2579 - accuracy: 0.9040 - val_loss: 0.3201 - val_accuracy: 0.8858 Epoch 8/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2451 - accuracy: 0.9088 - val_loss: 0.3104 - val_accuracy: 0.8873 Epoch 9/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2330 - accuracy: 0.9135 - val_loss: 0.3134 - val_accuracy: 0.8934 Epoch 10/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2249 - accuracy: 0.9148 - val_loss: 0.3214 - val_accuracy: 0.8878 Epoch 11/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2158 - accuracy: 0.9186 - val_loss: 0.3415 - val_accuracy: 0.8825 Epoch 12/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2088 - accuracy: 0.9208 - val_loss: 0.3341 - val_accuracy: 0.8859 Epoch 13/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1998 - accuracy: 0.9250 - val_loss: 0.3628 - val_accuracy: 0.8788 Epoch 14/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1927 - accuracy: 0.9271 - val_loss: 0.3349 - val_accuracy: 0.8921 Epoch 15/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1833 - accuracy: 0.9309 - val_loss: 0.3353 - val_accuracy: 0.8920 Epoch 16/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1804 - accuracy: 0.9305 - val_loss: 0.3479 - val_accuracy: 0.8906 Epoch 17/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1723 - accuracy: 0.9351 - val_loss: 0.3435 - val_accuracy: 0.8903 Epoch 18/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1647 - accuracy: 0.9376 - val_loss: 0.3408 - val_accuracy: 0.8926 Epoch 19/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1635 - accuracy: 0.9377 - val_loss: 0.3771 - val_accuracy: 0.8845 Epoch 20/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1589 - accuracy: 0.9400 - val_loss: 0.3478 - val_accuracy: 0.8914 Epoch 21/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1522 - accuracy: 0.9438 - val_loss: 0.3609 - val_accuracy: 0.8873 Epoch 22/22 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1476 - accuracy: 0.9442 - val_loss: 0.3635 - val_accuracy: 0.8942 <keras.src.callbacks.History at 0x7fd98036f400>
要完成本教程,请在测试数据上评估超模型。
eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.3894 - accuracy: 0.8890 [test loss, test accuracy]: [0.38939228653907776, 0.8889999985694885]
my_dir/intro_to_kt
目录中包含了在超参数搜索期间每次试验(模型配置)运行的详细日志和检查点。如果重新运行超参数搜索,Keras Tuner 将使用这些日志中记录的现有状态来继续搜索。要停用此行为,请在实例化调节器时传递一个附加的 overwrite = True
参数。
总结
在本教程中,您学习了如何使用 Keras Tuner 调节模型的超参数。要详细了解 Keras Tuner,请查看以下其他资源:
另请查看 TensorBoard 中的 HParams Dashboard,以交互方式调节模型超参数。