오류 코드: E1000

카테고리: 컴파일 시간: HBM OOM

이 오류는 프로그램에 TPU 기기에서 실제로 사용할 수 있는 것보다 더 많은 고대역폭 메모리 (HBM)가 필요함을 나타냅니다.

샘플 오류 메시지:

RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.

XLA 백엔드: TPU

개요

XLA는 필요한 모든 정적 할당의 집계 크기가 기기의 HBM에 맞는지 확인합니다.

컴파일러는 여러 유형의 할당에 대해 TPU의 고정 HBM 용량을 관리합니다.

  • 프로그램 입력 및 출력: 학습 배치, 옵티마이저 상태 등
  • TPU 임시 메모리: 중간 계산(예: 활성화, 기울기 등)에 필요한 동적 메모리입니다.
  • 컴파일된 바이너리: TensorCore (TC)와 SparseCore (SC)의 기계어 코드입니다.
  • 시스템 오버헤드: XLA 런타임용으로 예약된 공간입니다 (예: 이전 TPU 세대의 인피드 버퍼).
  • 상수: HLO IR에 삽입된 상수 값은 HBM에 할당됩니다.
  • 컴파일러 내부: 프로그램 수준 및 HLO별 할당 (예: 메시의 노드 라우팅 정보)

이 오류는 XLA 컴파일러가 위의 할당을 모두 기기 HBM에 맞출 수 없는 경우에 발생합니다.

디버깅

오류 메시지와 로그를 주의 깊게 분석하여 아래 HBM OOM 카테고리 중 오류를 가장 잘 설명하는 카테고리를 확인합니다.


시나리오 1. TC 및 SC HBM 사용량 균형 조정

오류에서 사용량을 명시적으로 분류하는 경우(예: 'TC Hbm 사용량: X, SC Hbm 사용량 Y') 집계된 TensorCore(TC) + SparseCore(SC) 사용량이 HBM 한도를 초과한 것입니다. 두 값을 비교하여 병목 현상을 파악합니다.

  • SparseCore 사용량이 많음
    • HBM 스택 사용량 최적화: HBM 스택 메모리 소비는 feature_width, max_unique_nz_per_row, logical_replica_count에 따라 확장됩니다. 테이블 처리를 직렬화하는 --xla_sc_num_serialized_tables_to_optimize_hbm 플래그를 조정하여 최대 스택 사용량을 줄일 수 있습니다. 이로 인해 병렬 처리가 줄어듭니다.
    • 패딩 오버헤드 확인: SparseCore는 삽입 테이블을 32B (8개의 부동 소수점)에 정렬합니다. 특성 너비가 작은 테이블 (예: 부동 소수점 8개 미만)은 패딩 오버헤드가 상당하여 HBM이 낭비됩니다.
    • 힙 사용량 줄이기: maximum_parallel_iterations 값이 높으면 HBM 힙으로 미리 가져오는 입력 데이터의 양이 늘어납니다. 이 값을 낮추면 상당한 메모리를 확보할 수 있습니다.
    • 샤딩 확인: 삽입 테이블이 모든 칩에서 올바르게 mod-샤딩되었는지 확인합니다. 한도가 테이블에 어떻게 적용되는지를 참고하세요.
    • 자세한 내용은 SC: 성능 및 메모리 병목 현상을 참고하세요.
  • 높은 TensorCore 사용량
  • 균형
    • 개별적으로 과도하지는 않지만 합계가 너무 높으면 칩의 용량에 도달한 것입니다. 두 구성요소의 사용량을 모두 줄여야 합니다. 세 섹션의 권장사항을 모두 따릅니다.

상황 2. 예상치 못한 대규모 할당으로 인한 메모리 부족

