XProf로 OOM 오류 디버그

메모리 부족 (OOM) 오류는 액셀러레이터 (GPU 또는 TPU)의 고대역폭 메모리 (HBM) 용량이 소진될 때 발생합니다. OOM 문제의 일반적인 원인과 디버깅 기법은 E1000 - 컴파일 시간 HBM OOM 문서GPU 메모리 할당에 관한 JAX 문서에 자세히 설명되어 있습니다.

이 페이지에서는 XProf의 메모리 뷰어 도구를 사용하여 JAX 프로그램의 메모리 사용량을 시각화하고, 최고 사용량 인스턴스를 식별하고, OOM 오류를 디버깅하는 방법을 설명합니다. 이 작업은 다음과 같은 단계로 진행됩니다.

  1. jax.profiler.trace로 프로그램을 실행하여 프로필을 캡처합니다.
  2. 백그라운드에서 XProf를 시작하고 메모리 뷰어 도구를 사용하여 메모리 사용량 세부정보를 확인합니다.

예제 프로그램

다음 JAX 프로그램은 OOM 오류를 발생시킵니다.

import jax
from jax import random
import jax.numpy as jnp


@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
    a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
    return a @ a


if __name__ == "__main__":
    oom()

TPU 머신에서 이 프로그램은 다음과 같이 실패합니다.

XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>

(GPU 머신에서는 오류가 XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 214748364800 bytes.와 같이 표시됩니다.)

XProf 실행

xprof (pip install xprof)를 설치하고 프로필이 저장된 디렉터리를 지정하여 XProf 인스턴스를 시작합니다.

xprof --logdir=/tmp/xprof/ --port=6006

인스턴스 (로컬 머신, http://localhost:6006)로 이동합니다. 도구 드롭다운에서 메모리 뷰어를 선택하고 메모리 뷰어 도구 창의 메모리 유형 드롭다운에서 HBM을 선택합니다 (일반적으로 기본적으로 선택됨).

위 예시 프로그램의 XProf 메모리 뷰어 페이지

XProf: 메모리 뷰어 도구 문서에서는 도구의 구성요소와 표시되는 정보를 설명합니다.

최고 메모리 사용량 지점에서 세 개의 버퍼 차트를 보여주는 최고 메모리 할당 시 HLO 작업 섹션에 집중합니다. 버퍼에는 다음이 포함됩니다. * 프로그램 입력 및 출력: 학습 배치, 옵티마이저 상태 등 * TensorCore 및 SparseCore 임시: 중간 계산 (예: 활성화, 기울기 등)에 필요한 동적 메모리

버퍼 차트 위로 마우스를 가져가면 크기, 모양, 할당 유형 등 작업에 관한 자세한 내용을 확인할 수 있습니다. 이를 통해 일시적인 요소가 많거나 오래 지속되는 작업, 비효율적인 패딩이 있는 대규모 입력/중간/출력 텐서 등 최대 메모리에 영향을 미치고 조정하거나 최적화해야 하는 요소를 식별하고 평가할 수 있습니다.

E1000: 디버깅에서 구체적인 디버깅 기법을 알아보세요.