Keras Tuner 简介

概述

Keras Tuner 是一个库,可帮助您为 TensorFlow 程序选择最佳的超参数集。为您的机器学习 (ML) 应用选择正确的超参数集,这一过程称为超参数调节超调

超参数是控制训练过程和 ML 模型拓扑的变量。这些变量在训练过程中保持不变,并会直接影响 ML 程序的性能。超参数有两种类型:

  1. 模型超参数:影响模型的选择,例如隐藏层的数量和宽度
  2. 算法超参数:影响学习算法的速度和质量,例如随机梯度下降 (SGD) 的学习率以及 k 近邻 (KNN) 分类器的近邻数

在本教程中,您将使用 Keras Tuner 对图像分类应用执行超调。

设置

import tensorflow as tf
from tensorflow import keras
2023-11-08 00:34:44.972816: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-08 00:34:44.972863: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-08 00:34:44.974559: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

安装并导入 Keras Tuner。

pip install -q -U keras-tuner
import keras_tuner as kt

下载并准备数据集

在本教程中,您将使用 Keras Tuner 为某个对 Fashion MNIST 数据集内的服装图像进行分类的机器学习模型找到最佳超参数。

加载数据。

(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0

定义模型

构建用于超调的模型时,除了模型架构之外,还要定义超参数搜索空间。您为超调设置的模型称为超模型

您可以通过两种方式定义超模型:

  • 使用模型构建工具函数
  • 将 Keras Tuner API 的 HyperModel 类子类化

您还可以将两个预定义的 HyperModelHyperXceptionHyperResNet 用于计算机视觉应用。

在本教程中,您将使用模型构建工具函数来定义图像分类模型。模型构建工具函数将返回已编译的模型,并使用您以内嵌方式定义的超参数对模型进行超调。

def model_builder(hp):
  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28)))

  # Tune the number of units in the first Dense layer
  # Choose an optimal value between 32-512
  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))
  model.add(keras.layers.Dense(10))

  # Tune the learning rate for the optimizer
  # Choose an optimal value from 0.01, 0.001, or 0.0001
  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

实例化调节器并执行超调

实例化调节器以执行超调。Keras Tuner 提供了四种调节器:RandomSearchHyperbandBayesianOptimizationSklearn。在本教程中,您将使用 Hyperband 调节器。

要实例化 Hyperband 调节器,必须指定超模型、要优化的 objective 和要训练的最大周期数 (max_epochs)。

tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3,
                     directory='my_dir',
                     project_name='intro_to_kt')

Hyperband 调节算法使用自适应资源分配和早停法来快速收敛到高性能模型。该过程采用了体育竞技争冠模式的排除法。算法会将大量模型训练多个周期,并仅将性能最高的一半模型送入下一轮训练。Hyperband 通过计算 1 + logfactor(max_epochs) 并将其向上舍入到最接近的整数来确定要训练的模型的数量。

创建回调以在验证损失达到特定值后提前停止训练。

stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

运行超参数搜索。除了上面的回调外,搜索方法的参数也与 tf.keras.model.fit 所用参数相同。

tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])

# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]

print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 43s]
val_accuracy: 0.8786666393280029

Best val_accuracy So Far: 0.8955833315849304
Total elapsed time: 00h 09m 03s

The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is 448 and the optimal learning rate for the optimizer
is 0.001.

训练模型

使用从搜索中获得的超参数找到训练模型的最佳周期数。

# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)