'메모리 공간 HBM의 메모리가 부족합니다'라는 오류 메시지가 표시되고 하나 이상의 예기치 않게 큰 할당이 로그에 표시되면(HBM 한도의 50% 초과) 하드웨어 용량 문제가 거의 아닙니다. 일반적으로 구성 오류입니다. JAX 소스 코드에 관한 힌트를 보려면 대규모 할당의 XLA 라벨 (있는 경우)을 확인하세요.

  • 디버깅 아티팩트 삭제
    • 대규모 실행에서 jax.debug.print()를 사용하면 컴파일러가 HBM에서 전체 텐서를 구체화하여 CPU로 전송하도록 강제하여 융합이 중단되고 최대 메모리 사용량이 증가할 수 있습니다. 남은 jax.debug.print()를 삭제합니다.
  • 비효율적인 메시 모양 또는 샤딩 수정하기
    • 잘못된 메시 모양이나 누락된 샤딩 주석으로 인해 컴파일러가 기본적으로 복제를 사용하게 되어 컴파일러가 단일 칩에 매우 큰 텐서를 맞추려고 시도할 수 있습니다.
    • 큰 할당의 모양을 확인하고 샤딩이 XLA에 의해 올바르게 지정되고 전파되는지 확인합니다.

상황 3. 집계된 할당으로 인해 메모리 부족

'메모리 공간 HBM에서 메모리가 부족함'이라는 오류 메시지가 표시되고 로그에 예기치 않게 큰 텐서가 없는 경우 할당의 합계가 HBM 한도를 초과하여 프로그램의 용량이 부족한 것입니다. 이 경우 메모리 프로필을 시각화하여 최대 사용량에 기여하는 특정 버퍼를 식별하는 것이 유용한 경우가 많습니다. 최고 메모리 기여자를 식별하는 단계별 가이드는 XProf로 OOM 오류 디버그를 참고하세요.

상위 기여자를 파악한 후 다음 단계에 따라 메모리 공간을 최적화합니다.

시나리오 3.A 구성 조정

다음과 같은 구성 조정으로 OOM을 해결할 수 있는 경우가 많습니다.

  • 배치 크기 줄이기: 중간 활성화 및 그라데이션에 필요한 메모리는 배치 크기에 직접 비례합니다. 배치 크기를 줄이면 메모리 사용량을 줄일 수 있습니다.
  • 입력 버퍼 기부: jax.jit를 사용할 때는 모델 매개변수에 donate_argnums를 지정하세요. 이를 통해 XLA는 입력 메모리를 출력으로 덮어쓸 수 있습니다.
  • 혼합 정밀도 (bfloat16) 사용 설정: 모델 아키텍처와 품질 요구사항이 허용하는 경우 프로그램에서 가장 큰 텐서에 bfloat16 또는 양자화 (int8 등)를 사용합니다. 이 변경사항은 모델 동작에 영향을 미칠 수 있으므로 신중하게 고려해야 합니다.

시나리오 3.B 아키텍처 및 샤딩 최적화

구성 변경이 충분하지 않으면 모델 토폴로지가 현재 하드웨어 설정에 비해 너무 클 수 있습니다.

  • 최신 TPU 세대 사용: 최신 TPU는 일반적으로 칩당 더 많은 HBM을 제공합니다. 사용 가능한 경우 최신 TPU 세대로 전환하세요.
  • 더 큰 칩 토폴로지에서 실행: 모델 가중치가 기존 토폴로지에 비해 너무 큰 경우 더 많은 칩에 걸쳐 샤딩해 볼 수 있습니다.
  • 고급 샤딩 기법 구현:
    • 고급 데이터, 텐서 또는 파이프라인 병렬 처리 접근 방식을 살펴봅니다.
    • 중간 값과 출력에 샤딩 힌트를 지정합니다.
  • JAX 호스트 오프로딩 사용: 호스트 오프로딩 기법을 사용하면 사용자가 대형 텐서를 호스트 CPU 메모리로 오프로드할 수 있습니다 (예: 활성화 오프로딩옵티마이저 상태 오프로딩).

시나리오 3.C 텐서 패딩 및 정렬 확인

