Os erros de falta de memória (OOM, na sigla em inglês) ocorrem quando a capacidade de memória de alta largura de banda (HBM) do acelerador (GPU ou TPU) se esgota. Algumas causas comuns de problemas de falta de memória (OOM) e técnicas de depuração são detalhadas em E1000: documentação sobre OOM de HBM no tempo de compilação e documentação do JAX sobre alocação de memória da GPU.
Esta página descreve como usar a ferramenta Memory Viewer do XProf para visualizar o uso de memória do seu programa JAX, identificar instâncias de uso máximo e depurar erros de falta de memória. Isso envolve as seguintes etapas:
- Execute o programa com
jax.profiler.tracepara capturar o perfil. - Inicie o XProf em segundo plano e use a ferramenta Memory Viewer para conferir detalhes de utilização da memória.
Programa de exemplo
O programa JAX a seguir causa um erro de falta de memória:
import jax
from jax import random
import jax.numpy as jnp
@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
return a @ a
if __name__ == "__main__":
oom()
Em uma máquina com TPU, esse programa falha com:
XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>
Em uma máquina com GPU, o erro aparece assim: XlaRuntimeError: RESOURCE_EXHAUSTED:
Out of memory while trying to allocate 214748364800 bytes.
Executar o XProf
Instale o xprof (pip install xprof) e inicie uma instância do XProf especificando
o diretório em que o perfil está armazenado:
xprof --logdir=/tmp/xprof/ --port=6006
Acesse a instância (em uma máquina local, em http://localhost:6006). No menu suspenso Ferramentas, selecione Visualizador de memória e, na janela de ferramentas do visualizador de memória, selecione HBM no menu suspenso Tipos de memória (geralmente selecionado por padrão).

A documentação da ferramenta XProf: Memory Viewer descreve os componentes da ferramenta e as informações apresentadas.
Concentre-se na seção Operações de HLO no pico de alocação de memória, que mostra três gráficos de buffer no ponto de pico de uso da memória. O buffer inclui: * Entradas e saídas do programa: lotes de treinamento, estados do otimizador etc. * Variáveis temporárias do TensorCore e do SparseCore: memória dinâmica necessária para cálculos intermediários (como ativações, gradientes etc.)
Passe o cursor sobre os gráficos de buffer para mais detalhes sobre a operação, como tamanho, forma, tipo de alocação e muito mais. Isso pode ajudar a identificar e avaliar operações que podem ter temporários altos ou de longa duração, tensores de entrada/intermediários/saída grandes com padding ineficiente etc., que estão contribuindo para o pico de memória e precisam ser ajustados ou otimizados.
Aprenda técnicas específicas de depuração em E1000: Debugging.