val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50
1500/1500 [==============================] - 5s 3ms/step - loss: 0.4930 - accuracy: 0.8246 - val_loss: 0.3828 - val_accuracy: 0.8637
Epoch 2/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.3712 - accuracy: 0.8640 - val_loss: 0.3673 - val_accuracy: 0.8663
Epoch 3/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.3302 - accuracy: 0.8771 - val_loss: 0.3450 - val_accuracy: 0.8763
Epoch 4/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.3057 - accuracy: 0.8874 - val_loss: 0.3340 - val_accuracy: 0.8823
Epoch 5/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2855 - accuracy: 0.8940 - val_loss: 0.3199 - val_accuracy: 0.8861
Epoch 6/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2708 - accuracy: 0.8999 - val_loss: 0.3114 - val_accuracy: 0.8890
Epoch 7/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2584 - accuracy: 0.9046 - val_loss: 0.3261 - val_accuracy: 0.8852
Epoch 8/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2443 - accuracy: 0.9094 - val_loss: 0.3379 - val_accuracy: 0.8788
Epoch 9/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2331 - accuracy: 0.9119 - val_loss: 0.3133 - val_accuracy: 0.8894
Epoch 10/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2248 - accuracy: 0.9158 - val_loss: 0.3240 - val_accuracy: 0.8886
Epoch 11/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2142 - accuracy: 0.9200 - val_loss: 0.3291 - val_accuracy: 0.8891
Epoch 12/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2079 - accuracy: 0.9227 - val_loss: 0.3122 - val_accuracy: 0.8942
Epoch 13/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1980 - accuracy: 0.9258 - val_loss: 0.3313 - val_accuracy: 0.8852
Epoch 14/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1919 - accuracy: 0.9278 - val_loss: 0.3244 - val_accuracy: 0.8930
Epoch 15/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1852 - accuracy: 0.9301 - val_loss: 0.3487 - val_accuracy: 0.8875
Epoch 16/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1802 - accuracy: 0.9319 - val_loss: 0.3368 - val_accuracy: 0.8927
Epoch 17/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1740 - accuracy: 0.9343 - val_loss: 0.3056 - val_accuracy: 0.8973
Epoch 18/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1663 - accuracy: 0.9379 - val_loss: 0.3311 - val_accuracy: 0.8967
Epoch 19/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1626 - accuracy: 0.9388 - val_loss: 0.3390 - val_accuracy: 0.8933
Epoch 20/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1562 - accuracy: 0.9417 - val_loss: 0.3752 - val_accuracy: 0.8853
Epoch 21/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1507 - accuracy: 0.9435 - val_loss: 0.3467 - val_accuracy: 0.8933
Epoch 22/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1504 - accuracy: 0.9433 - val_loss: 0.3458 - val_accuracy: 0.8991
Epoch 23/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1419 - accuracy: 0.9470 - val_loss: 0.3440 - val_accuracy: 0.8942
Epoch 24/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1418 - accuracy: 0.9471 - val_loss: 0.3936 - val_accuracy: 0.8887
Epoch 25/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1351 - accuracy: 0.9489 - val_loss: 0.3691 - val_accuracy: 0.8937
Epoch 26/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1338 - accuracy: 0.9496 - val_loss: 0.3733 - val_accuracy: 0.8947
Epoch 27/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1290 - accuracy: 0.9509 - val_loss: 0.3894 - val_accuracy: 0.8938
Epoch 28/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1234 - accuracy: 0.9538 - val_loss: 0.3997 - val_accuracy: 0.8886
Epoch 29/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1203 - accuracy: 0.9542 - val_loss: 0.3820 - val_accuracy: 0.8971
Epoch 30/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1168 - accuracy: 0.9565 - val_loss: 0.3945 - val_accuracy: 0.8955
Epoch 31/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1173 - accuracy: 0.9554 - val_loss: 0.4014 - val_accuracy: 0.8967
Epoch 32/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1121 - accuracy: 0.9580 - val_loss: 0.4028 - val_accuracy: 0.8915
Epoch 33/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1096 - accuracy: 0.9580 - val_loss: 0.4551 - val_accuracy: 0.8954
Epoch 34/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1077 - accuracy: 0.9591 - val_loss: 0.4140 - val_accuracy: 0.8919
Epoch 35/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1042 - accuracy: 0.9614 - val_loss: 0.4150 - val_accuracy: 0.8972
Epoch 36/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1043 - accuracy: 0.9614 - val_loss: 0.4508 - val_accuracy: 0.8942
Epoch 37/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0973 - accuracy: 0.9630 - val_loss: 0.4470 - val_accuracy: 0.8930
Epoch 38/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0996 - accuracy: 0.9619 - val_loss: 0.4453 - val_accuracy: 0.8966
Epoch 39/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0929 - accuracy: 0.9655 - val_loss: 0.4483 - val_accuracy: 0.8944
Epoch 40/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0941 - accuracy: 0.9646 - val_loss: 0.4605 - val_accuracy: 0.8967
Epoch 41/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0897 - accuracy: 0.9658 - val_loss: 0.4561 - val_accuracy: 0.8975
Epoch 42/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0905 - accuracy: 0.9671 - val_loss: 0.4831 - val_accuracy: 0.8962
Epoch 43/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0857 - accuracy: 0.9677 - val_loss: 0.4745 - val_accuracy: 0.8955
Epoch 44/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0910 - accuracy: 0.9664 - val_loss: 0.4529 - val_accuracy: 0.8947
Epoch 45/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0851 - accuracy: 0.9683 - val_loss: 0.4733 - val_accuracy: 0.8954
Epoch 46/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0772 - accuracy: 0.9710 - val_loss: 0.4811 - val_accuracy: 0.8967
Epoch 47/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0821 - accuracy: 0.9691 - val_loss: 0.5194 - val_accuracy: 0.8911
Epoch 48/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0799 - accuracy: 0.9713 - val_loss: 0.5141 - val_accuracy: 0.8944
Epoch 49/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0788 - accuracy: 0.9705 - val_loss: 0.4839 - val_accuracy: 0.8959
Epoch 50/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0766 - accuracy: 0.9709 - val_loss: 0.5130 - val_accuracy: 0.8944
Best epoch: 22

