此页面由 Cloud Translation API 翻译。
Switch to English

使用XLA对CIFAR-10进行分类

在TensorFlow.org上查看 在GitHub上查看源代码

本教程训练了一个TensorFlow模型来对CIFAR-10数据集进行分类,我们使用XLA对其进行编译。

使用Keras API加载并规范化数据集:

 import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.is_gpu_available())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
 
WARNING:tensorflow:From <ipython-input-1-f643ac1e83e4>:4: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

我们根据Keras CIFAR-10示例定义模型:

 def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()
 

我们使用RMSprop优化器训练模型:

 def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
 
196/196 [==============================] - 2s 12ms/step - loss: 2.0389 - accuracy: 0.2518 - val_loss: 1.8420 - val_accuracy: 0.3480
Epoch 1/25
196/196 [==============================] - 2s 10ms/step - loss: 2.0936 - accuracy: 0.2252 - val_loss: 1.8670 - val_accuracy: 0.3368
Epoch 2/25
196/196 [==============================] - 2s 10ms/step - loss: 1.7934 - accuracy: 0.3543 - val_loss: 1.6530 - val_accuracy: 0.4200
Epoch 3/25
196/196 [==============================] - 2s 10ms/step - loss: 1.6627 - accuracy: 0.4002 - val_loss: 1.5638 - val_accuracy: 0.4393
Epoch 4/25
196/196 [==============================] - 2s 10ms/step - loss: 1.5752 - accuracy: 0.4295 - val_loss: 1.4852 - val_accuracy: 0.4641
Epoch 5/25
196/196 [==============================] - 2s 11ms/step - loss: 1.5126 - accuracy: 0.4522 - val_loss: 1.4569 - val_accuracy: 0.4721
Epoch 6/25
196/196 [==============================] - 2s 11ms/step - loss: 1.4577 - accuracy: 0.4760 - val_loss: 1.3810 - val_accuracy: 0.5080
Epoch 7/25
196/196 [==============================] - 2s 11ms/step - loss: 1.4126 - accuracy: 0.4942 - val_loss: 1.3214 - val_accuracy: 0.5319
Epoch 8/25
196/196 [==============================] - 2s 11ms/step - loss: 1.3776 - accuracy: 0.5074 - val_loss: 1.3421 - val_accuracy: 0.5242
Epoch 9/25
196/196 [==============================] - 2s 10ms/step - loss: 1.3430 - accuracy: 0.5210 - val_loss: 1.2628 - val_accuracy: 0.5538
Epoch 10/25
196/196 [==============================] - 2s 10ms/step - loss: 1.3062 - accuracy: 0.5352 - val_loss: 1.2366 - val_accuracy: 0.5605
Epoch 11/25
196/196 [==============================] - 2s 10ms/step - loss: 1.2808 - accuracy: 0.5451 - val_loss: 1.2469 - val_accuracy: 0.5559
Epoch 12/25
196/196 [==============================] - 2s 10ms/step - loss: 1.2557 - accuracy: 0.5557 - val_loss: 1.2332 - val_accuracy: 0.5578
Epoch 13/25
196/196 [==============================] - 2s 10ms/step - loss: 1.2255 - accuracy: 0.5687 - val_loss: 1.1787 - val_accuracy: 0.5853
Epoch 14/25
196/196 [==============================] - 2s 10ms/step - loss: 1.2016 - accuracy: 0.5752 - val_loss: 1.1687 - val_accuracy: 0.5915
Epoch 15/25
196/196 [==============================] - 2s 10ms/step - loss: 1.1779 - accuracy: 0.5844 - val_loss: 1.1018 - val_accuracy: 0.6178
Epoch 16/25
196/196 [==============================] - 2s 10ms/step - loss: 1.1563 - accuracy: 0.5925 - val_loss: 1.1064 - val_accuracy: 0.6095
Epoch 17/25
196/196 [==============================] - 2s 10ms/step - loss: 1.1338 - accuracy: 0.6008 - val_loss: 1.1121 - val_accuracy: 0.6082
Epoch 18/25
196/196 [==============================] - 2s 10ms/step - loss: 1.1155 - accuracy: 0.6073 - val_loss: 1.0650 - val_accuracy: 0.6271
Epoch 19/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0934 - accuracy: 0.6173 - val_loss: 1.0252 - val_accuracy: 0.6433
Epoch 20/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0769 - accuracy: 0.6201 - val_loss: 1.0635 - val_accuracy: 0.6239
Epoch 21/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0628 - accuracy: 0.6260 - val_loss: 1.0418 - val_accuracy: 0.6323
Epoch 22/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0421 - accuracy: 0.6360 - val_loss: 0.9927 - val_accuracy: 0.6501
Epoch 23/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0247 - accuracy: 0.6405 - val_loss: 0.9640 - val_accuracy: 0.6639
Epoch 24/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0087 - accuracy: 0.6469 - val_loss: 0.9746 - val_accuracy: 0.6593
Epoch 25/25
196/196 [==============================] - 2s 10ms/step - loss: 0.9942 - accuracy: 0.6550 - val_loss: 0.9331 - val_accuracy: 0.6755
CPU times: user 53.4 s, sys: 13.7 s, total: 1min 7s
Wall time: 50.1 s
313/313 [==============================] - 1s 2ms/step - loss: 0.9331 - accuracy: 0.6755
Test loss: 0.9330735206604004
Test accuracy: 0.6754999756813049

