Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

You will load and normalize the dataset using the TensorFlow Datasets (TFDS) API. First, install/upgrade TensorFlow and TFDS:

pip install -U -q tensorflow tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
2024-02-03 12:09:14.590237: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-03 12:09:14.590280: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-03 12:09:14.591721: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  result = tfds.load('cifar10', batch_size = -1)
  (x_train, y_train) = result['train']['image'],result['train']['label']
  (x_test, y_test) = result['test']['image'],result['test']['label']

  x_train = x_train.numpy().astype('float32') / 256
  x_test = x_test.numpy().astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
2024-02-03 12:09:26.468614: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1706962168.898568   11354 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
196/196 [==============================] - 10s 27ms/step - loss: 2.1226 - accuracy: 0.2118 - val_loss: 1.9594 - val_accuracy: 0.2991
Epoch 1/25
196/196 [==============================] - 4s 22ms/step - loss: 2.1502 - accuracy: 0.1988 - val_loss: 1.9611 - val_accuracy: 0.2884
Epoch 2/25
196/196 [==============================] - 4s 20ms/step - loss: 1.8868 - accuracy: 0.3186 - val_loss: 1.7768 - val_accuracy: 0.3798
Epoch 3/25
196/196 [==============================] - 4s 20ms/step - loss: 1.7488 - accuracy: 0.3703 - val_loss: 1.6610 - val_accuracy: 0.4168
Epoch 4/25
196/196 [==============================] - 4s 20ms/step - loss: 1.6683 - accuracy: 0.3992 - val_loss: 1.5988 - val_accuracy: 0.4304
Epoch 5/25
196/196 [==============================] - 4s 20ms/step - loss: 1.6053 - accuracy: 0.4227 - val_loss: 1.5062 - val_accuracy: 0.4607
Epoch 6/25
196/196 [==============================] - 4s 21ms/step - loss: 1.5541 - accuracy: 0.4422 - val_loss: 1.5750 - val_accuracy: 0.4431
Epoch 7/25
196/196 [==============================] - 4s 21ms/step - loss: 1.5025 - accuracy: 0.4623 - val_loss: 1.4528 - val_accuracy: 0.4763
Epoch 8/25
196/196 [==============================] - 4s 21ms/step - loss: 1.4576 - accuracy: 0.4784 - val_loss: 1.3644 - val_accuracy: 0.5173
Epoch 9/25
196/196 [==============================] - 4s 21ms/step - loss: 1.4211 - accuracy: 0.4912 - val_loss: 1.3749 - val_accuracy: 0.5028
Epoch 10/25
196/196 [==============================] - 4s 21ms/step - loss: 1.3876 - accuracy: 0.5028 - val_loss: 1.3015 - val_accuracy: 0.5364
Epoch 11/25
196/196 [==============================] - 4s 21ms/step - loss: 1.3526 - accuracy: 0.5168 - val_loss: 1.2784 - val_accuracy: 0.5530
Epoch 12/25
196/196 [==============================] - 4s 21ms/step - loss: 1.3231 - accuracy: 0.5252 - val_loss: 1.2540 - val_accuracy: 0.5572
Epoch 13/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2983 - accuracy: 0.5370 - val_loss: 1.2257 - val_accuracy: 0.5723
Epoch 14/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2715 - accuracy: 0.5491 - val_loss: 1.2324 - val_accuracy: 0.5622
Epoch 15/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2420 - accuracy: 0.5603 - val_loss: 1.1655 - val_accuracy: 0.5888
Epoch 16/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2230 - accuracy: 0.5673 - val_loss: 1.1756 - val_accuracy: 0.5814
Epoch 17/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1984 - accuracy: 0.5773 - val_loss: 1.1747 - val_accuracy: 0.5832
Epoch 18/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1712 - accuracy: 0.5880 - val_loss: 1.1315 - val_accuracy: 0.5977
Epoch 19/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1567 - accuracy: 0.5921 - val_loss: 1.1624 - val_accuracy: 0.5857
Epoch 20/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1348 - accuracy: 0.6004 - val_loss: 1.1426 - val_accuracy: 0.5957
Epoch 21/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1175 - accuracy: 0.6066 - val_loss: 1.1944 - val_accuracy: 0.5799
Epoch 22/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1007 - accuracy: 0.6111 - val_loss: 1.0837 - val_accuracy: 0.6237
Epoch 23/25
196/196 [==============================] - 4s 21ms/step - loss: 1.0834 - accuracy: 0.6183 - val_loss: 1.0232 - val_accuracy: 0.6417
Epoch 24/25
196/196 [==============================] - 4s 21ms/step - loss: 1.0661 - accuracy: 0.6264 - val_loss: 1.0622 - val_accuracy: 0.6340
Epoch 25/25
196/196 [==============================] - 4s 21ms/step - loss: 1.0513 - accuracy: 0.6337 - val_loss: 1.0031 - val_accuracy: 0.6522
CPU times: user 1min 28s, sys: 7.5 s, total: 1min 36s
Wall time: 1min 43s
313/313 [==============================] - 1s 2ms/step - loss: 1.0031 - accuracy: 0.6522
Test loss: 1.0031086206436157
Test accuracy: 0.6521999835968018

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
2024-02-03 12:11:27.684163: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
196/196 [==============================] - ETA: 0s - loss: 2.0665 - accuracy: 0.2299
W0000 00:00:1706962297.665462   11344 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
196/196 [==============================] - 11s 26ms/step - loss: 2.0665 - accuracy: 0.2299 - val_loss: 1.8395 - val_accuracy: 0.3634
Epoch 1/25
196/196 [==============================] - 8s 41ms/step - loss: 2.1082 - accuracy: 0.2092 - val_loss: 1.8596 - val_accuracy: 0.3510
Epoch 2/25
 10/196 [>.............................] - ETA: 3s - loss: 1.8823 - accuracy: 0.3086
