Los errores de memoria insuficiente (OOM) ocurren cuando se agota la capacidad de memoria de alto ancho de banda (HBM) del acelerador (GPU o TPU). En la documentación de E1000 sobre el error de OOM de HBM en tiempo de compilación y la documentación de JAX sobre la asignación de memoria de GPU, se detallan algunas causas comunes de los problemas de OOM y las técnicas de depuración.
En esta página, se describe cómo usar la herramienta Memory Viewer de XProf para visualizar el uso de memoria de tu programa en JAX, identificar instancias de uso máximo y depurar errores de OOM. Esto implica los siguientes pasos:
- Ejecuta tu programa con
jax.profiler.tracepara capturar el perfil. - Inicia XProf en segundo plano y usa la herramienta Memory Viewer para ver los detalles del uso de la memoria.
Programa de ejemplo
El siguiente programa de JAX genera un error de OOM:
import jax
from jax import random
import jax.numpy as jnp
@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
return a @ a
if __name__ == "__main__":
oom()
En una máquina TPU, este programa falla con el siguiente mensaje:
XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>
(En una máquina con GPU, el error se ve así: XlaRuntimeError: RESOURCE_EXHAUSTED:
Out of memory while trying to allocate 214748364800 bytes.)
Ejecuta XProf
Instala xprof (pip install xprof) y, luego, inicia una instancia de XProf especificando el directorio en el que se almacena el perfil:
xprof --logdir=/tmp/xprof/ --port=6006
Ve a la instancia (en una máquina local, en http://localhost:6006). En el menú desplegable Herramientas, selecciona Visor de memoria y, en la ventana de herramientas Visor de memoria, selecciona HBM en el menú desplegable Tipos de memoria (por lo general, se selecciona de forma predeterminada).

En la documentación de la herramienta XProf: Memory Viewer, se describen los componentes de la herramienta y la información que se presenta.
Enfócate en la sección Operaciones de HLO en la asignación máxima de memoria, que muestra tres gráficos de búfer en el punto de uso máximo de memoria. El búfer incluye lo siguiente: * Entradas y salidas del programa: lotes de entrenamiento, estados del optimizador, etcétera * Variables temporales de TensorCore y SparseCore: Memoria dinámica necesaria para los cálculos intermedios (como activaciones, gradientes, etcétera)
Puedes colocar el cursor sobre los gráficos de búfer para obtener más detalles sobre la operación, como su tamaño, forma, tipo de asignación y mucho más. Esto puede ayudarte a identificar y evaluar las operaciones que pueden tener temporales altos o duraderos, cualquier tensor de entrada/intermedio/salida grande que tenga un padding ineficiente, etc., que contribuyan a la memoria máxima y deban ajustarse u optimizarse.
Aprende técnicas de depuración específicas en E1000: Debugging.