메모리 부족 (OOM) 오류는 액셀러레이터 (GPU 또는 TPU)의 고대역폭 메모리 (HBM) 용량이 소진될 때 발생합니다. OOM 문제의 일반적인 원인과 디버깅 기법은 E1000 - 컴파일 시간 HBM OOM 문서 및 GPU 메모리 할당에 관한 JAX 문서에 자세히 설명되어 있습니다.
이 페이지에서는 XProf의 메모리 뷰어 도구를 사용하여 JAX 프로그램의 메모리 사용량을 시각화하고, 최고 사용량 인스턴스를 식별하고, OOM 오류를 디버깅하는 방법을 설명합니다. 이 작업은 다음과 같은 단계로 진행됩니다.
jax.profiler.trace로 프로그램을 실행하여 프로필을 캡처합니다.- 백그라운드에서 XProf를 시작하고 메모리 뷰어 도구를 사용하여 메모리 사용량 세부정보를 확인합니다.
예제 프로그램
다음 JAX 프로그램은 OOM 오류를 발생시킵니다.
import jax
from jax import random
import jax.numpy as jnp
@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
return a @ a
if __name__ == "__main__":
oom()
TPU 머신에서 이 프로그램은 다음과 같이 실패합니다.
XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>
(GPU 머신에서는 오류가 XlaRuntimeError: RESOURCE_EXHAUSTED:
Out of memory while trying to allocate 214748364800 bytes.와 같이 표시됩니다.)
XProf 실행
xprof (pip install xprof)를 설치하고 프로필이 저장된 디렉터리를 지정하여 XProf 인스턴스를 시작합니다.
xprof --logdir=/tmp/xprof/ --port=6006
인스턴스 (로컬 머신, http://localhost:6006)로 이동합니다. 도구 드롭다운에서 메모리 뷰어를 선택하고 메모리 뷰어 도구 창의 메모리 유형 드롭다운에서 HBM을 선택합니다 (일반적으로 기본적으로 선택됨).

XProf: 메모리 뷰어 도구 문서에서는 도구의 구성요소와 표시되는 정보를 설명합니다.
최고 메모리 사용량 지점에서 세 개의 버퍼 차트를 보여주는 최고 메모리 할당 시 HLO 작업 섹션에 집중합니다. 버퍼에는 다음이 포함됩니다. * 프로그램 입력 및 출력: 학습 배치, 옵티마이저 상태 등 * TensorCore 및 SparseCore 임시: 중간 계산 (예: 활성화, 기울기 등)에 필요한 동적 메모리
버퍼 차트 위로 마우스를 가져가면 크기, 모양, 할당 유형 등 작업에 관한 자세한 내용을 확인할 수 있습니다. 이를 통해 일시적인 요소가 많거나 오래 지속되는 작업, 비효율적인 패딩이 있는 대규모 입력/중간/출력 텐서 등 최대 메모리에 영향을 미치고 조정하거나 최적화해야 하는 요소를 식별하고 평가할 수 있습니다.
E1000: 디버깅에서 구체적인 디버깅 기법을 알아보세요.