Classifying CIFAR-10 with XLA

View on TensorFlow.org View source on GitHub

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

Load and normalize the dataset using the Keras API:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.is_gpu_available())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
WARNING:tensorflow:From <ipython-input-1-f643ac1e83e4>:4: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 3s 0us/step

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
196/196 [==============================] - 2s 11ms/step - loss: 2.0193 - accuracy: 0.2559 - val_loss: 1.8110 - val_accuracy: 0.3626
Epoch 1/25
196/196 [==============================] - 2s 9ms/step - loss: 2.0840 - accuracy: 0.2270 - val_loss: 1.8265 - val_accuracy: 0.3573
Epoch 2/25
196/196 [==============================] - 2s 9ms/step - loss: 1.7767 - accuracy: 0.3575 - val_loss: 1.6495 - val_accuracy: 0.4136
Epoch 3/25
196/196 [==============================] - 2s 10ms/step - loss: 1.6641 - accuracy: 0.3971 - val_loss: 1.5735 - val_accuracy: 0.4327
Epoch 4/25
196/196 [==============================] - 2s 10ms/step - loss: 1.5903 - accuracy: 0.4249 - val_loss: 1.5449 - val_accuracy: 0.4448
Epoch 5/25
196/196 [==============================] - 2s 9ms/step - loss: 1.5285 - accuracy: 0.4479 - val_loss: 1.4528 - val_accuracy: 0.4774
Epoch 6/25
196/196 [==============================] - 2s 10ms/step - loss: 1.4749 - accuracy: 0.4652 - val_loss: 1.3693 - val_accuracy: 0.5074
Epoch 7/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4283 - accuracy: 0.4853 - val_loss: 1.3747 - val_accuracy: 0.5166
Epoch 8/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3936 - accuracy: 0.5001 - val_loss: 1.3233 - val_accuracy: 0.5285
Epoch 9/25
196/196 [==============================] - 2s 10ms/step - loss: 1.3559 - accuracy: 0.5130 - val_loss: 1.3176 - val_accuracy: 0.5255
Epoch 10/25
196/196 [==============================] - 2s 10ms/step - loss: 1.3232 - accuracy: 0.5260 - val_loss: 1.2480 - val_accuracy: 0.5586
Epoch 11/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2973 - accuracy: 0.5380 - val_loss: 1.2204 - val_accuracy: 0.5689
Epoch 12/25
196/196 [==============================] - 2s 10ms/step - loss: 1.2711 - accuracy: 0.5465 - val_loss: 1.2408 - val_accuracy: 0.5650
Epoch 13/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2496 - accuracy: 0.5536 - val_loss: 1.2451 - val_accuracy: 0.5643
Epoch 14/25
196/196 [==============================] - 2s 10ms/step - loss: 1.2233 - accuracy: 0.5675 - val_loss: 1.2038 - val_accuracy: 0.5764
Epoch 15/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2014 - accuracy: 0.5734 - val_loss: 1.1627 - val_accuracy: 0.5937
Epoch 16/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1780 - accuracy: 0.5840 - val_loss: 1.1585 - val_accuracy: 0.5869
Epoch 17/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1595 - accuracy: 0.5911 - val_loss: 1.1312 - val_accuracy: 0.6018
Epoch 18/25
196/196 [==============================] - 2s 10ms/step - loss: 1.1393 - accuracy: 0.5978 - val_loss: 1.0902 - val_accuracy: 0.6195
Epoch 19/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1168 - accuracy: 0.6059 - val_loss: 1.0546 - val_accuracy: 0.6313
Epoch 20/25
196/196 [==============================] - 2s 10ms/step - loss: 1.1001 - accuracy: 0.6129 - val_loss: 1.0592 - val_accuracy: 0.6301
Epoch 21/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0825 - accuracy: 0.6183 - val_loss: 1.0206 - val_accuracy: 0.6450
Epoch 22/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0663 - accuracy: 0.6243 - val_loss: 1.0352 - val_accuracy: 0.6385
Epoch 23/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0490 - accuracy: 0.6297 - val_loss: 1.0142 - val_accuracy: 0.6472
Epoch 24/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0333 - accuracy: 0.6356 - val_loss: 0.9798 - val_accuracy: 0.6592
Epoch 25/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0183 - accuracy: 0.6430 - val_loss: 1.0495 - val_accuracy: 0.6347
CPU times: user 49.9 s, sys: 13 s, total: 1min 2s
Wall time: 47.8 s
313/313 [==============================] - 1s 2ms/step - loss: 1.0495 - accuracy: 0.6347
Test loss: 1.049464225769043
Test accuracy: 0.6347000002861023

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 3s 14ms/step - loss: 2.0731 - accuracy: 0.2372 - val_loss: 1.8684 - val_accuracy: 0.3402
Epoch 1/25
196/196 [==============================] - 4s 20ms/step - loss: 2.1427 - accuracy: 0.2041 - val_loss: 1.8909 - val_accuracy: 0.3531
Epoch 2/25
196/196 [==============================] - 2s 9ms/step - loss: 1.8071 - accuracy: 0.3521 - val_loss: 1.6863 - val_accuracy: 0.4039
Epoch 3/25
196/196 [==============================] - 2s 9ms/step - loss: 1.6757 - accuracy: 0.3988 - val_loss: 1.5640 - val_accuracy: 0.4415
Epoch 4/25
196/196 [==============================] - 2s 9ms/step - loss: 1.5942 - accuracy: 0.4277 - val_loss: 1.4942 - val_accuracy: 0.4728
Epoch 5/25
196/196 [==============================] - 2s 9ms/step - loss: 1.5202 - accuracy: 0.4543 - val_loss: 1.5023 - val_accuracy: 0.4714
Epoch 6/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4680 - accuracy: 0.4733 - val_loss: 1.3882 - val_accuracy: 0.5049
Epoch 7/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4194 - accuracy: 0.4918 - val_loss: 1.3543 - val_accuracy: 0.5186
Epoch 8/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3859 - accuracy: 0.5044 - val_loss: 1.3078 - val_accuracy: 0.5310
Epoch 9/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3481 - accuracy: 0.5204 - val_loss: 1.2878 - val_accuracy: 0.5474
Epoch 10/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3190 - accuracy: 0.5306 - val_loss: 1.2528 - val_accuracy: 0.5604
Epoch 11/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2866 - accuracy: 0.5440 - val_loss: 1.2522 - val_accuracy: 0.5699
Epoch 12/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2591 - accuracy: 0.5533 - val_loss: 1.1970 - val_accuracy: 0.5863
Epoch 13/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2324 - accuracy: 0.5638 - val_loss: 1.1471 - val_accuracy: 0.5974
Epoch 14/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2113 - accuracy: 0.5722 - val_loss: 1.1189 - val_accuracy: 0.6052
Epoch 15/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1818 - accuracy: 0.5816 - val_loss: 1.1503 - val_accuracy: 0.5895
Epoch 16/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1627 - accuracy: 0.5906 - val_loss: 1.0991 - val_accuracy: 0.6144
Epoch 17/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1380 - accuracy: 0.6007 - val_loss: 1.0769 - val_accuracy: 0.6208
Epoch 18/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1201 - accuracy: 0.6058 - val_loss: 1.0661 - val_accuracy: 0.6291
Epoch 19/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1032 - accuracy: 0.6135 - val_loss: 1.0816 - val_accuracy: 0.6156
Epoch 20/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0831 - accuracy: 0.6202 - val_loss: 1.0150 - val_accuracy: 0.6443
Epoch 21/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0657 - accuracy: 0.6269 - val_loss: 1.0540 - val_accuracy: 0.6288
Epoch 22/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0475 - accuracy: 0.6328 - val_loss: 1.0075 - val_accuracy: 0.6454
Epoch 23/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0331 - accuracy: 0.6373 - val_loss: 0.9896 - val_accuracy: 0.6562
Epoch 24/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0168 - accuracy: 0.6448 - val_loss: 0.9541 - val_accuracy: 0.6706
Epoch 25/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0016 - accuracy: 0.6531 - val_loss: 1.0273 - val_accuracy: 0.6409
CPU times: user 47.5 s, sys: 9.32 s, total: 56.8 s
Wall time: 47.1 s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.