XLA: Optimización del compilador para el aprendizaje automático

OpenXLA es un compilador específico del dominio para álgebra lineal que puede acelerar los modelos de TensorFlow casi sin cambios en el código fuente.

Introducción

Cuando se ejecuta un programa de TensorFlow, el ejecutor de TensorFlow ejecuta todas las operaciones de forma individual. Cada operación de TensorFlow tiene una implementación de kernel de GPU precompilada a la que el ejecutor envía datos.

XLA proporciona un modo alternativo de ejecutar modelos: compila el grafo de TensorFlow en una secuencia de kernels de procesamiento generados específicamente para el modelo dado. Debido a que estos kernels son exclusivos del modelo, pueden aprovechar la información específica del modelo para la optimización. Por ejemplo, veamos una optimización que XLA realiza en el contexto de un cálculo simple de TensorFlow:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

Si se ejecuta sin XLA, el gráfico inicia tres kernels: uno para la multiplicación, uno para la suma y otro para la reducción. Sin embargo, XLA puede optimizar el gráfico para que calcule el resultado en un solo lanzamiento de kernel. Para ello, "fusiona" la suma, la multiplicación y la reducción en un solo kernel de GPU. Además, esta operación fusionada no escribe los valores intermedios producidos por y*z y x+y*z en la memoria, sino que "transmite" los resultados de estos cálculos intermedios directamente a los usuarios y los mantiene por completo en los registros de la GPU. La fusión es la optimización más importante de XLA. Por lo general, el ancho de banda de la memoria es el recurso más escaso en los aceleradores de hardware, por lo que quitar las operaciones de memoria es una de las mejores formas de mejorar el rendimiento.

Habilita XLA para los modelos de TensorFlow

Compilación explícita con tf.function(jit_compile=True)

La API de compilación explícita ofrece un control detallado para elegir qué funciones se deben compilar. Por ejemplo, la siguiente función de TensorFlow, que realiza el entrenamiento de MNIST, se compila con XLA:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

La API de jit_compile tiene una semántica que debe compilar: se compila toda la función con XLA o se arroja una excepción errors.InvalidArgumentError. Actualmente, XLA no puede compilar funciones en las que las dimensiones no sean inferibles, es decir, cuando no es posible inferir las dimensiones de todos los tensores sin ejecutar todo el cálculo. Por ejemplo, no se compilará la siguiente función:

@tf.function
def not_compilable(x):
  return tf.unique(x)

Sin embargo, las formas pueden variar en las ejecuciones:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

Consulta el instructivo de Colab para ver un ejemplo de uso más detallado y un video instructivo sobre el uso de jit_compile=True.

Uso con Keras

Para los modelos de Keras, jit_compile=True se puede establecer como un argumento en model.compile:

model.compile(optimizer="adam", jit_compile=True)

Uso con estrategia distribuida

XLA:GPU se puede usar con la estrategia distribuida de TF (MirroredStrategy o MultiWorkerMirroredStrategy) mediante la anotación de la función del paso con jit_compile=True:

@tf.function(jit_compile=True)
def step_fn():
  t = tf.ones(shape=[100], dtype=tf.float32)
  ctx = tf.distribute.get_replica_context()
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)

@tf.function
def run_fn():
  return strategy.run(step_fn)

Agrupamiento en clústeres automático

Una forma sencilla de comenzar a usar XLA en modelos de TensorFlow sin realizar cambios es habilitar el agrupamiento en clústeres automático, que encuentra automáticamente clústeres (subgrafos conectados) dentro de las funciones de TensorFlow que se pueden compilar y ejecutar con XLA. El agrupamiento en clústeres automático en la GPU se puede habilitar mediante la configuración de la variable de entorno TF_XLA_FLAGS:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

Actualmente, el agrupamiento en clústeres automático está optimizado para cargas de trabajo de GPU, pero también se puede habilitar en la CPU si se usa la marca --tf_xla_cpu_global_jit:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

Para obtener un ejemplo de uso detallado, consulta el instructivo de agrupamiento en clústeres automático de Colab.

Compilación AOT (anticipada) para CPU con tfcompile

También puedes usar una herramienta tfcompile independiente, que convierte el grafo de TensorFlow en código ejecutable (solo para CPU x86-64).

Cómo inspeccionar programas compilados

XLA proporciona recursos de introspección que te permiten inspeccionar los programas generados. Para volcar los programas generados, usa la variable de entorno XLA_FLAGS:

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

Después de realizar el volcado, podrás encontrar los siguientes archivos en /tmp/generated:

  • module_XXXX.*_optimizations.txt Se generaron programas de XLA, uno por cada clúster compilado. Es muy útil adjuntarlos cuando se envían los informes de errores de XLA.

  • module_XXXX.ir-*.ll Se generaron archivos en la representación intermedia de LLVM, con funciones intrínsecas de NVPTX.

  • module_XXXX.ptx Se generaron archivos PTX.

También puedes volcar el grafo si visualizas la incorporación de clústeres de XLA dentro del grafo de TensorFlow con lo siguiente:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

Informes de errores reproducibles

Un informe de errores es mucho más fácil de reproducir si incluye volcados para los programas de XLA generados y la incorporación utilizada para el agrupamiento en clústeres automático. Con el objetivo de generarlos para un programa de TensorFlow que se ejecute con agrupamiento en clústeres automático, inicia lo siguiente:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

Cuando registres errores, adjunta el contenido del directorio /tmp/generated (al que se hace referencia más arriba).

Si es posible, intenta aislar un error en un solo programa de XLA con run_hlo_module y ejecútalo de forma iterativa en programas generados.

Lecturas adicionales

Frontends de XLA

Además de hacerlo con TensorFlow, los programas XLA se pueden generar con los siguientes elementos:

  • JAX: Transformaciones componibles de programas Python+NumPy
  • Julia: El lenguaje Julia para el procesamiento científico
  • PyTorch: framework de PyTorch
  • Nx: Biblioteca de procesamiento numérico para el lenguaje de programación Elixir

Charlas

Usa XLA desde TF mediante jit_compile=True

Descripción general de XLA