XLA 開發工作流程通常以HLO IR 為中心,代表提供給編譯器的獨立功能運算。XLA 隨附多種指令列工具 (如下所述),這些工具會使用 HLO,並執行 HLO 或提供中繼編譯階段。這類工具對於快速疊代週期來說非常寶貴,因為 HLO 可視化且可駭入,反覆變更及執行 HLO 通常是瞭解及修正 XLA 效能或行為的最快方式。compile->modify->run
如要取得透過 XLA 編譯的程式 HLO,最簡單的方法通常是使用 XLA_FLAGS 環境變數:
$ XLA_FLAGS=--xla_dump_to=/tmp/myfolder ./myprogram-entry-point
這個資料夾會儲存所有最佳化前的 HLO 檔案,以及許多其他實用構件。
[run_hlo_module] 執行 HLO 模組
bazel run //xla/tools:run_hlo_module -- [flags] <filename>
這項工具 run_hlo_module 會在最佳化前 HLO 上運作,並預設會將編譯、執行和比較與參考解譯器實作項目捆綁在一起。舉例來說,在 NVIDIA GPU 上執行輸入檔案 computation.hlo 並檢查正確性的通常叫用方式如下:
run_hlo_module --platform=CUDA --reference_platform=Interpreter computation.hlo
執行多個 HLO 模組
run_hlo_module 支援使用多個 HLO 模組進行呼叫。如要從目錄執行所有 HLO 模組,請使用下列指令:
bazel run //xla/tools:run_hlo_module -- [flags] /dump/*before_optimizations*
[multihost_hlo_runner] Run HLO Modules With SPMD Support
# Note: Binary name is `hlo_runner_main`.
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] <filename>
多主機 HLO 執行器是類似的工具,但支援 SPMD,包括跨主機通訊。詳情請參閱「多主機 HLO 執行器」。
執行多個 HLO 模組 (支援 SPMD)
與 run_hlo_module 類似,multihost_hlo_runner 也支援使用多個模組進行呼叫。
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] /dump/*before_optimizations*
[hlo-opt] 編譯 HLO 模組
bazel run //xla/tools:hlo-opt -- --platform=[gpu|cpu|...] [more flags] <filename>
在偵錯或瞭解編譯器運作方式時,通常會取得特定硬體在管線特定時間點的擴充功能 (無論是 HLO、最佳化 HLO、TritonIR 或 LLVM),適用於特定 HLO 或 StableHLO 輸入。
hlo-opt 支援多個輸出階段:PTX、最佳化後的 HLO、最佳化前的 LLVM IR 或 TritonIR。支援的確切階段組合取決於平台 (例如 PTX 是 NVIDIA 專屬),可使用 --list-stages 指令查看:
hlo-opt --platform=CUDA --list-stages
buffer-assignment
hlo
hlo-backend
html
llvm
llvm-after-optimizations
llvm-before-optimizations
ptx
選取階段後,使用者可以將特定平台的轉換結果寫入特定串流:
hlo-opt --platform=cpu --stage=hlo input.hlo
這會將傾印內容列印至標準輸出 (或指定檔案,如果已指定 -o)。
GPU 無裝置編譯
無裝置編譯不需要存取 GPU。無裝置編譯功能可讓您在命令列 (--xla_gpu_target_config_filename) 上指定 GPU 規格,適用於需要存取 GPU 的階段,因此不需要 GPU 裝置。
範例:無法存取 GPU 裝置的 PTX 輸出內容:
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb input.hlo
熱門 GPU 的規格會隨編譯器一併出貨,而提供的檔案是 device_description.proto 的字串序列化:
gpu_device_info {
cuda_compute_capability {
major: 8
minor: 0
}
threads_per_block_limit: 1024
threads_per_warp: 32
shared_memory_per_block: 127152
shared_memory_per_core: 65536
threads_per_core_limit: 2048
core_count: 6192
fpus_per_core: 64
block_dim_limit_x: 2147483647
block_dim_limit_y: 65535
block_dim_limit_z: 65535
memory_bandwidth: 2039000000000
l2_cache_size: 4194304
clock_rate_ghz: 1.1105
device_memory_size: 79050250240
}
platform_name: "CUDA"
如需更多 GPU 規格,請前往 /xla/tools/hlo_opt/gpu_specs
自動調整
有時,彙整作業可能涉及根據彙整內容 --stage 自動調音。
如要讓無裝置編譯作業正常運作,使用者必須
停用自動調整功能 (使用 --xla_gpu_autotune_level=0
),或
載入預先存在的自動調整結果 (使用 --xla_gpu_load_autotune_results_from=<filename>,透過 --xla_gpu_dump_autotune_results_to=<filename> 取得)。
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=gpu_specs/a100_pcie_80.txtpb --xla_gpu_load_autotune_results_from=results.textpb input.hlo
自動調整檔案是 autotune_results.proto 的文字序列化,範例如下:
version: 3
results {
device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
result {
run_time {
nanos: 31744
}
triton {
block_m: 32
block_n: 32
block_k: 32
split_k: 1
num_stages: 1
num_warps: 4
}
}
}
自動調整資料庫可使用 XLA_FLAGS=--xla_gpu_dump_autotune_results_to=<myfile.pbtxt> 序列化
[hlo-opt] HLO Pass 開發與偵錯
# If you are working with hardware independent passes from the
# `xla/hlo/transforms/` directory, prefer light-weight version
# of the `hlo-opt` tool with fewer dependencies:
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
# Otherwise, for hardware independent and CPU, GPU passes use
# the same binary from "Compile HLO Modules" section above:
bazel run //xla/tools:hlo-opt -<- [flags>] filename
hlo-opt 工具可獨立於指定平台編譯階段執行個別傳遞作業。這種隔離方式有助於快速對輸入 HLO 模組執行傳遞作業,並找出失敗的根本原因。
hlo-opt --passes=schedule-aware-collective-cse input.hlo
hlo-opt 工具也支援 DebugOptions XLA_FLAGS。
hlo-opt --passes=schedule-aware-collective-cse
--xla_gpu_experimental_collective_cse_distance_threshold=20 input.hlo
使用 --list-passes 選項取得憑證名稱字串。
hlo-opt --list-passes
使用者可以指定多個傳遞至 --passes 選項,建立自己的自訂管道。
hlo-opt --passes=pass1,pass2,pass3 input.hlo
協助開發新的 HLO Pass
- 首先,請撰寫通行證。
將新憑證註冊至
hlo-opt工具憑證登錄檔。RegisterPass<FooPass>(FooPassInputOptions)請根據憑證類型,選擇下列其中一個註冊地點:
opt_lib.cc與硬體無關的憑證。
cpu_opt.ccCPU 專屬傳遞。
gpu_opt.ccGPU 專用通道。
compiled_opt.ccCPU、GPU、XPU 通用的傳遞。
別忘了新增建構依附元件。在 PR 中加入憑證註冊程序(範例),讓所有
hlo-opt使用者都能使用憑證。重建
hlo-opt工具,使用--list-passes選項驗證是否已成功註冊票證,然後使用--passes選項執行票證。$ hlo-opt --passes=foo-pass input.hlo如要編寫傳遞的單元測試,請參閱 https://openxla.org/xla/test_hlo_passes,瞭解詳情。
通過執行階段評估
如果是大型模型,完整編譯作業可能需要幾分鐘,因此很難偵測到細微的效能回歸。相較之下,使用 hlo-opt 執行的個別傳遞作業可精確評估效能,並輕鬆偵測到新程式碼變更導致的執行時間增加 (即使增加幅度很小)。
time hlo-opt --passes=reduce-window-rewriter,scatter_simplifier
--xla_reduce_window_rewrite_base_length=128 input.hlo
[hlo-opt] Convert HLO Module Formats
# Use the light weight version of the `hlo-opt` tool.
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
轉換 HLO Text -> HLO Proto
hlo-opt --emit-proto input.hlo
轉換 HLO Proto 或 HLO Proto Binary -> HLO Text
hlo-opt input.pbtxt or input.pb
[ptx-opt] 將編譯器 LLVM 模組向下編譯至 PTX
這項工具會執行 LLVMIR 最佳化管道,然後呼叫 CompileToPtx。
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 <filename>
這個工具也可以在每個路徑後傾印 LLVMIR。
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 --xla_dump_to=<path> --xla_gpu_dump_llvmir <filename>