Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

Introduction to the Keras Tuner

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

The Keras Tuner is a library that helps you pick the optimal set of hyperparameters for your TensorFlow program. The process of selecting the right set of hyperparameters for your machine learning (ML) application is called hyperparameter tuning or hypertuning.

Hyperparameters are the variables that govern the training process and the topology of an ML model. These variables remain constant over the training process and directly impact the performance of your ML program. Hyperparameters are of two types:

  1. Model hyperparameters which influence model selection such as the number and width of hidden layers
  2. Algorithm hyperparameters which influence the speed and quality of the learning algorithm such as the learning rate for Stochastic Gradient Descent (SGD) and the number of nearest neighbors for a k Nearest Neighbors (KNN) classifier

In this tutorial, you will use the Keras Tuner to perform hypertuning for an image classification application.

Setup

import tensorflow as tf
from tensorflow import keras
2022-12-14 06:54:16.934342: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 06:54:16.934438: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 06:54:16.934447: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Install and import the Keras Tuner.

pip install -q -U keras-tuner
import keras_tuner as kt

Download and prepare the dataset

In this tutorial, you will use the Keras Tuner to find the best hyperparameters for a machine learning model that classifies images of clothing from the Fashion MNIST dataset.

Load the data.

(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0

Define the model

When you build a model for hypertuning, you also define the hyperparameter search space in addition to the model architecture. The model you set up for hypertuning is called a hypermodel.

You can define a hypermodel through two approaches:

  • By using a model builder function
  • By subclassing the HyperModel class of the Keras Tuner API

You can also use two pre-defined HyperModel classes - HyperXception and HyperResNet for computer vision applications.

In this tutorial, you use a model builder function to define the image classification model. The model builder function returns a compiled model and uses hyperparameters you define inline to hypertune the model.

def model_builder(hp):
  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28)))

  # Tune the number of units in the first Dense layer
  # Choose an optimal value between 32-512
  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))
  model.add(keras.layers.Dense(10))

  # Tune the learning rate for the optimizer
  # Choose an optimal value from 0.01, 0.001, or 0.0001
  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

Instantiate the tuner and perform hypertuning

Instantiate the tuner to perform the hypertuning. The Keras Tuner has four tuners available - RandomSearch, Hyperband, BayesianOptimization, and Sklearn. In this tutorial, you use the Hyperband tuner.

To instantiate the Hyperband tuner, you must specify the hypermodel, the objective to optimize and the maximum number of epochs to train (max_epochs).

tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3,
                     directory='my_dir',
                     project_name='intro_to_kt')

The Hyperband tuning algorithm uses adaptive resource allocation and early-stopping to quickly converge on a high-performing model. This is done using a sports championship style bracket. The algorithm trains a large number of models for a few epochs and carries forward only the top-performing half of models to the next round. Hyperband determines the number of models to train in a bracket by computing 1 + logfactor(max_epochs) and rounding it up to the nearest integer.

Create a callback to stop training early after reaching a certain value for the validation loss.

stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

Run the hyperparameter search. The arguments for the search method are the same as those used for tf.keras.model.fit in addition to the callback above.

tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])

# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]

print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 39s]
val_accuracy: 0.8883333206176758

Best val_accuracy So Far: 0.8901666402816772
Total elapsed time: 00h 08m 13s
INFO:tensorflow:Oracle triggered exit

The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is 288 and the optimal learning rate for the optimizer
is 0.001.

Train the model

Find the optimal number of epochs to train the model with the hyperparameters obtained from the search.

# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)