现在,让我们使用XLA编译器再次训练模型。为了在应用程序中间启用编译器,我们需要重置Keras会话。

 # We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
 
196/196 [==============================] - 3s 14ms/step - loss: 2.0468 - accuracy: 0.2478 - val_loss: 1.8286 - val_accuracy: 0.3477
Epoch 1/25
196/196 [==============================] - 4s 20ms/step - loss: 2.1118 - accuracy: 0.2173 - val_loss: 1.8766 - val_accuracy: 0.3271
Epoch 2/25
196/196 [==============================] - 2s 9ms/step - loss: 1.7911 - accuracy: 0.3535 - val_loss: 1.7646 - val_accuracy: 0.3578
Epoch 3/25
196/196 [==============================] - 2s 9ms/step - loss: 1.6516 - accuracy: 0.4039 - val_loss: 1.5610 - val_accuracy: 0.4476
Epoch 4/25
196/196 [==============================] - 2s 9ms/step - loss: 1.5630 - accuracy: 0.4367 - val_loss: 1.4555 - val_accuracy: 0.4805
Epoch 5/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4951 - accuracy: 0.4619 - val_loss: 1.4407 - val_accuracy: 0.4915
Epoch 6/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4416 - accuracy: 0.4816 - val_loss: 1.3517 - val_accuracy: 0.5167
Epoch 7/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3954 - accuracy: 0.5016 - val_loss: 1.3714 - val_accuracy: 0.5078
Epoch 8/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3613 - accuracy: 0.5115 - val_loss: 1.3856 - val_accuracy: 0.5067
Epoch 9/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3253 - accuracy: 0.5278 - val_loss: 1.2802 - val_accuracy: 0.5500
Epoch 10/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2961 - accuracy: 0.5413 - val_loss: 1.2180 - val_accuracy: 0.5723
Epoch 11/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2669 - accuracy: 0.5504 - val_loss: 1.1723 - val_accuracy: 0.5852
Epoch 12/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2369 - accuracy: 0.5615 - val_loss: 1.2802 - val_accuracy: 0.5529
Epoch 13/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2134 - accuracy: 0.5730 - val_loss: 1.1564 - val_accuracy: 0.5916
Epoch 14/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1893 - accuracy: 0.5812 - val_loss: 1.1650 - val_accuracy: 0.5876
Epoch 15/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1673 - accuracy: 0.5891 - val_loss: 1.1040 - val_accuracy: 0.6118
Epoch 16/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1486 - accuracy: 0.5947 - val_loss: 1.1335 - val_accuracy: 0.6024
Epoch 17/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1260 - accuracy: 0.6048 - val_loss: 1.0717 - val_accuracy: 0.6218
Epoch 18/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1079 - accuracy: 0.6134 - val_loss: 1.0747 - val_accuracy: 0.6197
Epoch 19/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0933 - accuracy: 0.6170 - val_loss: 1.0317 - val_accuracy: 0.6406
Epoch 20/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0749 - accuracy: 0.6246 - val_loss: 1.0081 - val_accuracy: 0.6489
Epoch 21/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0566 - accuracy: 0.6312 - val_loss: 1.0412 - val_accuracy: 0.6382
Epoch 22/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0456 - accuracy: 0.6338 - val_loss: 1.0483 - val_accuracy: 0.6316
Epoch 23/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0289 - accuracy: 0.6401 - val_loss: 0.9587 - val_accuracy: 0.6687
Epoch 24/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0142 - accuracy: 0.6466 - val_loss: 0.9975 - val_accuracy: 0.6502
Epoch 25/25
196/196 [==============================] - 2s 9ms/step - loss: 0.9995 - accuracy: 0.6510 - val_loss: 0.9480 - val_accuracy: 0.6708
CPU times: user 49.2 s, sys: 9.27 s, total: 58.5 s
Wall time: 48 s

在配备Titan V GPU和Intel Xeon E5-2690 CPU的计算机上,速度提高了约1.17倍。