Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

You will load and normalize the dataset using the TensorFlow Datasets (TFDS) API. First, install/upgrade TensorFlow and TFDS:

pip install -U -q tensorflow tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
2025-02-05 12:08:59.954246: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1738757339.975741   12814 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738757339.982371   12814 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  result = tfds.load('cifar10', batch_size = -1)
  (x_train, y_train) = result['train']['image'],result['train']['label']
  (x_test, y_test) = result['test']['image'],result['test']['label']

  x_train = x_train.numpy().astype('float32') / 256
  x_test = x_test.numpy().astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
I0000 00:00:1738757344.732080   12814 gpu_device.cc:2022] Created device /device:GPU:0 with 13638 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
I0000 00:00:1738757344.734322   12814 gpu_device.cc:2022] Created device /device:GPU:1 with 13756 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5
I0000 00:00:1738757344.736427   12814 gpu_device.cc:2022] Created device /device:GPU:2 with 13756 MB memory:  -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5
I0000 00:00:1738757344.738769   12814 gpu_device.cc:2022] Created device /device:GPU:3 with 13756 MB memory:  -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5
I0000 00:00:1738757345.885857   12814 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13638 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
I0000 00:00:1738757345.887680   12814 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5
I0000 00:00:1738757345.889413   12814 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory:  -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5
I0000 00:00:1738757345.891329   12814 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory:  -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1738757354.182056   12980 service.cc:148] XLA service 0x7f1000003440 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1738757354.182091   12980 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1738757354.182096   12980 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1738757354.182098   12980 service.cc:156]   StreamExecutor device (2): Tesla T4, Compute Capability 7.5
I0000 00:00:1738757354.182101   12980 service.cc:156]   StreamExecutor device (3): Tesla T4, Compute Capability 7.5
I0000 00:00:1738757354.390398   12980 cuda_dnn.cc:529] Loaded cuDNN version 90300
7/196 ━━━━━━━━━━━━━━━━━━━━ 4s 22ms/step - accuracy: 0.1056 - loss: 2.3034
I0000 00:00:1738757358.396682   12980 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
196/196 ━━━━━━━━━━━━━━━━━━━━ 13s 42ms/step - accuracy: 0.1505 - loss: 2.2420 - val_accuracy: 0.3089 - val_loss: 1.9346
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.1385 - loss: 2.2607 - val_accuracy: 0.2789 - val_loss: 2.0079
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3024 - loss: 1.9314 - val_accuracy: 0.3875 - val_loss: 1.7382
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3708 - loss: 1.7491 - val_accuracy: 0.4160 - val_loss: 1.6325
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3983 - loss: 1.6634 - val_accuracy: 0.4265 - val_loss: 1.6075
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4288 - loss: 1.5921 - val_accuracy: 0.4710 - val_loss: 1.4838
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4390 - loss: 1.5512 - val_accuracy: 0.4799 - val_loss: 1.4414
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4580 - loss: 1.5086 - val_accuracy: 0.4608 - val_loss: 1.4838
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4756 - loss: 1.4536 - val_accuracy: 0.5008 - val_loss: 1.3865
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4940 - loss: 1.4113 - val_accuracy: 0.5166 - val_loss: 1.3489
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5060 - loss: 1.3833 - val_accuracy: 0.5104 - val_loss: 1.3868
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5192 - loss: 1.3459 - val_accuracy: 0.5248 - val_loss: 1.3238
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5283 - loss: 1.3231 - val_accuracy: 0.5371 - val_loss: 1.3070
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5381 - loss: 1.2904 - val_accuracy: 0.5717 - val_loss: 1.2051
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5502 - loss: 1.2741 - val_accuracy: 0.5833 - val_loss: 1.1783
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5560 - loss: 1.2512 - val_accuracy: 0.5868 - val_loss: 1.1719
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5626 - loss: 1.2174 - val_accuracy: 0.5900 - val_loss: 1.1521
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5737 - loss: 1.2039 - val_accuracy: 0.5817 - val_loss: 1.1801
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5842 - loss: 1.1768 - val_accuracy: 0.5845 - val_loss: 1.1640
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5824 - loss: 1.1674 - val_accuracy: 0.6156 - val_loss: 1.0896
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5966 - loss: 1.1431 - val_accuracy: 0.6231 - val_loss: 1.0713
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.6014 - loss: 1.1251 - val_accuracy: 0.6181 - val_loss: 1.0757
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.6088 - loss: 1.1064 - val_accuracy: 0.6248 - val_loss: 1.0740
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.6181 - loss: 1.0886 - val_accuracy: 0.6251 - val_loss: 1.0754
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.6236 - loss: 1.0690 - val_accuracy: 0.6439 - val_loss: 1.0223
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.6266 - loss: 1.0614 - val_accuracy: 0.6452 - val_loss: 1.0219
CPU times: user 35.9 s, sys: 11.5 s, total: 47.4 s
Wall time: 1min 27s
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.6534 - loss: 1.0154
Test loss: 1.021875023841858
Test accuracy: 0.6452000141143799

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 ━━━━━━━━━━━━━━━━━━━━ 10s 33ms/step - accuracy: 0.1656 - loss: 2.1969 - val_accuracy: 0.3165 - val_loss: 1.9108
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.1503 - loss: 2.2347 - val_accuracy: 0.2917 - val_loss: 1.9783
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.2997 - loss: 1.9190 - val_accuracy: 0.3922 - val_loss: 1.7206
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3645 - loss: 1.7533 - val_accuracy: 0.4076 - val_loss: 1.6393
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.3938 - loss: 1.6698 - val_accuracy: 0.4432 - val_loss: 1.5509
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.4147 - loss: 1.6079 - val_accuracy: 0.4534 - val_loss: 1.5110
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - accuracy: 0.4324 - loss: 1.5632 - val_accuracy: 0.4591 - val_loss: 1.5296
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - accuracy: 0.4557 - loss: 1.5141 - val_accuracy: 0.4902 - val_loss: 1.4239
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - accuracy: 0.4678 - loss: 1.4740 - val_accuracy: 0.5097 - val_loss: 1.3736
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.4872 - loss: 1.4255 - val_accuracy: 0.5197 - val_loss: 1.3444
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4998 - loss: 1.3958 - val_accuracy: 0.5301 - val_loss: 1.3210
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5108 - loss: 1.3722 - val_accuracy: 0.5438 - val_loss: 1.2873
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5228 - loss: 1.3352 - val_accuracy: 0.5313 - val_loss: 1.3096
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5316 - loss: 1.3116 - val_accuracy: 0.5286 - val_loss: 1.3313
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 18ms/step - accuracy: 0.5401 - loss: 1.2948 - val_accuracy: 0.5789 - val_loss: 1.2079
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5501 - loss: 1.2700 - val_accuracy: 0.5778 - val_loss: 1.1873
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5586 - loss: 1.2402 - val_accuracy: 0.5738 - val_loss: 1.2303
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5654 - loss: 1.2248 - val_accuracy: 0.5901 - val_loss: 1.1678
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5728 - loss: 1.2015 - val_accuracy: 0.5812 - val_loss: 1.1794
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5766 - loss: 1.1901 - val_accuracy: 0.5914 - val_loss: 1.1562
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5902 - loss: 1.1643 - val_accuracy: 0.6170 - val_loss: 1.1055
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.5943 - loss: 1.1474 - val_accuracy: 0.6229 - val_loss: 1.0880
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.6032 - loss: 1.1305 - val_accuracy: 0.5989 - val_loss: 1.1316
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.6136 - loss: 1.1134 - val_accuracy: 0.6292 - val_loss: 1.0637
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.6186 - loss: 1.0913 - val_accuracy: 0.6158 - val_loss: 1.0874
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.6219 - loss: 1.0754 - val_accuracy: 0.6393 - val_loss: 1.0306
CPU times: user 36.3 s, sys: 10.2 s, total: 46.5 s
Wall time: 1min 28s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.