W0000 00:00:1706962307.649981   11347 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
196/196 [==============================] - 4s 18ms/step - loss: 1.8077 - accuracy: 0.3446 - val_loss: 1.8004 - val_accuracy: 0.3508
Epoch 3/25
196/196 [==============================] - 4s 18ms/step - loss: 1.6915 - accuracy: 0.3898 - val_loss: 1.5908 - val_accuracy: 0.4312
Epoch 4/25
196/196 [==============================] - 4s 18ms/step - loss: 1.6194 - accuracy: 0.4138 - val_loss: 1.5250 - val_accuracy: 0.4498
Epoch 5/25
196/196 [==============================] - 4s 19ms/step - loss: 1.5676 - accuracy: 0.4337 - val_loss: 1.4679 - val_accuracy: 0.4681
Epoch 6/25
196/196 [==============================] - 4s 18ms/step - loss: 1.5195 - accuracy: 0.4513 - val_loss: 1.4242 - val_accuracy: 0.4834
Epoch 7/25
196/196 [==============================] - 4s 18ms/step - loss: 1.4791 - accuracy: 0.4650 - val_loss: 1.4116 - val_accuracy: 0.4899
Epoch 8/25
196/196 [==============================] - 4s 18ms/step - loss: 1.4394 - accuracy: 0.4815 - val_loss: 1.3532 - val_accuracy: 0.5117
Epoch 9/25
196/196 [==============================] - 4s 18ms/step - loss: 1.4037 - accuracy: 0.4945 - val_loss: 1.3709 - val_accuracy: 0.5084
Epoch 10/25
196/196 [==============================] - 4s 18ms/step - loss: 1.3781 - accuracy: 0.5076 - val_loss: 1.3205 - val_accuracy: 0.5250
Epoch 11/25
196/196 [==============================] - 4s 18ms/step - loss: 1.3456 - accuracy: 0.5163 - val_loss: 1.3173 - val_accuracy: 0.5320
Epoch 12/25
196/196 [==============================] - 4s 18ms/step - loss: 1.3203 - accuracy: 0.5282 - val_loss: 1.2930 - val_accuracy: 0.5412
Epoch 13/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2967 - accuracy: 0.5367 - val_loss: 1.2573 - val_accuracy: 0.5501
Epoch 14/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2742 - accuracy: 0.5455 - val_loss: 1.1872 - val_accuracy: 0.5811
Epoch 15/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2500 - accuracy: 0.5586 - val_loss: 1.1979 - val_accuracy: 0.5739
Epoch 16/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2302 - accuracy: 0.5633 - val_loss: 1.1803 - val_accuracy: 0.5804
Epoch 17/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2060 - accuracy: 0.5733 - val_loss: 1.1340 - val_accuracy: 0.6020
Epoch 18/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1860 - accuracy: 0.5840 - val_loss: 1.1127 - val_accuracy: 0.6107
Epoch 19/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1683 - accuracy: 0.5890 - val_loss: 1.0988 - val_accuracy: 0.6149
Epoch 20/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1488 - accuracy: 0.5953 - val_loss: 1.1171 - val_accuracy: 0.6107
Epoch 21/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1294 - accuracy: 0.6020 - val_loss: 1.0697 - val_accuracy: 0.6267
Epoch 22/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1118 - accuracy: 0.6109 - val_loss: 1.0354 - val_accuracy: 0.6417
Epoch 23/25
196/196 [==============================] - 4s 18ms/step - loss: 1.0981 - accuracy: 0.6128 - val_loss: 1.0400 - val_accuracy: 0.6375
Epoch 24/25
196/196 [==============================] - 4s 18ms/step - loss: 1.0794 - accuracy: 0.6191 - val_loss: 1.0348 - val_accuracy: 0.6398
Epoch 25/25
196/196 [==============================] - 4s 18ms/step - loss: 1.0645 - accuracy: 0.6256 - val_loss: 1.0463 - val_accuracy: 0.6402
CPU times: user 43.4 s, sys: 8.37 s, total: 51.8 s
Wall time: 1min 36s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.