val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.5038 - accuracy: 0.8215 - val_loss: 0.3984 - val_accuracy: 0.8597
Epoch 2/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3729 - accuracy: 0.8632 - val_loss: 0.3585 - val_accuracy: 0.8723
Epoch 3/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3323 - accuracy: 0.8773 - val_loss: 0.3729 - val_accuracy: 0.8637
Epoch 4/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3079 - accuracy: 0.8869 - val_loss: 0.3604 - val_accuracy: 0.8698
Epoch 5/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2919 - accuracy: 0.8911 - val_loss: 0.3466 - val_accuracy: 0.8786
Epoch 6/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2741 - accuracy: 0.8976 - val_loss: 0.3207 - val_accuracy: 0.8847
Epoch 7/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2589 - accuracy: 0.9040 - val_loss: 0.3203 - val_accuracy: 0.8887
Epoch 8/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2474 - accuracy: 0.9085 - val_loss: 0.3094 - val_accuracy: 0.8891
Epoch 9/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2372 - accuracy: 0.9120 - val_loss: 0.3101 - val_accuracy: 0.8918
Epoch 10/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2288 - accuracy: 0.9150 - val_loss: 0.3281 - val_accuracy: 0.8862
Epoch 11/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2190 - accuracy: 0.9186 - val_loss: 0.3210 - val_accuracy: 0.8892
Epoch 12/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2116 - accuracy: 0.9202 - val_loss: 0.3219 - val_accuracy: 0.8864
Epoch 13/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2011 - accuracy: 0.9260 - val_loss: 0.3163 - val_accuracy: 0.8939
Epoch 14/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1969 - accuracy: 0.9264 - val_loss: 0.3533 - val_accuracy: 0.8886
Epoch 15/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1900 - accuracy: 0.9289 - val_loss: 0.3135 - val_accuracy: 0.8970
Epoch 16/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1857 - accuracy: 0.9299 - val_loss: 0.3227 - val_accuracy: 0.8937
Epoch 17/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1787 - accuracy: 0.9343 - val_loss: 0.3268 - val_accuracy: 0.8932
Epoch 18/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1722 - accuracy: 0.9360 - val_loss: 0.3658 - val_accuracy: 0.8832
Epoch 19/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1670 - accuracy: 0.9374 - val_loss: 0.3479 - val_accuracy: 0.8947
Epoch 20/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1615 - accuracy: 0.9396 - val_loss: 0.3638 - val_accuracy: 0.8888
Epoch 21/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1571 - accuracy: 0.9413 - val_loss: 0.3411 - val_accuracy: 0.8980
Epoch 22/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1535 - accuracy: 0.9413 - val_loss: 0.3458 - val_accuracy: 0.8933
Epoch 23/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1492 - accuracy: 0.9457 - val_loss: 0.3653 - val_accuracy: 0.8919
Epoch 24/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1428 - accuracy: 0.9454 - val_loss: 0.4007 - val_accuracy: 0.8816
Epoch 25/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1411 - accuracy: 0.9469 - val_loss: 0.3806 - val_accuracy: 0.8917
Epoch 26/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1368 - accuracy: 0.9488 - val_loss: 0.3644 - val_accuracy: 0.8955
Epoch 27/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1308 - accuracy: 0.9508 - val_loss: 0.3870 - val_accuracy: 0.8889
Epoch 28/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1293 - accuracy: 0.9519 - val_loss: 0.3786 - val_accuracy: 0.8949
Epoch 29/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1246 - accuracy: 0.9532 - val_loss: 0.4124 - val_accuracy: 0.8873
Epoch 30/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1246 - accuracy: 0.9535 - val_loss: 0.3897 - val_accuracy: 0.8950
Epoch 31/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1191 - accuracy: 0.9550 - val_loss: 0.4175 - val_accuracy: 0.8882
Epoch 32/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1171 - accuracy: 0.9565 - val_loss: 0.4166 - val_accuracy: 0.8927
Epoch 33/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1120 - accuracy: 0.9585 - val_loss: 0.4328 - val_accuracy: 0.8921
Epoch 34/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1095 - accuracy: 0.9599 - val_loss: 0.4163 - val_accuracy: 0.8942
Epoch 35/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1079 - accuracy: 0.9594 - val_loss: 0.4387 - val_accuracy: 0.8890
Epoch 36/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1056 - accuracy: 0.9600 - val_loss: 0.4339 - val_accuracy: 0.8924
Epoch 37/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1021 - accuracy: 0.9621 - val_loss: 0.4474 - val_accuracy: 0.8911
Epoch 38/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0980 - accuracy: 0.9630 - val_loss: 0.4563 - val_accuracy: 0.8892
Epoch 39/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0996 - accuracy: 0.9626 - val_loss: 0.4329 - val_accuracy: 0.8937
Epoch 40/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0962 - accuracy: 0.9644 - val_loss: 0.4605 - val_accuracy: 0.8925
Epoch 41/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0929 - accuracy: 0.9648 - val_loss: 0.4717 - val_accuracy: 0.8913
Epoch 42/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0949 - accuracy: 0.9644 - val_loss: 0.4650 - val_accuracy: 0.8888
Epoch 43/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0861 - accuracy: 0.9679 - val_loss: 0.4899 - val_accuracy: 0.8948
Epoch 44/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0882 - accuracy: 0.9665 - val_loss: 0.4779 - val_accuracy: 0.8929
Epoch 45/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0870 - accuracy: 0.9666 - val_loss: 0.4873 - val_accuracy: 0.8878
Epoch 46/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0842 - accuracy: 0.9685 - val_loss: 0.4929 - val_accuracy: 0.8953
Epoch 47/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0829 - accuracy: 0.9690 - val_loss: 0.5250 - val_accuracy: 0.8905
Epoch 48/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0796 - accuracy: 0.9705 - val_loss: 0.4885 - val_accuracy: 0.8967
Epoch 49/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0781 - accuracy: 0.9703 - val_loss: 0.5235 - val_accuracy: 0.8878
Epoch 50/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0786 - accuracy: 0.9696 - val_loss: 0.5510 - val_accuracy: 0.8874
Best epoch: 21