비효율적인 텐서 형태는 TPU에서 OOM이 발생하는 일반적인 원인이지만 눈에 띄지 않습니다. TPU에서 최고 성능을 얻기 위해 XLA는 텐서 차원을 패딩합니다. 일반적으로 가장 작은 차원의 경우 128의 배수, 두 번째로 작은 차원의 경우 8의 배수로 패딩합니다. 이 패딩은 입력 배열과 중간 텐서 (HLO 임시)에 모두 영향을 미치므로 특히 작은 차원 크기의 경우 메모리 사용량이 크게 증가할 수 있습니다. 배열 레이아웃을 참고하세요.

  • 대형 버퍼의 모양 감사: (기본 레이아웃이 적용된 TPU v5)
    • Xprof 메모리 뷰어에서 버퍼 위로 마우스를 가져가면 패딩 정보를 비롯한 버퍼 세부정보가 포함된 버퍼 세부정보 카드가 표시됩니다.
    • : (129, 1024) 모양이 (256, 1024)로 패딩되어 메모리 낭비가 거의 50% 에 달할 수 있습니다.
    • 수정: (128, 1024) 모양에는 패딩이 필요하지 않으며 메모리 낭비가 0% 발생합니다.
  • 차원 정렬: 모든 대형 텐서 차원 (배치 크기, 임베딩 차원, 숨겨진 크기)이 128의 배수인지 확인합니다. 이 변경사항은 모델 동작에 영향을 미칠 수 있으므로 신중하게 고려해야 합니다.

시나리오 3.D XLA에 영향을 미치는 주요 메모리 플래그 조정

주요 메모리 플래그를 조정하여 성능과 메모리 사용량 간의 균형을 맞출 수 있습니다. 하지만 이 전략은 성능에 부정적인 영향을 미칠 수 있으므로 최후의 수단으로 사용해야 합니다.

시나리오 3.E XLA 재구체화 패스/수동 체크포인트 조정

모델이 메모리에 거의 맞지 않는 경우 jax.grad와 함께 jax.checkpoint 데코레이터를 사용하여 순방향 패스에서 저장되는 중간값과 역방향 패스에서 다시 계산되는 중간값을 수동으로 제어하여 HBM에 대한 컴퓨팅 주기를 교환할 수 있습니다.

또는 XLA::Rematerialization 패스를 강제 실행하여 메모리 절약을 우선시할 수 있습니다. 이 경우 컴파일 속도가 느려질 수 있습니다.

플래그 설명 영향 / 절충
--xla_tpu_max_hbm_size_mib 리매터리얼라이제이션 패스에서 사용하는 HBM 크기의 한도를 수동으로 설정합니다. 컴파일러가 실제 물리적 HBM보다 작은 제한에 프로그램을 맞추기 위해 더 열심히 작업하도록 강제합니다.
--xla_tpu_rematerialization_algo=PEAK_PRIORITY 최고 메모리 사용량 지점에 집중합니다. 기본 알고리즘보다 적극적인 메모리 감소에 더 효율적일 수 있습니다.
--xla_tpu_rematerialization_max_block_size_limit=32 한 번에 다시 구체화할 수 있는 블록의 최대 명령어 수를 제어합니다. 이 값을 늘리면 컴파일 시간이 크게 증가하는 대신 메모리를 절약할 수 있습니다.
--xla_tpu_rematerialization_block_effort_factor=10.0 다시 구체화할 블록을 검색하는 데 소요되는 노력 (컴파일 시간)의 양을 정의합니다. 값이 높을수록 컴파일 시간이 늘어나는 대신 메모리 절약을 위한 더 철저한 검색이 가능합니다.
--xla_tpu_pre_fusion_remat=true 융합 패스 에 추가 리매터리얼라이제이션 패스를 사용 설정합니다. 메모리 절약 효과가 더 클 수 있지만 컴파일 시간이 늘어나고 수치 안정성에 영향을 줄 수 있습니다.

XLA 플래그를 변경하면 성능에 부정적인 영향을 미칠 수 있으므로 최후의 수단으로 사용해야 합니다.

시나리오 3.F 고급 프로파일링 도구 사용

XProf로 OOM 오류 디버그에서는 XProf 메모리 뷰어를 사용하여 HBM 사용에 대한 컴파일러의 뷰를 시각화하는 방법을 설명하는 튜토리얼을 제공합니다.

이 도구를 사용하면 최고 활용 시점에 HBM을 소비하는 항목을 정확히 파악하는 데 중요한 최고 메모리 할당 및 버퍼 수명을 확인할 수 있습니다. 일반 프로파일링 설정은 Xprof 시작하기TensorBoard 프로파일링을 참고하세요.