Categoria:tempo de compilação: HBM OOM
Esse erro indica que o programa exige mais memória de alta largura de banda (HBM) do que está fisicamente disponível no dispositivo TPU.
Exemplo de mensagens de erro:
RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.
Back-ends do XLA:TPU
Visão geral
O XLA realiza verificações para garantir que o tamanho agregado de todas as alocações estáticas necessárias caiba na HBM do dispositivo.
O compilador gerencia a capacidade fixa da HBM da TPU para vários tipos de alocações:
- Entradas e saídas do programa:lotes de treinamento, estados do otimizador etc.
- Temporários da TPU:memória dinâmica necessária para cálculos intermediários (por exemplo, ativações, gradientes etc.).
- Binário compilado:o código de máquina para TensorCore (TC) e SparseCore (SC).
- Sobrecarga do sistema:espaço reservado para o ambiente de execução do XLA (por exemplo, buffers de entrada em gerações mais antigas de TPU).
- Constantes:valores constantes incorporados no IR do HLO são alocados na HBM.
- Internos do compilador:alocações de nível de programa e por HLO (por exemplo, informações de roteamento para nós na malha).
Esse erro ocorre quando o compilador XLA não consegue ajustar todas as alocações acima na HBM do dispositivo.
Depuração
Analise cuidadosamente a mensagem de erro e os registros para determinar qual categoria de HBM OOM abaixo descreve melhor seu erro:
- "TC Hbm usage: X, SC Hbm usage Y": se o erro detalha explicitamente o uso da HBM, o uso agregado de TensorCore (TC) + SparseCore (SC) excede o limite da HBM. → Acesse o cenário 1. Equilibre o uso da HBM de TC e SC.
- "Ran out of memory in memory space HBM": verifique os registros para uma
enumeração das maiores alocações na HBM.
- Caso um ou mais tensores inesperadamente grandes (por exemplo, > 50% do limite da HBM) estejam presentes → Acesse o cenário 2. Memória insuficiente devido a alocações inesperadamente grandes.
- Se nenhum tensor inesperadamente grande estiver presente nos registros → Acesse o cenário 3. Memória insuficiente devido a alocações agregadas.
Cenário 1. Equilibre o uso da HBM de TC e SC
Se o erro detalha explicitamente o uso, por exemplo, "TC Hbm usage: X, SC Hbm usage Y", isso significa que o uso agregado de TensorCore (TC) + SparseCore (SC) excede o limite da HBM. Compare os dois valores para identificar o gargalo:
- Alto uso de SparseCore
- Otimize o uso da pilha de HBM:o consumo de memória da pilha de HBM é dimensionado com
feature_width,max_unique_nz_per_rowelogical_replica_count. É possível reduzir o uso máximo da pilha ajustando a flag--xla_sc_num_serialized_tables_to_optimize_hbm, que serializa o processamento de tabelas. Isso tem o custo de paralelismo reduzido. - Verifique a sobrecarga de preenchimento:o SparseCore alinha as tabelas de incorporação a 32B (8 flutuantes). Tabelas com larguras de recursos pequenas (por exemplo, < 8 flutuantes) incorrem em uma sobrecarga de preenchimento significativa, desperdiçando HBM.
- Reduza o uso de heap:valores altos para
maximum_parallel_iterationsaumentam a quantidade de dados de entrada pré-buscados no heap da HBM. Diminuir esse valor pode liberar uma quantidade significativa de memória. - Verifique o particionamento:garanta que as tabelas de incorporação sejam particionadas corretamente em todos os chips. Consulte Como os limites são traduzidos em tabelas.
- Confira SC: gargalos de desempenho e memória para mais ideias.
- Otimize o uso da pilha de HBM:o consumo de memória da pilha de HBM é dimensionado com
- Alto uso de TensorCore
- Acesse o cenário 2.
- Equilibrado
- Se nenhum dos dois for individualmente excessivo, mas a soma for muito alta, você estará na capacidade do chip. Tente diminuir o uso dos dois componentes. Siga as recomendações nas três seções.
Cenário 2. Memória insuficiente devido a alocações inesperadamente grandes
Se você observar a mensagem de erro "Ran out of memory in memory space HBM" e uma ou mais alocações inesperadamente grandes estiverem presentes nos registros (> 50% do limite da HBM ), quase nunca será um problema de capacidade de hardware. Normalmente, é um erro de configuração. Verifique o rótulo XLA (se presente) das grandes alocações para dicas sobre o código-fonte JAX.
- Remover artefatos de depuração
- O uso de
jax.debug.print()
em execuções em grande escala pode forçar o compilador a materializar o
tensor completo na HBM para transferi-lo para a CPU, interrompendo a fusão e aumentando o
uso máximo da memória. Remova todos os
jax.debug.print()s restantes.
- O uso de
jax.debug.print()
em execuções em grande escala pode forçar o compilador a materializar o
tensor completo na HBM para transferi-lo para a CPU, interrompendo a fusão e aumentando o
uso máximo da memória. Remova todos os
- Corrigir formas de malha ou particionamento ineficientes
- Formas de malha incorretas ou anotações de particionamento ausentes podem fazer com que o compilador use a replicação como padrão, forçando o compilador a tentar ajustar tensores muito grandes em um único chip.
- Verifique as formas das grandes alocações e verifique se o particionamento está especificado e propagado corretamente pelo XLA.
Cenário 3. Memória insuficiente devido a alocações agregadas
Se você observar a mensagem de erro "Ran out of memory in memory space HBM" e nenhum tensor inesperadamente grande estiver presente nos registros, o programa ficará sem capacidade devido à soma agregada de alocações que excedem o limite da HBM. Nesse caso, geralmente é útil visualizar o perfil de memória para identificar os buffers específicos que contribuem para o uso máximo. Consulte Depurar erros de OOM com o XProf para um guia detalhado sobre como identificar os principais colaboradores de memória.
Depois de identificar alguns dos principais colaboradores, siga estas etapas para otimizar o consumo de memória.
Cenário 3.A Ajustar a configuração
Muitas vezes, é possível resolver OOMs com esses ajustes de configuração:
- Reduza o tamanho do lote:a memória necessária para ativações e gradientes intermediários é diretamente proporcional ao tamanho do lote. Reduzir o tamanho do lote pode ajudar a reduzir o uso da memória.
- Doe buffers de entrada:ao usar
jax.jit, especifique donate_argnums para os parâmetros do modelo. Isso permite que o XLA substitua a memória de entrada pela saída. - Ative a precisão mista (bfloat16) : use bfloat16 ou quantização (int8 etc.) para os maiores tensores do programa se a arquitetura do modelo e os requisitos de qualidade permitirem. Essa mudança pode afetar o comportamento do modelo e precisa ser considerada com cuidado.
Cenário 3.B Otimizar a arquitetura e o particionamento
Se as mudanças de configuração forem insuficientes, a topologia do modelo poderá ser muito grande para a configuração de hardware atual.
- Use gerações mais recentes de TPU:as TPUs mais recentes geralmente oferecem mais HBM por chip. Mude para gerações mais recentes de TPU, se disponíveis.
- Execute em uma topologia de chip maior:se os pesos do modelo forem muito grandes para a topologia atual, tente particioná-los em mais chips.
- Implemente técnicas avançadas de particionamento:
- Confira abordagens mais avançadas de paralelismo de dados, tensores ou pipelines.
- Especifique dicas de particionamento para valores e saídas intermediários.
- Use o descarregamento de host JAX: as técnicas de descarregamento de host permitem que o usuário descarregue tensores grandes para a memória da CPU do host (por exemplo, descarregamento de ativação e descarregamento de estado do otimizador).
Cenário 3.C Verificar o preenchimento e o alinhamento do tensor
Formas de tensor ineficientes são uma causa comum e silenciosa de OOMs em TPUs. Para ter o desempenho máximo nas TPUs, o XLA preenche as dimensões do tensor, normalmente para múltiplos de 128 para a dimensão menor e 8 para a segunda menor. Esse preenchimento afeta matrizes de entrada e tensores intermediários (temporários do HLO), podendo inflacionar significativamente o uso da memória, especialmente com tamanhos de dimensão pequenos. Consulte Layouts de matriz.
- Auditar formas de buffers grandes : (na TPU v5 com layouts padrão)
- Passar o cursor sobre um buffer no visualizador de memória do Xprof mostra a ficha de informações do buffer, que contém detalhes do buffer, incluindo informações de padding.
- Exemplo: uma forma de
(129, 1024)pode ser preenchida para(256, 1024), resultando em quase 50% de desperdício de memória. - Correção:uma forma de
(128, 1024)não requer preenchimento e incorre em 0% de desperdício de memória.
- Alinhar dimensões:garanta que todas as dimensões de tensor grandes (tamanho do lote, dimensão de incorporação, tamanho oculto) sejam múltiplos de 128. Essa mudança pode afetar o comportamento do modelo e precisa ser considerada com cuidado.
Cenário 3.D Ajustar flags de XLA que afetam a memória principal
As flags de memória principais podem ser ajustadas para compensar o desempenho por um uso da memória menor. No entanto, essa estratégia precisa ser usada como uma medida de último recurso, já que pode afetar negativamente o desempenho.
Cenário 3.E Ajustar a passagem de rematerialização do XLA/checkpoint manual
Se o modelo estiver quase cabendo na memória, você poderá usar o
jax.checkpoint
decorador com jax.grad para controlar manualmente quais intermediários são salvos na
passagem direta em comparação com a recompilação na passagem inversa, trocando ciclos de computação
por HBM.
Como alternativa, você pode forçar a passagem XLA::Rematerialization a priorizar a economia de memória, possivelmente ao custo de compilações mais lentas:
| Flag | Descrição | Impacto / compensação |
|---|---|---|
--xla_tpu_max_hbm_size_mib |
Define manualmente o limite de tamanho da HBM usado pela passagem de rematerialização. | Força o compilador a trabalhar mais para ajustar o programa a um limite menor que a HBM física real. |
--xla_tpu_rematerialization_algo=PEAK_PRIORITY |
Concentra os esforços nos pontos de uso da memória. | Pode ser mais eficiente para redução agressiva de memória do que o algoritmo padrão. |
--xla_tpu_rematerialization_max_block_size_limit=32 |
Controla o número máximo de instruções em um bloco que podem ser rematerializadas de uma só vez. | Aumentar isso permite economizar memória ao custo de aumentar significativamente o tempo de compilação. |
--xla_tpu_rematerialization_block_effort_factor=10.0 |
Define a quantidade de esforço (tempo de compilação) gasto na pesquisa de blocos a serem rematerializados. | Valores mais altos permitem uma pesquisa mais exaustiva de economia de memória ao custo de aumento do tempo de compilação. |
--xla_tpu_pre_fusion_remat=true |
Ativa uma passagem de rematerialização adicional antes da passagem de fusão. | Pode encontrar mais economia de memória, mas aumenta o tempo de compilação e pode afetar a estabilidade numérica. |
As mudanças nas flags do XLA precisam ser usadas como uma medida de último recurso, já que podem afetar negativamente o desempenho.
Cenário 3.F Usar ferramentas avançadas de criação de perfil
Depurar erros de OOM com o XProf oferece um tutorial sobre como usar o visualizador de memória do XProf para visualizar a visão do compilador sobre o uso da HBM.
Essa ferramenta permite que você veja a alocação máxima de memória e os ciclos de vida do buffer, o que é fundamental para entender exatamente o que consome HBM no ponto de utilização máxima. Para a configuração geral de criação de perfil, consulte Introdução à criação de perfil do Xprof e do TensorBoard.