Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

You will load and normalize the dataset using the TensorFlow Datasets (TFDS) API. First, install/upgrade TensorFlow and TFDS:

pip install -U -q tensorflow tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  result = tfds.load('cifar10', batch_size = -1)
  (x_train, y_train) = result['train']['image'],result['train']['label']
  (x_test, y_test) = result['test']['image'],result['test']['label']

  x_train = x_train.numpy().astype('float32') / 256
  x_test = x_test.numpy().astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1768305996.228973    9833 gpu_device.cc:2020] Created device /device:GPU:0 with 13680 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
I0000 00:00:1768305996.231191    9833 gpu_device.cc:2020] Created device /device:GPU:1 with 13756 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5
I0000 00:00:1768305996.233351    9833 gpu_device.cc:2020] Created device /device:GPU:2 with 13756 MB memory:  -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5
I0000 00:00:1768305996.235499    9833 gpu_device.cc:2020] Created device /device:GPU:3 with 13756 MB memory:  -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5
I0000 00:00:1768305997.539283    9833 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13680 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
I0000 00:00:1768305997.541054    9833 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5
I0000 00:00:1768305997.542867    9833 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory:  -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5
I0000 00:00:1768305997.544727    9833 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory:  -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/convolutional/base_conv.py:113: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
7/196 ━━━━━━━━━━━━━━━━━━━━ 4s 23ms/step - accuracy: 0.0945 - loss: 2.3061
I0000 00:00:1768306010.012884    9994 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
196/196 ━━━━━━━━━━━━━━━━━━━━ 15s 46ms/step - accuracy: 0.1796 - loss: 2.1855 - val_accuracy: 0.3010 - val_loss: 1.9683
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.1618 - loss: 2.2178 - val_accuracy: 0.3256 - val_loss: 1.9150
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.3056 - loss: 1.9096 - val_accuracy: 0.3772 - val_loss: 1.7256
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.3673 - loss: 1.7384 - val_accuracy: 0.4216 - val_loss: 1.5983
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.4005 - loss: 1.6516 - val_accuracy: 0.4462 - val_loss: 1.5322
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.4194 - loss: 1.5950 - val_accuracy: 0.4612 - val_loss: 1.4874
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.4333 - loss: 1.5527 - val_accuracy: 0.4527 - val_loss: 1.4858
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.4488 - loss: 1.5189 - val_accuracy: 0.4909 - val_loss: 1.4098
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.4695 - loss: 1.4715 - val_accuracy: 0.5002 - val_loss: 1.3865
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.4810 - loss: 1.4375 - val_accuracy: 0.5181 - val_loss: 1.3379
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.4898 - loss: 1.4099 - val_accuracy: 0.5290 - val_loss: 1.3194
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5033 - loss: 1.3807 - val_accuracy: 0.4878 - val_loss: 1.4089
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5123 - loss: 1.3612 - val_accuracy: 0.5375 - val_loss: 1.3042
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5252 - loss: 1.3278 - val_accuracy: 0.5487 - val_loss: 1.2698
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5350 - loss: 1.3139 - val_accuracy: 0.5584 - val_loss: 1.2416
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5447 - loss: 1.2795 - val_accuracy: 0.5646 - val_loss: 1.2166
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5515 - loss: 1.2644 - val_accuracy: 0.5537 - val_loss: 1.2843
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5646 - loss: 1.2408 - val_accuracy: 0.5653 - val_loss: 1.2134
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5670 - loss: 1.2210 - val_accuracy: 0.5959 - val_loss: 1.1506
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5720 - loss: 1.2067 - val_accuracy: 0.5944 - val_loss: 1.1457
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5788 - loss: 1.1834 - val_accuracy: 0.5964 - val_loss: 1.1598
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5854 - loss: 1.1749 - val_accuracy: 0.5956 - val_loss: 1.1349
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5918 - loss: 1.1470 - val_accuracy: 0.6121 - val_loss: 1.0904
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.6026 - loss: 1.1268 - val_accuracy: 0.6252 - val_loss: 1.0738
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.6048 - loss: 1.1206 - val_accuracy: 0.6177 - val_loss: 1.0776
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.6105 - loss: 1.1024 - val_accuracy: 0.6433 - val_loss: 1.0215
CPU times: user 39.8 s, sys: 11.9 s, total: 51.7 s
Wall time: 1min 37s
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.6487 - loss: 1.0139
Test loss: 1.0214592218399048
Test accuracy: 0.6432999968528748

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 ━━━━━━━━━━━━━━━━━━━━ 12s 40ms/step - accuracy: 0.1708 - loss: 2.2064 - val_accuracy: 0.3319 - val_loss: 1.8737
Epoch 1/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 21ms/step - accuracy: 0.1540 - loss: 2.2399 - val_accuracy: 0.3022 - val_loss: 1.9470
Epoch 2/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.3144 - loss: 1.8921 - val_accuracy: 0.3909 - val_loss: 1.7211
Epoch 3/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.3744 - loss: 1.7343 - val_accuracy: 0.4143 - val_loss: 1.6210
Epoch 4/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.3952 - loss: 1.6527 - val_accuracy: 0.4351 - val_loss: 1.5757
Epoch 5/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.4156 - loss: 1.6021 - val_accuracy: 0.4590 - val_loss: 1.5038
Epoch 6/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.4352 - loss: 1.5476 - val_accuracy: 0.4815 - val_loss: 1.4379
Epoch 7/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.4609 - loss: 1.4864 - val_accuracy: 0.5016 - val_loss: 1.3768
Epoch 8/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.4811 - loss: 1.4435 - val_accuracy: 0.4958 - val_loss: 1.3862
Epoch 9/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.4937 - loss: 1.4059 - val_accuracy: 0.5320 - val_loss: 1.3177
Epoch 10/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5077 - loss: 1.3783 - val_accuracy: 0.5498 - val_loss: 1.2750
Epoch 11/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5153 - loss: 1.3527 - val_accuracy: 0.5246 - val_loss: 1.3406
Epoch 12/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5303 - loss: 1.3170 - val_accuracy: 0.5543 - val_loss: 1.2578
Epoch 13/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5339 - loss: 1.2964 - val_accuracy: 0.5692 - val_loss: 1.2180
Epoch 14/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5508 - loss: 1.2734 - val_accuracy: 0.5581 - val_loss: 1.2417
Epoch 15/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5535 - loss: 1.2552 - val_accuracy: 0.5869 - val_loss: 1.1689
Epoch 16/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5660 - loss: 1.2196 - val_accuracy: 0.5934 - val_loss: 1.1592
Epoch 17/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5773 - loss: 1.1996 - val_accuracy: 0.6070 - val_loss: 1.1230
Epoch 18/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - accuracy: 0.5818 - loss: 1.1813 - val_accuracy: 0.6079 - val_loss: 1.1179
Epoch 19/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5933 - loss: 1.1545 - val_accuracy: 0.6224 - val_loss: 1.0836
Epoch 20/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.5927 - loss: 1.1440 - val_accuracy: 0.6126 - val_loss: 1.1001
Epoch 21/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.6070 - loss: 1.1237 - val_accuracy: 0.6121 - val_loss: 1.1268
Epoch 22/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.6141 - loss: 1.1090 - val_accuracy: 0.6406 - val_loss: 1.0256
Epoch 23/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.6180 - loss: 1.0898 - val_accuracy: 0.6403 - val_loss: 1.0260
Epoch 24/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.6258 - loss: 1.0724 - val_accuracy: 0.6498 - val_loss: 1.0095
Epoch 25/25
196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - accuracy: 0.6280 - loss: 1.0604 - val_accuracy: 0.6606 - val_loss: 0.9856
CPU times: user 39.4 s, sys: 11.6 s, total: 51 s
Wall time: 1min 38s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.