Depurar erros de falta de memória com o XProf

Os erros de falta de memória (OOM, na sigla em inglês) ocorrem quando a capacidade de memória de alta largura de banda (HBM) do acelerador (GPU ou TPU) se esgota. Algumas causas comuns de problemas de falta de memória (OOM) e técnicas de depuração são detalhadas em E1000: documentação sobre OOM de HBM no tempo de compilação e documentação do JAX sobre alocação de memória da GPU.

Esta página descreve como usar a ferramenta Memory Viewer do XProf para visualizar o uso de memória do seu programa JAX, identificar instâncias de uso máximo e depurar erros de falta de memória. Isso envolve as seguintes etapas:

  1. Execute o programa com jax.profiler.trace para capturar o perfil.
  2. Inicie o XProf em segundo plano e use a ferramenta Memory Viewer para conferir detalhes de utilização da memória.

Programa de exemplo

O programa JAX a seguir causa um erro de falta de memória:

import jax
from jax import random
import jax.numpy as jnp


@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
    a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
    return a @ a


if __name__ == "__main__":
    oom()

Em uma máquina com TPU, esse programa falha com:

XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>

Em uma máquina com GPU, o erro aparece assim: XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 214748364800 bytes.

Executar o XProf

Instale o xprof (pip install xprof) e inicie uma instância do XProf especificando o diretório em que o perfil está armazenado:

xprof --logdir=/tmp/xprof/ --port=6006

Acesse a instância (em uma máquina local, em http://localhost:6006). No menu suspenso Ferramentas, selecione Visualizador de memória e, na janela de ferramentas do visualizador de memória, selecione HBM no menu suspenso Tipos de memória (geralmente selecionado por padrão).

Página do visualizador de memória do XProf para o programa de exemplo acima

A documentação da ferramenta XProf: Memory Viewer descreve os componentes da ferramenta e as informações apresentadas.

Concentre-se na seção Operações de HLO no pico de alocação de memória, que mostra três gráficos de buffer no ponto de pico de uso da memória. O buffer inclui: * Entradas e saídas do programa: lotes de treinamento, estados do otimizador etc. * Variáveis temporárias do TensorCore e do SparseCore: memória dinâmica necessária para cálculos intermediários (como ativações, gradientes etc.)

Passe o cursor sobre os gráficos de buffer para mais detalhes sobre a operação, como tamanho, forma, tipo de alocação e muito mais. Isso pode ajudar a identificar e avaliar operações que podem ter temporários altos ou de longa duração, tensores de entrada/intermediários/saída grandes com padding ineficiente etc., que estão contribuindo para o pico de memória e precisam ser ajustados ou otimizados.

Aprenda técnicas específicas de depuração em E1000: Debugging.