Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

You will load and normalize the dataset using the TensorFlow Datasets (TFDS) API. First, install/upgrade TensorFlow and TFDS:

pip install -U -q tensorflow tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
2024-02-01 01:01:18.525467: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-01 01:01:18.525511: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-01 01:01:18.527144: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  result = tfds.load('cifar10', batch_size = -1)
  (x_train, y_train) = result['train']['image'],result['train']['label']
  (x_test, y_test) = result['test']['image'],result['test']['label']

  x_train = x_train.numpy().astype('float32') / 256
  x_test = x_test.numpy().astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()

We define the model, adapted from the Keras CIFAR-10 example:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

We train the model using the RMSprop optimizer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
2024-02-01 01:01:30.239121: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1706749292.537084   11525 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
196/196 [==============================] - 10s 28ms/step - loss: 2.0984 - accuracy: 0.2159 - val_loss: 1.9261 - val_accuracy: 0.3098
Epoch 1/25
196/196 [==============================] - 4s 23ms/step - loss: 2.1367 - accuracy: 0.2002 - val_loss: 1.9638 - val_accuracy: 0.2888
Epoch 2/25
196/196 [==============================] - 4s 21ms/step - loss: 1.9036 - accuracy: 0.3084 - val_loss: 1.7742 - val_accuracy: 0.3659
Epoch 3/25
196/196 [==============================] - 4s 21ms/step - loss: 1.7328 - accuracy: 0.3719 - val_loss: 1.6131 - val_accuracy: 0.4219
Epoch 4/25
196/196 [==============================] - 4s 21ms/step - loss: 1.6326 - accuracy: 0.4134 - val_loss: 1.5368 - val_accuracy: 0.4508
Epoch 5/25
196/196 [==============================] - 4s 21ms/step - loss: 1.5631 - accuracy: 0.4344 - val_loss: 1.4908 - val_accuracy: 0.4645
Epoch 6/25
196/196 [==============================] - 4s 21ms/step - loss: 1.5135 - accuracy: 0.4563 - val_loss: 1.4733 - val_accuracy: 0.4758
Epoch 7/25
196/196 [==============================] - 4s 21ms/step - loss: 1.4699 - accuracy: 0.4704 - val_loss: 1.3938 - val_accuracy: 0.5037
Epoch 8/25
196/196 [==============================] - 4s 21ms/step - loss: 1.4294 - accuracy: 0.4886 - val_loss: 1.3318 - val_accuracy: 0.5265
Epoch 9/25
196/196 [==============================] - 4s 21ms/step - loss: 1.4076 - accuracy: 0.4955 - val_loss: 1.3215 - val_accuracy: 0.5316
Epoch 10/25
196/196 [==============================] - 4s 21ms/step - loss: 1.3755 - accuracy: 0.5082 - val_loss: 1.3157 - val_accuracy: 0.5308
Epoch 11/25
196/196 [==============================] - 4s 21ms/step - loss: 1.3478 - accuracy: 0.5184 - val_loss: 1.2730 - val_accuracy: 0.5489
Epoch 12/25
196/196 [==============================] - 4s 21ms/step - loss: 1.3248 - accuracy: 0.5293 - val_loss: 1.2447 - val_accuracy: 0.5639
Epoch 13/25
196/196 [==============================] - 4s 21ms/step - loss: 1.3011 - accuracy: 0.5386 - val_loss: 1.2492 - val_accuracy: 0.5544
Epoch 14/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2784 - accuracy: 0.5467 - val_loss: 1.1970 - val_accuracy: 0.5772
Epoch 15/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2592 - accuracy: 0.5542 - val_loss: 1.2136 - val_accuracy: 0.5766
Epoch 16/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2344 - accuracy: 0.5646 - val_loss: 1.2655 - val_accuracy: 0.5532
Epoch 17/25
196/196 [==============================] - 4s 21ms/step - loss: 1.2131 - accuracy: 0.5711 - val_loss: 1.1786 - val_accuracy: 0.5904
Epoch 18/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1954 - accuracy: 0.5779 - val_loss: 1.1435 - val_accuracy: 0.5953
Epoch 19/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1753 - accuracy: 0.5852 - val_loss: 1.1379 - val_accuracy: 0.6042
Epoch 20/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1546 - accuracy: 0.5927 - val_loss: 1.1171 - val_accuracy: 0.6077
Epoch 21/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1373 - accuracy: 0.6001 - val_loss: 1.0871 - val_accuracy: 0.6160
Epoch 22/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1177 - accuracy: 0.6062 - val_loss: 1.0877 - val_accuracy: 0.6210
Epoch 23/25
196/196 [==============================] - 4s 21ms/step - loss: 1.1047 - accuracy: 0.6135 - val_loss: 1.0735 - val_accuracy: 0.6238
Epoch 24/25
196/196 [==============================] - 4s 21ms/step - loss: 1.0810 - accuracy: 0.6192 - val_loss: 1.0215 - val_accuracy: 0.6488
Epoch 25/25
196/196 [==============================] - 4s 21ms/step - loss: 1.0672 - accuracy: 0.6260 - val_loss: 1.0615 - val_accuracy: 0.6316
CPU times: user 1min 28s, sys: 6.97 s, total: 1min 35s
Wall time: 1min 44s
313/313 [==============================] - 1s 2ms/step - loss: 1.0615 - accuracy: 0.6316
Test loss: 1.0615112781524658
Test accuracy: 0.631600022315979

Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
2024-02-01 01:03:31.448765: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
196/196 [==============================] - ETA: 0s - loss: 2.0583 - accuracy: 0.2398
W0000 00:00:1706749421.351773   11512 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
196/196 [==============================] - 11s 26ms/step - loss: 2.0583 - accuracy: 0.2398 - val_loss: 1.8238 - val_accuracy: 0.3510
Epoch 1/25
196/196 [==============================] - 8s 40ms/step - loss: 2.1202 - accuracy: 0.2086 - val_loss: 1.8862 - val_accuracy: 0.3217
Epoch 2/25
 10/196 [>.............................] - ETA: 3s - loss: 1.8810 - accuracy: 0.3113