重新实例化超模型并使用上面的最佳周期数对其进行训练。

hypermodel = tuner.hypermodel.build(best_hps)

# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/22
1500/1500 [==============================] - 5s 3ms/step - loss: 0.4890 - accuracy: 0.8271 - val_loss: 0.4094 - val_accuracy: 0.8528
Epoch 2/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.3694 - accuracy: 0.8661 - val_loss: 0.3573 - val_accuracy: 0.8658
Epoch 3/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.3321 - accuracy: 0.8784 - val_loss: 0.3685 - val_accuracy: 0.8646
Epoch 4/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.3048 - accuracy: 0.8867 - val_loss: 0.3363 - val_accuracy: 0.8767
Epoch 5/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2875 - accuracy: 0.8927 - val_loss: 0.3549 - val_accuracy: 0.8754
Epoch 6/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2714 - accuracy: 0.8988 - val_loss: 0.3112 - val_accuracy: 0.8876
Epoch 7/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2579 - accuracy: 0.9040 - val_loss: 0.3201 - val_accuracy: 0.8858
Epoch 8/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2451 - accuracy: 0.9088 - val_loss: 0.3104 - val_accuracy: 0.8873
Epoch 9/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2330 - accuracy: 0.9135 - val_loss: 0.3134 - val_accuracy: 0.8934
Epoch 10/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2249 - accuracy: 0.9148 - val_loss: 0.3214 - val_accuracy: 0.8878
Epoch 11/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2158 - accuracy: 0.9186 - val_loss: 0.3415 - val_accuracy: 0.8825
Epoch 12/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2088 - accuracy: 0.9208 - val_loss: 0.3341 - val_accuracy: 0.8859
Epoch 13/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1998 - accuracy: 0.9250 - val_loss: 0.3628 - val_accuracy: 0.8788
Epoch 14/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1927 - accuracy: 0.9271 - val_loss: 0.3349 - val_accuracy: 0.8921
Epoch 15/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1833 - accuracy: 0.9309 - val_loss: 0.3353 - val_accuracy: 0.8920
Epoch 16/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1804 - accuracy: 0.9305 - val_loss: 0.3479 - val_accuracy: 0.8906
Epoch 17/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1723 - accuracy: 0.9351 - val_loss: 0.3435 - val_accuracy: 0.8903
Epoch 18/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1647 - accuracy: 0.9376 - val_loss: 0.3408 - val_accuracy: 0.8926
Epoch 19/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1635 - accuracy: 0.9377 - val_loss: 0.3771 - val_accuracy: 0.8845
Epoch 20/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1589 - accuracy: 0.9400 - val_loss: 0.3478 - val_accuracy: 0.8914
Epoch 21/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1522 - accuracy: 0.9438 - val_loss: 0.3609 - val_accuracy: 0.8873
Epoch 22/22
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1476 - accuracy: 0.9442 - val_loss: 0.3635 - val_accuracy: 0.8942
<keras.src.callbacks.History at 0x7fd98036f400>

要完成本教程,请在测试数据上评估超模型。

eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.3894 - accuracy: 0.8890
[test loss, test accuracy]: [0.38939228653907776, 0.8889999985694885]

my_dir/intro_to_kt 目录中包含了在超参数搜索期间每次试验(模型配置)运行的详细日志和检查点。如果重新运行超参数搜索,Keras Tuner 将使用这些日志中记录的现有状态来继续搜索。要停用此行为,请在实例化调节器时传递一个附加的 overwrite = True 参数。

总结

在本教程中,您学习了如何使用 Keras Tuner 调节模型的超参数。要详细了解 Keras Tuner,请查看以下其他资源:

另请查看 TensorBoard 中的 HParams Dashboard,以交互方式调节模型超参数。