XLA 개발 워크플로는 일반적으로 컴파일러에 제공되는 격리된 함수형
계산을 나타내는
HLO IR을 중심으로 진행됩니다. XLA에는 HLO를 사용하고 이를 실행하거나 중간 컴파일 단계를 제공하는 여러 명령줄 도구(아래 설명 참고)가 함께 제공됩니다. 이러한 도구를 사용하면 빠른
compile->modify->run 반복 주기에 매우 유용합니다. HLO는 시각화 및
해킹이 가능하며, 이를 반복적으로 변경하고 실행하는 것이 XLA 성능 또는 동작을 이해하고 수정하는 가장 빠른 방법인 경우가 많기 때문입니다.
XLA로 컴파일되는 프로그램의 HLO를 얻는 가장 쉬운 방법은 일반적으로 XLA_FLAGS 환경 변수를 사용하는 것입니다.
$ XLA_FLAGS=--xla_dump_to=/tmp/myfolder ./myprogram-entry-point
이 변수는 다른 유용한 아티팩트와 함께 지정된 폴더에 최적화 전의 모든 HLO 파일을 저장합니다.
[run_hlo_module] HLO 모듈 실행
bazel run //xla/tools:run_hlo_module -- [flags] <filename>
run_hlo_module 도구는 최적화 전 HLO에서 작동하며 기본적으로 컴파일, 실행, 참조 인터프리터 구현과의 비교를 번들로 제공합니다. 예를 들어 NVIDIA GPU에서 입력 파일 computation.hlo를 실행하고 올바른지 확인하는 일반적인 호출은 다음과 같습니다.
run_hlo_module --platform=CUDA --reference_platform=Interpreter computation.hlo
여러 HLO 모듈 실행
run_hlo_module에서는 여러 HLO 모듈을 사용한 호출이 지원됩니다. 디렉터리의 모든 hlo 모듈을 실행하려면 다음을 실행합니다.
bazel run //xla/tools:run_hlo_module -- [flags] /dump/*before_optimizations*
[multihost_hlo_runner] SPMD 지원으로 HLO 모듈 실행
# Note: Binary name is `hlo_runner_main`.
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] <filename>
멀티호스트 HLO 실행기는 호스트 간 통신을 포함하여 SPMD를 지원한다는 점을 제외하면 매우 유사한 도구입니다. 자세한 내용은 멀티호스트 HLO 실행기를 참고하세요.
SPMD 지원으로 여러 HLO 모듈 실행
run_hlo_module과 마찬가지로 multihost_hlo_runner도 여러 모듈을 사용한 호출을 지원합니다.
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] /dump/*before_optimizations*
[hlo-opt] HLO 모듈 컴파일
bazel run //xla/tools:hlo-opt -- --platform=[gpu|cpu|...] [more flags] <filename>
컴파일러의 작동 방식을 디버깅하거나 이해할 때는 특정 HLO 또는 StableHLO 입력에 대해 파이프라인의 특정 시점 (HLO, 최적화된 HLO, TritonIR 또는 LLVM)에서 특정 하드웨어의 확장을 가져오는 것이 유용한 경우가 많습니다.
hlo-opt 는 최적화 후의 PTX, HLO, 최적화 전의 LLVM IR 또는 TritonIR 등 여러 출력 단계를 지원합니다. 지원되는 정확한 단계 집합은 플랫폼에 따라 다르며 (예: PTX는 NVIDIA 전용) --list-stages 명령어를 사용하여 확인할 수 있습니다.
hlo-opt --platform=CUDA --list-stages
buffer-assignment
hlo
hlo-backend
html
llvm
llvm-after-optimizations
llvm-before-optimizations
ptx
단계를 선택한 후 사용자는 특정 플랫폼의 변환 결과를 특정 스트림에 쓸 수 있습니다.
hlo-opt --platform=cpu --stage=hlo input.hlo
그러면 덤프가 stdout에 출력되거나 -o가 지정된 경우 지정된 파일에 출력됩니다.
GPU용 기기 없는 컴파일
기기 없는 컴파일에는 GPU에 대한 액세스 권한이 필요하지 않습니다. 기기 없는 컴파일은 GPU에 대한 액세스 권한이 필요한 단계에서 명령줄(--xla_gpu_target_config_filename)에 GPU 사양을 지정하는 방법을 제공하여 GPU 기기의 필요성을 없애줍니다.
예: GPU 기기에 액세스하지 않고 PTX 출력:
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb input.hlo
일반적인 GPU 사양은 컴파일러와 함께 제공되며 제공된 파일은 device_description.proto의 문자열 직렬화입니다.
gpu_device_info {
cuda_compute_capability {
major: 8
minor: 0
}
threads_per_block_limit: 1024
threads_per_warp: 32
shared_memory_per_block: 127152
shared_memory_per_core: 65536
threads_per_core_limit: 2048
core_count: 6192
fpus_per_core: 64
block_dim_limit_x: 2147483647
block_dim_limit_y: 65535
block_dim_limit_z: 65535
memory_bandwidth: 2039000000000
l2_cache_size: 4194304
clock_rate_ghz: 1.1105
device_memory_size: 79050250240
}
platform_name: "CUDA"
더 많은 GPU 사양은 /xla/tools/hlo_opt/gpu_specs에 있습니다.
자동 조정
컴파일에는 컴파일 --stage를 기반으로 자동 조정이 포함될 수 있습니다.
기기 없는 컴파일이 작동하려면 사용자가
자동 조정을 사용 중지하거나 --xla_gpu_autotune_level=0
기존 자동 조정 결과를 로드해야 합니다.
--xla_gpu_load_autotune_results_from=<filename> (--xla_gpu_dump_autotune_results_to=<filename>으로 가져옴)
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=gpu_specs/a100_pcie_80.txtpb --xla_gpu_load_autotune_results_from=results.textpb input.hlo
자동 조정 파일은 autotune_results.proto의 텍스트 직렬화이며 예는 다음과 같습니다.
version: 3
results {
device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
result {
run_time {
nanos: 31744
}
triton {
block_m: 32
block_n: 32
block_k: 32
split_k: 1
num_stages: 1
num_warps: 4
}
}
}
자동 조정 데이터베이스는
XLA_FLAGS=--xla_gpu_dump_autotune_results_to=<myfile.pbtxt>를 사용하여 직렬화할 수 있습니다.
[hlo-opt] HLO 패스 개발 및 디버깅
# If you are working with hardware independent passes from the
# `xla/hlo/transforms/` directory, prefer light-weight version
# of the `hlo-opt` tool with fewer dependencies:
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
# Otherwise, for hardware independent and CPU, GPU passes use
# the same binary from "Compile HLO Modules" section above:
bazel run //xla/tools:hlo-opt -- [flags] <filename>
hlo-opt 도구를 사용하면 지정된 플랫폼 컴파일 단계와 독립적으로 개별 패스를 실행할 수 있습니다. 이 격리를 통해 입력 hlo 모듈에서 패스를 빠르게 실행하고 실패의 근본 원인을 파악할 수 있습니다.
hlo-opt --passes=schedule-aware-collective-cse input.hlo
hlo-opt 도구는 DebugOptions XLA_FLAGS도 지원합니다.
hlo-opt --passes=schedule-aware-collective-cse
--xla_gpu_experimental_collective_cse_distance_threshold=20 input.hlo
--list-passes 옵션을 사용하여 패스 이름 문자열을 가져옵니다.
hlo-opt --list-passes
사용자는 --passes 옵션에 두 개 이상의 패스를 지정하여 자체 커스텀 파이프라인을 만들 수 있습니다.
hlo-opt --passes=pass1,pass2,pass3 input.hlo
새 HLO 패스 개발 지원
- 먼저 패스를 작성합니다.
hlo-opt도구 패스 레지스트리에 새 패스를 등록합니다.RegisterPass<FooPass>(FooPassInputOptions)패스 유형에 따라 등록을 위해 다음 위치 중 하나를 선택합니다.
opt_lib.cc하드웨어 독립 패스
cpu_opt.ccCPU 특정 패스
gpu_opt.ccGPU 특정 패스
compiled_opt.ccCPU, GPU, XPU에 공통적인 패스
빌드 종속 항목을 추가하는 것을 잊지 마세요.패스를 모든
hlo-opt사용자가 사용할 수 있도록 패스 등록을 PR(예)의 일부로 포함합니다.hlo-opt도구를 다시 빌드하고--list-passes옵션을 사용하여 패스 등록이 성공했는지 확인한 다음--passes옵션을 사용하여 패스를 실행합니다.$ hlo-opt --passes=foo-pass input.hlo패스에 대한 단위 테스트를 작성하시나요? 자세한 내용은 https://openxla.org/xla/test_hlo_passes를 참고하세요.
패스 런타임 측정
대규모 모델의 경우 전체 컴파일 실행에 몇 분이 걸릴 수 있으므로 미묘한 성능 회귀를 감지하기가 어렵습니다. 반면 hlo-opt를 사용하는 개별 패스 실행을 사용하면 정확한 성능 측정이 가능하며 새로운 코드 변경으로 인해 실행 시간이 약간 증가하는 경우에도 쉽게 감지할 수 있습니다.
time hlo-opt --passes=reduce-window-rewriter,scatter_simplifier
--xla_reduce_window_rewrite_base_length=128 input.hlo
[hlo-opt] HLO 모듈 형식 변환
# Use the light weight version of the `hlo-opt` tool.
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
HLO Text -> HLO Proto 변환
hlo-opt --emit-proto input.hlo
HLO Proto 또는 HLO Proto Binary -> HLO Text 변환
hlo-opt input.pbtxt or input.pb
[ptx-opt] 컴파일러 LLVM 모듈을 PTX로 다운컴파일
이 도구는 LLVMIR 최적화 파이프라인을 실행한 다음 CompileToPtx를 호출합니다.
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 <filename>
이 도구는 모든 경로 후에 LLVMIR을 덤프할 수도 있습니다.
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 --xla_dump_to=<path> --xla_gpu_dump_llvmir <filename>
[isolate_hlo] 문제 있는 HLO 명령어 격리
대규모 HLO 덤프가 있고 HLO 모듈 내의 특정 명령어 또는 섹션이 비정상 종료를 일으키는 것으로 의심되는 경우 isolate_hlo 도구를 사용할 수 있습니다.
이 도구는 단일 HLO 명령어 (및 필요한 컨텍스트)를 더 작은 새 HLO 모듈로 추출합니다. 이는 최소한의 컴파일러 수준 재현기를 만드는 데 매우 유용합니다.
- 문서 및 소스:
isolate_hlo도구는 OpenXLA 저장소에서 사용할 수 있습니다. XLA 소스 코드의xla/tools디렉터리를 참고하세요. 사용: XLA 소스 트리에서 도구를 빌드합니다. 일반적으로 입력 HLO 모듈 파일 (텍스트 또는 프로토), 추출할 명령어 이름, 출력 파일 경로를 사용합니다.
# Example usage after building XLA: # ./build/tools/isolate_hlo --input=module.hlo --instruction_name=fusion.123 \ # --output=isolated_fusion.123.hlo --input_format=txt --output_format=long_txt특정 플래그 및 형식 옵션은 도구의 도움말 메시지 (
--help)를 참고하세요.