Re-instantiate the hypermodel and train it with the optimal number of epochs from above.

hypermodel = tuner.hypermodel.build(best_hps)

# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.4994 - accuracy: 0.8230 - val_loss: 0.4522 - val_accuracy: 0.8298
Epoch 2/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3687 - accuracy: 0.8662 - val_loss: 0.3484 - val_accuracy: 0.8746
Epoch 3/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3322 - accuracy: 0.8766 - val_loss: 0.3518 - val_accuracy: 0.8705
Epoch 4/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3064 - accuracy: 0.8869 - val_loss: 0.3297 - val_accuracy: 0.8784
Epoch 5/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2864 - accuracy: 0.8929 - val_loss: 0.3257 - val_accuracy: 0.8808
Epoch 6/21
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2745 - accuracy: 0.8981 - val_loss: 0.3108 - val_accuracy: 0.8867
Epoch 7/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2596 - accuracy: 0.9034 - val_loss: 0.3166 - val_accuracy: 0.8898
Epoch 8/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2476 - accuracy: 0.9081 - val_loss: 0.3186 - val_accuracy: 0.8896
Epoch 9/21
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2375 - accuracy: 0.9105 - val_loss: 0.3137 - val_accuracy: 0.8887
Epoch 10/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2298 - accuracy: 0.9143 - val_loss: 0.3172 - val_accuracy: 0.8888
Epoch 11/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2186 - accuracy: 0.9183 - val_loss: 0.3237 - val_accuracy: 0.8914
Epoch 12/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2119 - accuracy: 0.9206 - val_loss: 0.3147 - val_accuracy: 0.8894
Epoch 13/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2026 - accuracy: 0.9251 - val_loss: 0.3270 - val_accuracy: 0.8925
Epoch 14/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1957 - accuracy: 0.9265 - val_loss: 0.3337 - val_accuracy: 0.8857
Epoch 15/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1897 - accuracy: 0.9284 - val_loss: 0.3502 - val_accuracy: 0.8862
Epoch 16/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1825 - accuracy: 0.9317 - val_loss: 0.3265 - val_accuracy: 0.8938
Epoch 17/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1760 - accuracy: 0.9333 - val_loss: 0.3308 - val_accuracy: 0.8932
Epoch 18/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1692 - accuracy: 0.9367 - val_loss: 0.3246 - val_accuracy: 0.8969
Epoch 19/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1648 - accuracy: 0.9386 - val_loss: 0.3394 - val_accuracy: 0.8949
Epoch 20/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1598 - accuracy: 0.9393 - val_loss: 0.3543 - val_accuracy: 0.8926
Epoch 21/21
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1522 - accuracy: 0.9421 - val_loss: 0.3691 - val_accuracy: 0.8895
<keras.callbacks.History at 0x7f4156803100>

To finish this tutorial, evaluate the hypermodel on the test data.

eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.4042 - accuracy: 0.8819
[test loss, test accuracy]: [0.4041593372821808, 0.8819000124931335]

The my_dir/intro_to_kt directory contains detailed logs and checkpoints for every trial (model configuration) run during the hyperparameter search. If you re-run the hyperparameter search, the Keras Tuner uses the existing state from these logs to resume the search. To disable this behavior, pass an additional overwrite=True argument while instantiating the tuner.

Summary

In this tutorial, you learned how to use the Keras Tuner to tune hyperparameters for a model. To learn more about the Keras Tuner, check out these additional resources:

Also check out the HParams Dashboard in TensorBoard to interactively tune your model hyperparameters.