![]() |
![]() |
![]() |
![]() |
Overview
The Keras Tuner is a library that helps you pick the optimal set of hyperparameters for your TensorFlow program. The process of selecting the right set of hyperparameters for your machine learning (ML) application is called hyperparameter tuning or hypertuning.
Hyperparameters are the variables that govern the training process and the topology of an ML model. These variables remain constant over the training process and directly impact the performance of your ML program. Hyperparameters are of two types:
- Model hyperparameters which influence model selection such as the number and width of hidden layers
- Algorithm hyperparameters which influence the speed and quality of the learning algorithm such as the learning rate for Stochastic Gradient Descent (SGD) and the number of nearest neighbors for a k Nearest Neighbors (KNN) classifier
In this tutorial, you will use the Keras Tuner to perform hypertuning for an image classification application.
Setup
import tensorflow as tf
from tensorflow import keras
2023-09-28 07:01:16.582823: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-09-28 07:01:16.582867: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-09-28 07:01:16.582907: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Install and import the Keras Tuner.
pip install -q -U keras-tuner
import keras_tuner as kt
Using TensorFlow backend
Download and prepare the dataset
In this tutorial, you will use the Keras Tuner to find the best hyperparameters for a machine learning model that classifies images of clothing from the Fashion MNIST dataset.
Load the data.
(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
Define the model
When you build a model for hypertuning, you also define the hyperparameter search space in addition to the model architecture. The model you set up for hypertuning is called a hypermodel.
You can define a hypermodel through two approaches:
- By using a model builder function
- By subclassing the
HyperModel
class of the Keras Tuner API
You can also use two pre-defined HyperModel classes - HyperXception and HyperResNet for computer vision applications.
In this tutorial, you use a model builder function to define the image classification model. The model builder function returns a compiled model and uses hyperparameters you define inline to hypertune the model.
def model_builder(hp):
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
# Tune the number of units in the first Dense layer
# Choose an optimal value between 32-512
hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
model.add(keras.layers.Dense(units=hp_units, activation='relu'))
model.add(keras.layers.Dense(10))
# Tune the learning rate for the optimizer
# Choose an optimal value from 0.01, 0.001, or 0.0001
hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
Instantiate the tuner and perform hypertuning
Instantiate the tuner to perform the hypertuning. The Keras Tuner has four tuners available - RandomSearch
, Hyperband
, BayesianOptimization
, and Sklearn
. In this tutorial, you use the Hyperband tuner.
To instantiate the Hyperband tuner, you must specify the hypermodel, the objective
to optimize and the maximum number of epochs to train (max_epochs
).
tuner = kt.Hyperband(model_builder,
objective='val_accuracy',
max_epochs=10,
factor=3,
directory='my_dir',
project_name='intro_to_kt')
2023-09-28 07:01:23.760824: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
The Hyperband tuning algorithm uses adaptive resource allocation and early-stopping to quickly converge on a high-performing model. This is done using a sports championship style bracket. The algorithm trains a large number of models for a few epochs and carries forward only the top-performing half of models to the next round. Hyperband determines the number of models to train in a bracket by computing 1 + logfactor
(max_epochs
) and rounding it up to the nearest integer.
Create a callback to stop training early after reaching a certain value for the validation loss.
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
Run the hyperparameter search. The arguments for the search method are the same as those used for tf.keras.model.fit
in addition to the callback above.
tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])
# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]
print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 43s] val_accuracy: 0.878250002861023 Best val_accuracy So Far: 0.890916645526886 Total elapsed time: 00h 09m 27s The hyperparameter search is complete. The optimal number of units in the first densely-connected layer is 416 and the optimal learning rate for the optimizer is 0.001.
Train the model
Find the optimal number of epochs to train the model with the hyperparameters obtained from the search.
# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)
val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.5015 - accuracy: 0.8232 - val_loss: 0.4154 - val_accuracy: 0.8482 Epoch 2/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.3713 - accuracy: 0.8648 - val_loss: 0.3429 - val_accuracy: 0.8768 Epoch 3/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.3324 - accuracy: 0.8776 - val_loss: 0.3480 - val_accuracy: 0.8764 Epoch 4/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.3053 - accuracy: 0.8878 - val_loss: 0.3395 - val_accuracy: 0.8793 Epoch 5/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2845 - accuracy: 0.8960 - val_loss: 0.3207 - val_accuracy: 0.8854 Epoch 6/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2711 - accuracy: 0.8991 - val_loss: 0.3303 - val_accuracy: 0.8804 Epoch 7/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2552 - accuracy: 0.9057 - val_loss: 0.3281 - val_accuracy: 0.8837 Epoch 8/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2440 - accuracy: 0.9091 - val_loss: 0.3024 - val_accuracy: 0.8920 Epoch 9/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2353 - accuracy: 0.9120 - val_loss: 0.3106 - val_accuracy: 0.8926 Epoch 10/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2265 - accuracy: 0.9150 - val_loss: 0.3229 - val_accuracy: 0.8887 Epoch 11/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2169 - accuracy: 0.9181 - val_loss: 0.3253 - val_accuracy: 0.8820 Epoch 12/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2073 - accuracy: 0.9210 - val_loss: 0.3306 - val_accuracy: 0.8910 Epoch 13/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2013 - accuracy: 0.9251 - val_loss: 0.3131 - val_accuracy: 0.8945 Epoch 14/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1927 - accuracy: 0.9263 - val_loss: 0.3253 - val_accuracy: 0.8910 Epoch 15/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1854 - accuracy: 0.9307 - val_loss: 0.3469 - val_accuracy: 0.8848 Epoch 16/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1787 - accuracy: 0.9333 - val_loss: 0.3410 - val_accuracy: 0.8910 Epoch 17/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1753 - accuracy: 0.9340 - val_loss: 0.3412 - val_accuracy: 0.8932 Epoch 18/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1676 - accuracy: 0.9358 - val_loss: 0.3349 - val_accuracy: 0.8967 Epoch 19/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1604 - accuracy: 0.9406 - val_loss: 0.3527 - val_accuracy: 0.8890 Epoch 20/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1590 - accuracy: 0.9406 - val_loss: 0.3512 - val_accuracy: 0.8924 Epoch 21/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1512 - accuracy: 0.9430 - val_loss: 0.3380 - val_accuracy: 0.8944 Epoch 22/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1464 - accuracy: 0.9441 - val_loss: 0.3741 - val_accuracy: 0.8950 Epoch 23/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1441 - accuracy: 0.9461 - val_loss: 0.3800 - val_accuracy: 0.8921 Epoch 24/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1366 - accuracy: 0.9485 - val_loss: 0.4073 - val_accuracy: 0.8826 Epoch 25/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1377 - accuracy: 0.9471 - val_loss: 0.4091 - val_accuracy: 0.8846 Epoch 26/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1329 - accuracy: 0.9501 - val_loss: 0.3760 - val_accuracy: 0.8923 Epoch 27/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1270 - accuracy: 0.9519 - val_loss: 0.3927 - val_accuracy: 0.8915 Epoch 28/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1247 - accuracy: 0.9528 - val_loss: 0.3965 - val_accuracy: 0.8917 Epoch 29/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1210 - accuracy: 0.9548 - val_loss: 0.4010 - val_accuracy: 0.8926 Epoch 30/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1172 - accuracy: 0.9565 - val_loss: 0.3914 - val_accuracy: 0.8957 Epoch 31/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1165 - accuracy: 0.9562 - val_loss: 0.4254 - val_accuracy: 0.8948 Epoch 32/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1113 - accuracy: 0.9590 - val_loss: 0.4485 - val_accuracy: 0.8901 Epoch 33/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1098 - accuracy: 0.9587 - val_loss: 0.4156 - val_accuracy: 0.8934 Epoch 34/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1058 - accuracy: 0.9609 - val_loss: 0.4374 - val_accuracy: 0.8932 Epoch 35/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1028 - accuracy: 0.9624 - val_loss: 0.4520 - val_accuracy: 0.8925 Epoch 36/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0995 - accuracy: 0.9631 - val_loss: 0.4732 - val_accuracy: 0.8913 Epoch 37/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0987 - accuracy: 0.9625 - val_loss: 0.4268 - val_accuracy: 0.8972 Epoch 38/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0952 - accuracy: 0.9644 - val_loss: 0.4506 - val_accuracy: 0.8953 Epoch 39/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0930 - accuracy: 0.9656 - val_loss: 0.4726 - val_accuracy: 0.8932 Epoch 40/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0899 - accuracy: 0.9650 - val_loss: 0.4744 - val_accuracy: 0.8917 Epoch 41/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0887 - accuracy: 0.9671 - val_loss: 0.5256 - val_accuracy: 0.8913 Epoch 42/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0883 - accuracy: 0.9670 - val_loss: 0.4868 - val_accuracy: 0.8953 Epoch 43/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0848 - accuracy: 0.9689 - val_loss: 0.4819 - val_accuracy: 0.8953 Epoch 44/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0870 - accuracy: 0.9672 - val_loss: 0.5638 - val_accuracy: 0.8848 Epoch 45/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0817 - accuracy: 0.9698 - val_loss: 0.5141 - val_accuracy: 0.8905 Epoch 46/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0794 - accuracy: 0.9704 - val_loss: 0.5411 - val_accuracy: 0.8902 Epoch 47/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0828 - accuracy: 0.9694 - val_loss: 0.5208 - val_accuracy: 0.8920 Epoch 48/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0763 - accuracy: 0.9707 - val_loss: 0.5354 - val_accuracy: 0.8883 Epoch 49/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0749 - accuracy: 0.9717 - val_loss: 0.5483 - val_accuracy: 0.8907 Epoch 50/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0744 - accuracy: 0.9725 - val_loss: 0.5422 - val_accuracy: 0.8936 Best epoch: 37
Re-instantiate the hypermodel and train it with the optimal number of epochs from above.
hypermodel = tuner.hypermodel.build(best_hps)
# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.4886 - accuracy: 0.8290 - val_loss: 0.4179 - val_accuracy: 0.8546 Epoch 2/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.3672 - accuracy: 0.8651 - val_loss: 0.3462 - val_accuracy: 0.8750 Epoch 3/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.3309 - accuracy: 0.8789 - val_loss: 0.3543 - val_accuracy: 0.8720 Epoch 4/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.3057 - accuracy: 0.8861 - val_loss: 0.3413 - val_accuracy: 0.8767 Epoch 5/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2852 - accuracy: 0.8934 - val_loss: 0.3232 - val_accuracy: 0.8855 Epoch 6/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2694 - accuracy: 0.8992 - val_loss: 0.3517 - val_accuracy: 0.8782 Epoch 7/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2580 - accuracy: 0.9041 - val_loss: 0.3315 - val_accuracy: 0.8846 Epoch 8/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2439 - accuracy: 0.9084 - val_loss: 0.3281 - val_accuracy: 0.8870 Epoch 9/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2340 - accuracy: 0.9125 - val_loss: 0.3325 - val_accuracy: 0.8851 Epoch 10/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2220 - accuracy: 0.9163 - val_loss: 0.3119 - val_accuracy: 0.8917 Epoch 11/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2148 - accuracy: 0.9189 - val_loss: 0.3411 - val_accuracy: 0.8831 Epoch 12/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.2081 - accuracy: 0.9210 - val_loss: 0.3186 - val_accuracy: 0.8894 Epoch 13/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1970 - accuracy: 0.9252 - val_loss: 0.3442 - val_accuracy: 0.8882 Epoch 14/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1911 - accuracy: 0.9280 - val_loss: 0.3057 - val_accuracy: 0.8994 Epoch 15/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1865 - accuracy: 0.9287 - val_loss: 0.3203 - val_accuracy: 0.8938 Epoch 16/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1770 - accuracy: 0.9333 - val_loss: 0.3222 - val_accuracy: 0.8953 Epoch 17/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1723 - accuracy: 0.9352 - val_loss: 0.3336 - val_accuracy: 0.8944 Epoch 18/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1672 - accuracy: 0.9368 - val_loss: 0.3268 - val_accuracy: 0.8963 Epoch 19/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1596 - accuracy: 0.9413 - val_loss: 0.3511 - val_accuracy: 0.8930 Epoch 20/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1599 - accuracy: 0.9404 - val_loss: 0.3686 - val_accuracy: 0.8907 Epoch 21/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1506 - accuracy: 0.9420 - val_loss: 0.3596 - val_accuracy: 0.8914 Epoch 22/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1460 - accuracy: 0.9452 - val_loss: 0.3511 - val_accuracy: 0.8975 Epoch 23/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1445 - accuracy: 0.9460 - val_loss: 0.4118 - val_accuracy: 0.8825 Epoch 24/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1384 - accuracy: 0.9477 - val_loss: 0.3831 - val_accuracy: 0.8900 Epoch 25/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1340 - accuracy: 0.9499 - val_loss: 0.3846 - val_accuracy: 0.8949 Epoch 26/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1309 - accuracy: 0.9511 - val_loss: 0.3845 - val_accuracy: 0.8950 Epoch 27/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1268 - accuracy: 0.9521 - val_loss: 0.3950 - val_accuracy: 0.8967 Epoch 28/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1242 - accuracy: 0.9530 - val_loss: 0.3944 - val_accuracy: 0.8941 Epoch 29/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1196 - accuracy: 0.9549 - val_loss: 0.4600 - val_accuracy: 0.8852 Epoch 30/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1161 - accuracy: 0.9557 - val_loss: 0.4099 - val_accuracy: 0.8901 Epoch 31/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1137 - accuracy: 0.9572 - val_loss: 0.4176 - val_accuracy: 0.8960 Epoch 32/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1115 - accuracy: 0.9578 - val_loss: 0.4383 - val_accuracy: 0.8903 Epoch 33/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1063 - accuracy: 0.9603 - val_loss: 0.4602 - val_accuracy: 0.8877 Epoch 34/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1097 - accuracy: 0.9593 - val_loss: 0.4323 - val_accuracy: 0.8947 Epoch 35/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1048 - accuracy: 0.9601 - val_loss: 0.4370 - val_accuracy: 0.8925 Epoch 36/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.0995 - accuracy: 0.9624 - val_loss: 0.4267 - val_accuracy: 0.8924 Epoch 37/37 1500/1500 [==============================] - 5s 3ms/step - loss: 0.1005 - accuracy: 0.9619 - val_loss: 0.4724 - val_accuracy: 0.8898 <keras.src.callbacks.History at 0x7f3b94fc5dc0>
To finish this tutorial, evaluate the hypermodel on the test data.
eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.5181 - accuracy: 0.8854 [test loss, test accuracy]: [0.5180676579475403, 0.8853999972343445]
The my_dir/intro_to_kt
directory contains detailed logs and checkpoints for every trial (model configuration) run during the hyperparameter search. If you re-run the hyperparameter search, the Keras Tuner uses the existing state from these logs to resume the search. To disable this behavior, pass an additional overwrite=True
argument while instantiating the tuner.
Summary
In this tutorial, you learned how to use the Keras Tuner to tune hyperparameters for a model. To learn more about the Keras Tuner, check out these additional resources:
Also check out the HParams Dashboard in TensorBoard to interactively tune your model hyperparameters.