W0000 00:00:1706749431.271703   11530 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update
196/196 [==============================] - 4s 18ms/step - loss: 1.7954 - accuracy: 0.3482 - val_loss: 1.6763 - val_accuracy: 0.3991
Epoch 3/25
196/196 [==============================] - 4s 18ms/step - loss: 1.6860 - accuracy: 0.3886 - val_loss: 1.5956 - val_accuracy: 0.4273
Epoch 4/25
196/196 [==============================] - 4s 18ms/step - loss: 1.6211 - accuracy: 0.4137 - val_loss: 1.5276 - val_accuracy: 0.4477
Epoch 5/25
196/196 [==============================] - 4s 18ms/step - loss: 1.5687 - accuracy: 0.4335 - val_loss: 1.4805 - val_accuracy: 0.4716
Epoch 6/25
196/196 [==============================] - 4s 18ms/step - loss: 1.5212 - accuracy: 0.4489 - val_loss: 1.4675 - val_accuracy: 0.4703
Epoch 7/25
196/196 [==============================] - 4s 18ms/step - loss: 1.4795 - accuracy: 0.4668 - val_loss: 1.3747 - val_accuracy: 0.5062
Epoch 8/25
196/196 [==============================] - 4s 18ms/step - loss: 1.4378 - accuracy: 0.4843 - val_loss: 1.4392 - val_accuracy: 0.4915
Epoch 9/25
196/196 [==============================] - 4s 18ms/step - loss: 1.4022 - accuracy: 0.4948 - val_loss: 1.3317 - val_accuracy: 0.5238
Epoch 10/25
196/196 [==============================] - 4s 18ms/step - loss: 1.3754 - accuracy: 0.5075 - val_loss: 1.3411 - val_accuracy: 0.5188
Epoch 11/25
196/196 [==============================] - 4s 18ms/step - loss: 1.3454 - accuracy: 0.5214 - val_loss: 1.2909 - val_accuracy: 0.5325
Epoch 12/25
196/196 [==============================] - 4s 18ms/step - loss: 1.3204 - accuracy: 0.5301 - val_loss: 1.2476 - val_accuracy: 0.5561
Epoch 13/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2972 - accuracy: 0.5380 - val_loss: 1.2480 - val_accuracy: 0.5645
Epoch 14/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2710 - accuracy: 0.5496 - val_loss: 1.2333 - val_accuracy: 0.5609
Epoch 15/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2558 - accuracy: 0.5549 - val_loss: 1.2088 - val_accuracy: 0.5711
Epoch 16/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2344 - accuracy: 0.5649 - val_loss: 1.1727 - val_accuracy: 0.5877
Epoch 17/25
196/196 [==============================] - 4s 18ms/step - loss: 1.2138 - accuracy: 0.5720 - val_loss: 1.1674 - val_accuracy: 0.5917
Epoch 18/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1951 - accuracy: 0.5772 - val_loss: 1.1401 - val_accuracy: 0.6032
Epoch 19/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1730 - accuracy: 0.5866 - val_loss: 1.1111 - val_accuracy: 0.6112
Epoch 20/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1543 - accuracy: 0.5925 - val_loss: 1.1721 - val_accuracy: 0.5866
Epoch 21/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1347 - accuracy: 0.6006 - val_loss: 1.1129 - val_accuracy: 0.6091
Epoch 22/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1210 - accuracy: 0.6056 - val_loss: 1.1031 - val_accuracy: 0.6084
Epoch 23/25
196/196 [==============================] - 4s 18ms/step - loss: 1.1049 - accuracy: 0.6107 - val_loss: 1.0297 - val_accuracy: 0.6402
Epoch 24/25
196/196 [==============================] - 4s 18ms/step - loss: 1.0877 - accuracy: 0.6183 - val_loss: 1.0093 - val_accuracy: 0.6507
Epoch 25/25
196/196 [==============================] - 4s 18ms/step - loss: 1.0741 - accuracy: 0.6223 - val_loss: 1.0872 - val_accuracy: 0.6193
CPU times: user 42.9 s, sys: 8.32 s, total: 51.2 s
Wall time: 1min 35s

On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.