Eseguire il debug degli errori OOM con XProf

Gli errori di esaurimento della memoria (OOM) si verificano quando la capacità della memoria ad alta larghezza di banda (HBM) dell'acceleratore (GPU o TPU) è esaurita. Alcune cause comuni dei problemi di esaurimento della memoria e le tecniche di debug sono descritte in dettaglio nella documentazione E1000 - Compile Time HBM OOM e nella documentazione JAX sull'allocazione della memoria della GPU.

Questa pagina descrive come utilizzare lo strumento Memory Viewer di XProf per visualizzare l'utilizzo della memoria del tuo programma JAX, identificare le istanze di picco di utilizzo ed eseguire il debug degli errori di esaurimento della memoria. che include i seguenti passaggi:

  1. Esegui il programma con jax.profiler.trace per acquisire il profilo.
  2. Avvia XProf in background e utilizza lo strumento Visualizzatore memoria per visualizzare i dettagli sull'utilizzo della memoria.

Programma di esempio

Il seguente programma JAX genera un errore OOM:

import jax
from jax import random
import jax.numpy as jnp


@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
    a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
    return a @ a


if __name__ == "__main__":
    oom()

Su una macchina TPU, questo programma non viene eseguito e viene visualizzato il seguente messaggio:

XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>

Su una macchina GPU, l'errore ha il seguente aspetto: XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 214748364800 bytes.

Esegui XProf

Installa xprof (pip install xprof) e avvia un'istanza XProf specificando la directory in cui è memorizzato il profilo:

xprof --logdir=/tmp/xprof/ --port=6006

Vai all'istanza (su una macchina locale, in http://localhost:6006). Nel menu a discesa Strumenti, seleziona Visualizzatore memoria e, nella finestra dello strumento Visualizzatore memoria, seleziona HBM nel menu a discesa Tipi di memoria (in genere selezionato per impostazione predefinita).

Pagina XProf Memory Viewer per il programma di esempio precedente

La documentazione dello strumento XProf: Memory Viewer descrive i componenti dello strumento e le informazioni presentate.

Concentrati sulla sezione Operazioni HLO con allocazione massima della memoria, che mostra tre grafici dei buffer nel punto di utilizzo massimo della memoria. Il buffer include: * Input e output del programma: batch di addestramento, stati dell'ottimizzatore e così via. * Variabili temporanee TensorCore e SparseCore: memoria dinamica necessaria per i calcoli intermedi (come attivazioni, gradienti e così via).

Puoi passare il mouse sopra i grafici buffer per visualizzare ulteriori dettagli sull'operazione, ad esempio dimensioni, forma, tipo di allocazione e altro ancora. Questo può aiutarti a identificare e valutare le operazioni che potrebbero avere temporanei elevati o di lunga durata, eventuali tensori di input/intermedi/output con padding inefficiente e così via, che contribuiscono al picco di memoria e devono essere modificati o ottimizzati.

Scopri tecniche di debug specifiche in E1000: Debugging.