Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

You will load and normalize the dataset using the TensorFlow Datasets (TFDS) API. First, install/upgrade TensorFlow and TFDS:

pip install -U -q tensorflow tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
2022-12-14 12:09:56.173210: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:09:56.173312: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:09:56.173322: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  result = tfds.load('cifar10', batch_size = -1)
  (x_train, y_train) = result['train']['image'],result['train']['label']
  (x_test, y_test) = result['test']['image'],result['test']['label']

  x_train = x_train.numpy().astype('float32') / 256
  x_test = x_test.numpy().astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
2022-12-14 12:10:07.149286: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
196/196 [==============================] - 9s 16ms/step - loss: 2.1645 - accuracy: 0.1917 - val_loss: 1.9695 - val_accuracy: 0.2969
Epoch 1/25
196/196 [==============================] - 3s 15ms/step - loss: 2.1978 - accuracy: 0.1731 - val_loss: 1.9897 - val_accuracy: 0.2836
Epoch 2/25
196/196 [==============================] - 3s 14ms/step - loss: 1.9248 - accuracy: 0.2987 - val_loss: 1.7995 - val_accuracy: 0.3703
Epoch 3/25
196/196 [==============================] - 3s 14ms/step - loss: 1.7647 - accuracy: 0.3615 - val_loss: 1.6578 - val_accuracy: 0.4068
Epoch 4/25
196/196 [==============================] - 3s 14ms/step - loss: 1.6663 - accuracy: 0.3971 - val_loss: 1.5586 - val_accuracy: 0.4476
Epoch 5/25
196/196 [==============================] - 3s 14ms/step - loss: 1.5925 - accuracy: 0.4213 - val_loss: 1.5188 - val_accuracy: 0.4545
Epoch 6/25
196/196 [==============================] - 3s 14ms/step - loss: 1.5364 - accuracy: 0.4425 - val_loss: 1.4577 - val_accuracy: 0.4715
Epoch 7/25
196/196 [==============================] - 3s 14ms/step - loss: 1.4824 - accuracy: 0.4637 - val_loss: 1.4128 - val_accuracy: 0.4902
Epoch 8/25
196/196 [==============================] - 3s 14ms/step - loss: 1.4393 - accuracy: 0.4805 - val_loss: 1.3698 - val_accuracy: 0.5051
Epoch 9/25
196/196 [==============================] - 3s 14ms/step - loss: 1.4070 - accuracy: 0.4954 - val_loss: 1.4210 - val_accuracy: 0.4957
Epoch 10/25
196/196 [==============================] - 3s 14ms/step - loss: 1.3745 - accuracy: 0.5083 - val_loss: 1.2968 - val_accuracy: 0.5332
Epoch 11/25
196/196 [==============================] - 3s 14ms/step - loss: 1.3462 - accuracy: 0.5167 - val_loss: 1.2945 - val_accuracy: 0.5352
Epoch 12/25
196/196 [==============================] - 3s 14ms/step - loss: 1.3152 - accuracy: 0.5308 - val_loss: 1.2307 - val_accuracy: 0.5668
Epoch 13/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2906 - accuracy: 0.5405 - val_loss: 1.2417 - val_accuracy: 0.5639
Epoch 14/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2654 - accuracy: 0.5512 - val_loss: 1.2161 - val_accuracy: 0.5693
Epoch 15/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2450 - accuracy: 0.5566 - val_loss: 1.1675 - val_accuracy: 0.5898
Epoch 16/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2223 - accuracy: 0.5665 - val_loss: 1.2303 - val_accuracy: 0.5673
Epoch 17/25
196/196 [==============================] - 3s 14ms/step - loss: 1.2006 - accuracy: 0.5772 - val_loss: 1.1869 - val_accuracy: 0.5848
Epoch 18/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1762 - accuracy: 0.5839 - val_loss: 1.1188 - val_accuracy: 0.6013
Epoch 19/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1646 - accuracy: 0.5894 - val_loss: 1.0905 - val_accuracy: 0.6201
Epoch 20/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1462 - accuracy: 0.5964 - val_loss: 1.0872 - val_accuracy: 0.6234
Epoch 21/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1294 - accuracy: 0.6040 - val_loss: 1.1073 - val_accuracy: 0.6132
Epoch 22/25
196/196 [==============================] - 3s 14ms/step - loss: 1.1062 - accuracy: 0.6086 - val_loss: 1.0405 - val_accuracy: 0.6390
Epoch 23/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0887 - accuracy: 0.6185 - val_loss: 1.0190 - val_accuracy: 0.6468
Epoch 24/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0731 - accuracy: 0.6224 - val_loss: 1.0398 - val_accuracy: 0.6400
Epoch 25/25
196/196 [==============================] - 3s 14ms/step - loss: 1.0551 - accuracy: 0.6291 - val_loss: 1.0322 - val_accuracy: 0.6400
CPU times: user 1min 13s, sys: 8.08 s, total: 1min 21s
Wall time: 1min 9s
313/313 [==============================] - 1s 3ms/step - loss: 1.0322 - accuracy: 0.6400
Test loss: 1.0322308540344238
Test accuracy: 0.6399999856948853

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
2022-12-14 12:11:33.240648: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
196/196 [==============================] - 7s 18ms/step - loss: 2.1110 - accuracy: 0.2162 - val_loss: 1.8555 - val_accuracy: 0.3400
Epoch 1/25
196/196 [==============================] - 5s 24ms/step - loss: 2.1736 - accuracy: 0.1900 - val_loss: 1.9028 - val_accuracy: 0.3324
Epoch 2/25
196/196 [==============================] - 2s 12ms/step - loss: 1.8217 - accuracy: 0.3384 - val_loss: 1.7167 - val_accuracy: 0.3816
Epoch 3/25
196/196 [==============================] - 2s 12ms/step - loss: 1.7032 - accuracy: 0.3824 - val_loss: 1.6026 - val_accuracy: 0.4151
Epoch 4/25
196/196 [==============================] - 2s 12ms/step - loss: 1.6385 - accuracy: 0.4032 - val_loss: 1.5448 - val_accuracy: 0.4400
Epoch 5/25
196/196 [==============================] - 2s 12ms/step - loss: 1.5951 - accuracy: 0.4194 - val_loss: 1.5036 - val_accuracy: 0.4539
Epoch 6/25
196/196 [==============================] - 2s 12ms/step - loss: 1.5498 - accuracy: 0.4378 - val_loss: 1.5210 - val_accuracy: 0.4475
Epoch 7/25
196/196 [==============================] - 2s 12ms/step - loss: 1.5144 - accuracy: 0.4514 - val_loss: 1.4251 - val_accuracy: 0.4830
Epoch 8/25
196/196 [==============================] - 2s 12ms/step - loss: 1.4753 - accuracy: 0.4685 - val_loss: 1.3891 - val_accuracy: 0.4954
Epoch 9/25
196/196 [==============================] - 2s 12ms/step - loss: 1.4427 - accuracy: 0.4789 - val_loss: 1.3516 - val_accuracy: 0.5130
Epoch 10/25
196/196 [==============================] - 2s 12ms/step - loss: 1.4068 - accuracy: 0.4936 - val_loss: 1.3314 - val_accuracy: 0.5233
Epoch 11/25
196/196 [==============================] - 2s 12ms/step - loss: 1.3762 - accuracy: 0.5080 - val_loss: 1.3036 - val_accuracy: 0.5315
Epoch 12/25
196/196 [==============================] - 2s 12ms/step - loss: 1.3466 - accuracy: 0.5189 - val_loss: 1.2882 - val_accuracy: 0.5463
Epoch 13/25
196/196 [==============================] - 2s 12ms/step - loss: 1.3187 - accuracy: 0.5296 - val_loss: 1.2779 - val_accuracy: 0.5408
Epoch 14/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2950 - accuracy: 0.5374 - val_loss: 1.2581 - val_accuracy: 0.5628
Epoch 15/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2726 - accuracy: 0.5454 - val_loss: 1.2607 - val_accuracy: 0.5515
Epoch 16/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2480 - accuracy: 0.5536 - val_loss: 1.1792 - val_accuracy: 0.5839
Epoch 17/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2278 - accuracy: 0.5642 - val_loss: 1.1702 - val_accuracy: 0.5875
Epoch 18/25
196/196 [==============================] - 2s 12ms/step - loss: 1.2059 - accuracy: 0.5738 - val_loss: 1.1653 - val_accuracy: 0.5904
Epoch 19/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1876 - accuracy: 0.5805 - val_loss: 1.1801 - val_accuracy: 0.5800
Epoch 20/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1673 - accuracy: 0.5883 - val_loss: 1.1035 - val_accuracy: 0.6141
Epoch 21/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1482 - accuracy: 0.5971 - val_loss: 1.1362 - val_accuracy: 0.6029
Epoch 22/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1304 - accuracy: 0.6047 - val_loss: 1.0972 - val_accuracy: 0.6184
Epoch 23/25
196/196 [==============================] - 2s 12ms/step - loss: 1.1083 - accuracy: 0.6120 - val_loss: 1.0889 - val_accuracy: 0.6203
Epoch 24/25
196/196 [==============================] - 2s 12ms/step - loss: 1.0955 - accuracy: 0.6156 - val_loss: 1.1731 - val_accuracy: 0.5996
Epoch 25/25
196/196 [==============================] - 2s 12ms/step - loss: 1.0793 - accuracy: 0.6211 - val_loss: 1.0527 - val_accuracy: 0.6337
CPU times: user 42.5 s, sys: 7.7 s, total: 50.2 s
Wall time: 1min 1s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.