XLA: ottimizzare il compilatore per il machine learning

OpenXLA è un compilatore specifico del dominio per l'algebra lineare in grado di accelerare i modelli TensorFlow senza apportare modifiche al codice sorgente.

Introduzione

Quando viene eseguito un programma TensorFlow, tutte le operazioni vengono eseguite singolarmente dall'esecutore di TensorFlow. Ogni operazione TensorFlow ha un'implementazione del kernel GPU precompilata che viene inviata dall'esecutore.

XLA fornisce una modalità di esecuzione alternativa: compila il grafico TensorFlow in una sequenza di kernel di calcolo generati specificamente per il modello in questione. Poiché questi kernel sono univoci per il modello, possono sfruttare le informazioni specifiche del modello per l'ottimizzazione. Ad esempio, diamo un'occhiata a un'ottimizzazione XLA nel contesto di un semplice calcolo con TensorFlow:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

Eseguito senza XLA, il grafico lancia tre kernel: uno per la moltiplicazione, uno per l'addizione e uno per la riduzione. Tuttavia, XLA può ottimizzare il grafico in modo da calcolare il risultato in un lancio di un singolo kernel. Lo fa "fondendo" l'aggiunta, la moltiplicazione e la riduzione in un unico kernel GPU. Inoltre, questa operazione combinata non scrive nella memoria i valori intermedi prodotti da y*z e x+y*z, ma trasmette i risultati di questi calcoli intermedi direttamente agli utenti, mantenendoli interamente nei registri GPU. La fusione è l'ottimizzazione più importante di XLA. La larghezza di banda della memoria è in genere la risorsa più scarsa per gli acceleratori hardware, quindi la rimozione delle operazioni di memoria è uno dei modi migliori per migliorare le prestazioni.

Abilita XLA per i modelli TensorFlow

Compilation esplicita con tf.function(jit_compile=True)

L'API di compilazione esplicita offre un controllo granulare per la scelta delle funzioni da compilare. Ad esempio, la seguente funzione TensorFlow che esegue l'addestramento MNIST è compilata con XLA:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

L'API jit_compile ha la semantica need-compile: l'intera funzione viene compilata con XLA oppure viene generata un'eccezione errors.InvalidArgumentError. Al momento, XLA non può compilare funzioni in cui le dimensioni non sono inferrabili, ovvero se non è possibile dedurre le dimensioni di tutti i tensori senza eseguire l'intero calcolo. Ad esempio, la seguente funzione non compilerà:

@tf.function
def not_compilable(x):
  return tf.unique(x)

Tuttavia, le forme possono variare nelle esecuzioni:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

Guarda il tutorial di Colab per un esempio di utilizzo più dettagliato e un video tutorial sull'utilizzo di jit_compile=True.

Utilizzo con Keras

Per i modelli Keras, jit_compile=True può essere impostato come argomento a model.compile:

model.compile(optimizer="adam", jit_compile=True)

Utilizzo con strategia distribuita

XLA:GPU può essere utilizzato con una strategia distribuita TF (MirroredStrategy o MultiWorkerMirroredStrategy) annotando la funzione di passaggio con jit_compile=True:

@tf.function(jit_compile=True)
def step_fn():
  t = tf.ones(shape=[100], dtype=tf.float32)
  ctx = tf.distribute.get_replica_context()
  return ctx.all_reduce(tf.distribute.ReduceOp.SUM, t)

@tf.function
def run_fn():
  return strategy.run(step_fn)

Clustering automatico

Un modo semplice per iniziare a utilizzare XLA nei modelli TensorFlow senza apportare modifiche consiste nell'abilitare il clustering automatico, che trova automaticamente i cluster (sottografi collegati) all'interno delle funzioni TensorFlow che possono essere compilate ed eseguite utilizzando XLA. Il clustering automatico sulla GPU può essere attivato impostando la variabile di ambiente TF_XLA_FLAGS:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

Il clustering automatico è attualmente ottimizzato per i carichi di lavoro delle GPU, ma può essere abilitato anche sulla CPU utilizzando il flag --tf_xla_cpu_global_jit:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

Per un esempio di utilizzo dettagliato, consulta il tutorial sul clustering automatico di clustering.

Compilazione AOT (Ahead-of-time) per CPU con tfcompile

Puoi anche utilizzare uno strumento tfcompile autonomo, che converte il grafico TensorFlow in codice eseguibile (solo per CPU x86-64).

Ispeziona i programmi compilati

XLA fornisce servizi di introspezione che consentono di esaminare i programmi generati. Per eseguire il dump dei programmi generati, utilizza la variabile di ambiente XLA_FLAGS:

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

Dopo l'esecuzione del dumping, puoi trovare i seguenti file in /tmp/generated:

  • module_XXXX.*_optimizations.txt Programmi XLA generati, uno per ogni cluster compilato. Allegare queste foto quando si inviano segnalazioni di bug XLA è estremamente utile.

  • module_XXXX.ir-*.ll File generati nella rappresentazione intermedia LLVM, con le intrinseche NVPTX.

  • module_XXXX.ptx File PTX generati.

Puoi anche eseguire il dump del grafico che visualizza l'incorporamento dei cluster XLA all'interno del grafico TensorFlow con:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

Segnalazioni di bug riproducibili

Una segnalazione di bug è molto più facile da riprodurre se include dump per i programmi XLA generati e l'incorporamento di clustering automatico utilizzato. Per generarli per un programma TensorFlow in esecuzione con il clustering automatico, avvia:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

Quando invii bug, allega i contenuti della directory /tmp/generated (a cui si fa riferimento sopra).

Se possibile, prova a isolare un bug a un singolo programma XLA utilizzando run_hlo_module ed eseguendolo in modo iterativo sui programmi generati.

Per approfondire

Frontend XLA

A parte TensorFlow, i programmi XLA possono essere generati:

  • JAX: trasformazioni componibili di programmi Python + NumPy
  • Julia: la lingua Julia per il calcolo scientifico
  • PyTorch: framework PyTorch
  • Nx: libreria di calcolo numerico per il linguaggio di programmazione Elixir

Discorsi

Con XLA di TF usando jit_compile=True

Panoramica XLA