View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.
You will load and normalize the dataset using the TensorFlow Datasets (TFDS) API. First, install/upgrade TensorFlow and TFDS:
pip install -U -q tensorflow tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
2024-07-19 11:18:39.588103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-07-19 11:18:39.609085: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-07-19 11:18:39.615296: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.
def load_data():
result = tfds.load('cifar10', batch_size = -1)
(x_train, y_train) = result['train']['image'],result['train']['label']
(x_test, y_test) = result['test']['image'],result['test']['label']
x_train = x_train.numpy().astype('float32') / 256
x_test = x_test.numpy().astype('float32') / 256
# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
return ((x_train, y_train), (x_test, y_test))
(x_train, y_train), (x_test, y_test) = load_data()
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1721387923.894314 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.896437 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.898507 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.900583 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.910296 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.912294 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.914239 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.916220 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.918290 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.920262 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.922218 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.924190 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.963504 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.965567 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.967555 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.969567 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.971618 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.973610 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.975536 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.977518 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.979562 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.983313 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.985682 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387923.988058 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.103645 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.106229 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.108729 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.111281 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.113938 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.115885 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.117819 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.119803 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.121837 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.123791 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.125723 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.127688 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.130192 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.132190 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.134131 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.136103 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.138146 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.140102 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.142033 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.144007 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.146026 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.147993 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.149925 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.151892 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.154042 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.156025 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.157987 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.159972 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.162026 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.164009 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.165957 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1721387925.167954 7779 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
We define the model, adapted from the Keras CIFAR-10 example:
def generate_model():
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(32, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(64, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10),
tf.keras.layers.Activation('softmax')
])
model = generate_model()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs)
We train the model using the RMSprop optimizer:
def compile_model(model):
opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
return model
model = compile_model(model)
def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)
def warmup(model, x_train, y_train, x_test, y_test):
# Warm up the JIT, we do not wish to measure the compilation time.
initial_weights = model.get_weights()
train_model(model, x_train, y_train, x_test, y_test, epochs=1)
model.set_weights(initial_weights)
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1721387931.695570 7951 service.cc:146] XLA service 0x7fb9cc06ff00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: I0000 00:00:1721387931.695624 7951 service.cc:154] StreamExecutor device (0): Tesla T4, Compute Capability 7.5 I0000 00:00:1721387931.695628 7951 service.cc:154] StreamExecutor device (1): Tesla T4, Compute Capability 7.5 I0000 00:00:1721387931.695631 7951 service.cc:154] StreamExecutor device (2): Tesla T4, Compute Capability 7.5 I0000 00:00:1721387931.695633 7951 service.cc:154] StreamExecutor device (3): Tesla T4, Compute Capability 7.5 13/196 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - accuracy: 0.1111 - loss: 2.3062 I0000 00:00:1721387936.369950 7951 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 196/196 ━━━━━━━━━━━━━━━━━━━━ 13s 38ms/step - accuracy: 0.1732 - loss: 2.1988 - val_accuracy: 0.3324 - val_loss: 1.9102 Epoch 1/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.1594 - loss: 2.2286 - val_accuracy: 0.2999 - val_loss: 1.9607 Epoch 2/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.3213 - loss: 1.8832 - val_accuracy: 0.3861 - val_loss: 1.7008 Epoch 3/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.3762 - loss: 1.7196 - val_accuracy: 0.4266 - val_loss: 1.5860 Epoch 4/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4071 - loss: 1.6343 - val_accuracy: 0.4460 - val_loss: 1.5507 Epoch 5/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4304 - loss: 1.5668 - val_accuracy: 0.4713 - val_loss: 1.4475 Epoch 6/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4489 - loss: 1.5251 - val_accuracy: 0.4886 - val_loss: 1.4013 Epoch 7/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4720 - loss: 1.4580 - val_accuracy: 0.5018 - val_loss: 1.3746 Epoch 8/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.4853 - loss: 1.4357 - val_accuracy: 0.5177 - val_loss: 1.3616 Epoch 9/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.5019 - loss: 1.3863 - val_accuracy: 0.5285 - val_loss: 1.3113 Epoch 10/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5097 - loss: 1.3686 - val_accuracy: 0.5460 - val_loss: 1.2699 Epoch 11/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5155 - loss: 1.3464 - val_accuracy: 0.5275 - val_loss: 1.3146 Epoch 12/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5352 - loss: 1.3097 - val_accuracy: 0.5575 - val_loss: 1.2461 Epoch 13/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5421 - loss: 1.2912 - val_accuracy: 0.5716 - val_loss: 1.2073 Epoch 14/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5480 - loss: 1.2704 - val_accuracy: 0.5780 - val_loss: 1.1883 Epoch 15/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5589 - loss: 1.2447 - val_accuracy: 0.5758 - val_loss: 1.1973 Epoch 16/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5619 - loss: 1.2290 - val_accuracy: 0.5835 - val_loss: 1.1736 Epoch 17/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5750 - loss: 1.2067 - val_accuracy: 0.6029 - val_loss: 1.1326 Epoch 18/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5792 - loss: 1.1889 - val_accuracy: 0.6093 - val_loss: 1.1105 Epoch 19/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5921 - loss: 1.1644 - val_accuracy: 0.6173 - val_loss: 1.0910 Epoch 20/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5945 - loss: 1.1488 - val_accuracy: 0.5960 - val_loss: 1.1313 Epoch 21/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6035 - loss: 1.1238 - val_accuracy: 0.6293 - val_loss: 1.0471 Epoch 22/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6051 - loss: 1.1195 - val_accuracy: 0.6311 - val_loss: 1.0628 Epoch 23/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6142 - loss: 1.0944 - val_accuracy: 0.6360 - val_loss: 1.0296 Epoch 24/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6196 - loss: 1.0812 - val_accuracy: 0.6365 - val_loss: 1.0423 Epoch 25/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6277 - loss: 1.0634 - val_accuracy: 0.6389 - val_loss: 1.0376 CPU times: user 1min 23s, sys: 7.44 s, total: 1min 30s Wall time: 1min 20s 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - accuracy: 0.6393 - loss: 1.0278 Test loss: 1.0375560522079468 Test accuracy: 0.6388999819755554
Now let's train the model again, using the XLA compiler. To enable the compiler in the middle of the application, we need to reset the Keras session.
# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 ━━━━━━━━━━━━━━━━━━━━ 9s 30ms/step - accuracy: 0.1636 - loss: 2.2106 - val_accuracy: 0.3386 - val_loss: 1.8749 Epoch 1/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - accuracy: 0.1503 - loss: 2.2391 - val_accuracy: 0.3486 - val_loss: 1.8660 Epoch 2/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.3232 - loss: 1.8592 - val_accuracy: 0.3834 - val_loss: 1.7150 Epoch 3/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.3755 - loss: 1.7189 - val_accuracy: 0.4240 - val_loss: 1.6264 Epoch 4/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4039 - loss: 1.6524 - val_accuracy: 0.4361 - val_loss: 1.5560 Epoch 5/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4268 - loss: 1.5859 - val_accuracy: 0.4632 - val_loss: 1.4941 Epoch 6/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 0.4416 - loss: 1.5386 - val_accuracy: 0.4796 - val_loss: 1.4442 Epoch 7/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4644 - loss: 1.4853 - val_accuracy: 0.4947 - val_loss: 1.4027 Epoch 8/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4756 - loss: 1.4581 - val_accuracy: 0.5121 - val_loss: 1.3702 Epoch 9/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.4876 - loss: 1.4262 - val_accuracy: 0.5319 - val_loss: 1.3259 Epoch 10/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5032 - loss: 1.3797 - val_accuracy: 0.5370 - val_loss: 1.3116 Epoch 11/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5168 - loss: 1.3472 - val_accuracy: 0.5349 - val_loss: 1.2888 Epoch 12/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5259 - loss: 1.3270 - val_accuracy: 0.5626 - val_loss: 1.2377 Epoch 13/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5348 - loss: 1.3059 - val_accuracy: 0.5766 - val_loss: 1.2052 Epoch 14/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5419 - loss: 1.2832 - val_accuracy: 0.5681 - val_loss: 1.2101 Epoch 15/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5537 - loss: 1.2545 - val_accuracy: 0.5750 - val_loss: 1.2009 Epoch 16/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5620 - loss: 1.2324 - val_accuracy: 0.5908 - val_loss: 1.1602 Epoch 17/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5738 - loss: 1.2019 - val_accuracy: 0.6025 - val_loss: 1.1346 Epoch 18/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5812 - loss: 1.1870 - val_accuracy: 0.6072 - val_loss: 1.1096 Epoch 19/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5833 - loss: 1.1698 - val_accuracy: 0.6100 - val_loss: 1.1047 Epoch 20/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5932 - loss: 1.1555 - val_accuracy: 0.6090 - val_loss: 1.1049 Epoch 21/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.5994 - loss: 1.1335 - val_accuracy: 0.6282 - val_loss: 1.0643 Epoch 22/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6069 - loss: 1.1118 - val_accuracy: 0.6225 - val_loss: 1.0833 Epoch 23/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6117 - loss: 1.1023 - val_accuracy: 0.6263 - val_loss: 1.0536 Epoch 24/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6220 - loss: 1.0804 - val_accuracy: 0.6427 - val_loss: 1.0194 Epoch 25/25 196/196 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.6276 - loss: 1.0654 - val_accuracy: 0.6470 - val_loss: 1.0076 CPU times: user 1min 26s, sys: 6.52 s, total: 1min 33s Wall time: 1min 23s
On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x.