HLO 덤프는 계산의 다양한 단계에서 HLO 모듈을 텍스트로 표현한 것입니다. 디버깅에 유용하며 버그 신고에 포함해야 하는 경우가 많습니다. 일반적으로 HLO 명령어와 속성을 나열하는 사람이 읽을 수 있는 텍스트 파일입니다. HLO 모듈은 다음과 같이 덤프되기도 합니다.
- HloProto: 더 구조화된 기계가 읽을 수 있는 형식인 프로토콜 버퍼 파일입니다.
- HloSnapshot: HLO 모듈과 입력 HLO를 재생하려면 무작위 데이터가 아닌 특정 계산에 제공된 실제 입력이 필요한 경우가 있습니다.
XLA 플래그를 사용하여 덤프를 지정하고 가져올 수 있습니다. 대부분의 경우 환경 변수를 사용하여 설정할 수 있습니다. JAX는 HLO 덤프를 출력하는 프로그래매틱 방식도 제공합니다.
로컬 실행
환경 변수 사용
덤프를 가져오는 데 필요한 플래그를 사용하여 XLA_FLAGS 환경 변수를 설정할 수 있습니다. 이는 JAX, TensorFlow, PyTorch/XLA에 적용됩니다.
HLO 모듈과 기타 디버깅 정보를 특정 디렉터리에 덤프하려면 --xla_dump_to 플래그를 사용하여 프로그램을 실행하세요.
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"
예를 들어 /tmp 또는 /tmp/xladump을 경로로 사용할 수 있습니다.
기본적으로 이는 최적화 파이프라인의 시작과 끝에서 HLO 모듈을 텍스트로 덤프합니다.
형식을 명시적으로 지정할 수도 있습니다.
- 텍스트 덤프
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
- HLO 프로토콜
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
- HLO 스냅샷
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
- graphviz 서버를 사용한 그래프 렌더링 (작은 그래프에만 적합)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
- 그래프를 HTML 파일로 렌더링 (작은 그래프에만 적합)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"
더 큰 그래프의 경우 interactive_graphviz를 사용하여 그래프의 일부를 시각화할 수 있습니다.
특정 중간 패스 덤프
표준 사전 최적화 / 최종 최적화 HLO 외에도 특정 컴파일러 패스 후 HLO의 상태를 덤프할 수 있습니다.
XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"
이름이 정규 표현식 (regex)과 일치하는 패스의 HLO 모듈이 덤프됩니다. 예를 들어 SPMD 파티셔닝과 관련된 패스로 인해 발생하는 HLO는 다음을 사용하여 관찰할 수 있습니다.
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"
모든 XLA 패스 후에 결과를 덤프하려면 (파일이 많이 생성됨) 다음을 설정하면 됩니다.
XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"
JAX 관련 옵션
JAX에서 프로그래매틱 방식
플래그나 환경 변수를 전달하는 대신 JAX의 lower 및 compile API를 사용하여 HLO를 프로그래매틱 방식으로 덤프할 수도 있습니다.
다음 명령어를 사용하여 최적화되지 않은 원래 하위 HLO를 로컬로 가져옵니다.
jax.jit(f).lower(*args).as_text('hlo')
HLO 컴파일 전달 중에 파일로 덤프하려면 다음을 지정하세요.
compilation_args = {
'xla_dump_to': DIRECTORY_PATH,
'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
...
}
jax.jit(f).lower(*args).compile(compilation_args)
jaxpr 덤프
jaxprs는 프로그램 트레이스에 대한 JAX의 중간 표현입니다. 이를 덤프하려면 다음 환경 변수를 설정하세요.
JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr
스테이징된 계산 내보내기 및 직렬화: 디버깅에 관한 JAX 문서를 참고하세요.
Google Colab
환경 변수
노트북의 첫 번째 실행 셀에서 (환경 변수와 명령줄 플래그는 일반적으로 모듈 가져오기 시간이나 XLA 백엔드 초기화 시간 등 한 번만 처리되므로) 위에 설명된 XLA_FLAGS를 os.environ와 함께 추가합니다.예를 들면 다음과 같습니다.
import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"
이렇게 하면 계산이 DIRECTORY_PATH(예: /tmp)에 덤프됩니다. Colab에서는 왼쪽 사이드바의 '파일' 브라우저로 이동하여 이 디렉터리를 확인하고 액세스합니다.
로컬 실행 섹션에 언급된 모든 플래그를 사용할 수 있습니다.
JAX 관련 옵션
로컬 실행과 유사합니다. 실시간 대화형 인트로스펙션의 경우 최적화된 HLO를 직접 출력할 수 있습니다.
def f(x):
return jax.numpy.sin(jax.numpy.cos(x))
c = jax.jit(f).lower(3.).compiler_ir('hlo')
print(c.as_hlo_text())
계산의 최적화된 HLO를 직접 인쇄할 수도 있습니다.
def optimized_HLO(f, *args, platform=None):
print(jax.jit(f).lower(*args).compile().as_text())
def f(x):
return jax.numpy.sin(jax.numpy.cos(x))
optimized_HLO(f, 1.0)
모든/소규모 계산 덤프
모든 소규모 컴파일을 포함하여 덤프의 모든 항목을 보려면 JAX 환경 변수를 설정하세요.
JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0
모자이크
Mosaic은 Pallas TPU 백엔드와 실험용 Pallas GPU 백엔드용 컴파일러입니다. 모자이크 계산을 덤프하려면 다음 플래그를 설정하세요.
--xla_mosaic_dump_to=/tmp/mosaic_dumps
또는 TPU 초기화 인수를 환경 변수로 설정합니다.
export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"
자세한 내용은 Pallas 및 Mosaic에 관한 JAX 문서를 참고하세요.
HLO 덤프 더보기
올바른 계산 찾기
일반적으로 많은 계산이 덤프됩니다. 덤프된 파일은 로그에 표시된 JAX, TensorFlow 또는 PyTorch/XLA '계산 이름'으로 명시적으로 이름이 지정되므로 관련 HLO 파일을 쉽게 식별할 수 있습니다. 예를 들면 다음과 같습니다.
1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt
그렇지 않으면 ripgrep를 사용하여 특정 기호나 계산을 보유한 모듈을 빠르게 식별할 수 있습니다.
도움말: 버그 신고에 관심 있는 할당 전/후/버퍼의 덤프된 파일 3개를 포함하세요.
HLO 변환
HLOProto와 텍스트 형식 간에 변환할 수 있는 hlo-opt라는 도구
한 형식은 있지만 디버깅을 위해 다른 형식이 필요한 경우에 유용합니다.
사용 방법 알아보기: XLA 도구 문서: hlo-opt
다시 재생
덤프된 계산을 지정된 XLA 백엔드에서 가짜 데이터 또는 입력 스냅샷으로 실행 (재생)할 수 있습니다. 이는 XLA에서 문제를 재현하고, 반복하고, 디버그하는 편리한 방법입니다.
다음 명령어는 가짜 데이터를 사용합니다. HLO 스냅샷을 저장한 경우 이를 대신 전달할 수 있으며 스냅샷의 데이터가 사용됩니다. 스냅샷을 실행하는 동안에도 가짜 데이터를 사용하려면 --force_fake_data 플래그를 전달하세요.
CPU 백엔드:
bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
/tmp/xladump/module_4561.before_optimizations.txt
GPU 백엔드:
bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
/tmp/xladump/module_4561.before_optimizations.txt