XLA 开发工作流通常以 HLO IR 为中心,它表示提供给编译器的隔离功能计算。XLA 随附了多个命令行工具(如下所述),这些工具会使用 HLO 并运行它,或者提供中间编译阶段。使用此类工具对于实现快速 compile->modify->run 迭代周期非常宝贵,因为 HLO 既可直观呈现,也可进行破解,并且以迭代方式更改和运行 HLO 通常是了解和修复 XLA 性能或行为的最快方式。
获取使用 XLA 编译的程序的 HLO 的最简单方法通常是使用 XLA_FLAGS 环境变量:
$ XLA_FLAGS=--xla_dump_to=/tmp/myfolder ./myprogram-entry-point
该命令会将所有优化前的 HLO 文件存储在指定文件夹中,以及许多其他有用的制品。
[run_hlo_module] 运行 HLO 模块
bazel run //xla/tools:run_hlo_module -- [flags] <filename>
工具 run_hlo_module 可对优化前的 HLO 进行操作,默认情况下,它会将编译、运行和与参考解释器实现进行比较捆绑在一起。例如,在 NVIDIA GPU 上运行输入文件 computation.hlo 并检查其正确性的常规调用如下:
run_hlo_module --platform=CUDA --reference_platform=Interpreter computation.hlo
运行多个 HLO 模块
run_hlo_module 支持使用多个 HLO 模块进行调用。如需运行目录中的所有 HLO 模块,请执行以下操作:
bazel run //xla/tools:run_hlo_module -- [flags] /dump/*before_optimizations*
[multihost_hlo_runner] 运行支持 SPMD 的 HLO 模块
# Note: Binary name is `hlo_runner_main`.
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] <filename>
多主机 HLO Runner 是一种非常类似的工具,但它支持 SPMD,包括跨主机通信。如需了解详情,请参阅多主机 HLO Runner。
运行多个支持 SPMD 的 HLO 模块
与 run_hlo_module 类似,multihost_hlo_runner 也支持使用多个模块进行调用。
bazel run //xla/tools/multihost_hlo_runner:hlo_runner_main -- [flags] /dump/*before_optimizations*
[hlo-opt] 编译 HLO 模块
bazel run //xla/tools:hlo-opt -- --platform=[gpu|cpu|...] [more flags] <filename>
在调试或了解编译器的工作原理时,通常需要获取流水线中特定时间点(无论是 HLO、优化后的 HLO、TritonIR 还是 LLVM)针对特定硬件的扩展,以用于给定的 HLO 或 StableHLO 输入。
hlo-opt 支持多个输出阶段:PTX、优化后的 HLO、优化前的 LLVM IR 或 TritonIR。支持的确切阶段集取决于平台(例如,PTX 是 NVIDIA 特有的),可以使用 --list-stages 命令查看:
hlo-opt --platform=CUDA --list-stages
buffer-assignment
hlo
hlo-backend
html
llvm
llvm-after-optimizations
llvm-before-optimizations
ptx
选择阶段后,用户可以将给定平台的转化结果写入给定数据流:
hlo-opt --platform=cpu --stage=hlo input.hlo
这会将 dump 输出到标准输出(如果指定了 -o,则输出到指定的文件)。
GPU 的无设备编译
无设备编译不需要访问 GPU。无设备编译提供了一种在命令行 (--xla_gpu_target_config_filename) 上指定 GPU 规范的方法,用于需要访问 GPU 的阶段,从而无需使用 GPU 设备。
示例:没有 GPU 设备访问权限的 PTX 输出:
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=/xla/tools/hlo_opt/gpu_specs/a100_pcie_80.txtpb input.hlo
热门 GPU 的规范随编译器一起提供,所提供的文件是 device_description.proto 的字符串序列化:
gpu_device_info {
cuda_compute_capability {
major: 8
minor: 0
}
threads_per_block_limit: 1024
threads_per_warp: 32
shared_memory_per_block: 127152
shared_memory_per_core: 65536
threads_per_core_limit: 2048
core_count: 6192
fpus_per_core: 64
block_dim_limit_x: 2147483647
block_dim_limit_y: 65535
block_dim_limit_z: 65535
memory_bandwidth: 2039000000000
l2_cache_size: 4194304
clock_rate_ghz: 1.1105
device_memory_size: 79050250240
}
platform_name: "CUDA"
如需了解更多 GPU 规格,请访问 /xla/tools/hlo_opt/gpu_specs
自动调节
有时,合集可能涉及基于合集 --stage 的自动调音。
为了使无设备编译正常运行,用户需要使用 --xla_gpu_autotune_level=0
停用自动调优,或者使用 --xla_gpu_load_autotune_results_from=<filename>(通过 --xla_gpu_dump_autotune_results_to=<filename> 获得)
加载预先存在的自动调优结果。
hlo-opt --platform=CUDA --stage=llvm --xla_gpu_target_config_filename=gpu_specs/a100_pcie_80.txtpb --xla_gpu_load_autotune_results_from=results.textpb input.hlo
自动调谐文件是 autotune_results.proto 的文本序列化,示例如下所示:
version: 3
results {
device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB"
hlo: "{\n tmp_0 = f16[1,16,17,3]{3,2,1,0} parameter(0)\n tmp_1 = f16[16,51]{1,0} bitcast(f16[1,16,17,3]{3,2,1,0} tmp_0)\n tmp_2 = s8[16,17,3]{2,1,0} parameter(1)\n tmp_3 = s8[51,16]{0,1} bitcast(s8[16,17,3]{2,1,0} tmp_2)\n tmp_4 = f16[51,16]{0,1} convert(s8[51,16]{0,1} tmp_3)\n tmp_5 = f16[16,16]{1,0} dot(f16[16,51]{1,0} tmp_1, f16[51,16]{0,1} tmp_4), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n ROOT tmp_6 = f16[1,16,16]{2,1,0} bitcast(f16[16,16]{1,0} tmp_5)\n}"
result {
run_time {
nanos: 31744
}
triton {
block_m: 32
block_n: 32
block_k: 32
split_k: 1
num_stages: 1
num_warps: 4
}
}
}
可以使用 XLA_FLAGS=--xla_gpu_dump_autotune_results_to=<myfile.pbtxt> 对自动调优数据库进行序列化
[hlo-opt] HLO Pass 开发和调试
# If you are working with hardware independent passes from the
# `xla/hlo/transforms/` directory, prefer light-weight version
# of the `hlo-opt` tool with fewer dependencies:
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
# Otherwise, for hardware independent and CPU, GPU passes use
# the same binary from "Compile HLO Modules" section above:
bazel run //xla/tools:hlo-opt -<- [flags>] filename
借助 hlo-opt 工具,您可以独立于给定的平台编译阶段来执行各个 pass。这种隔离有助于快速对输入 HLO 模块运行 pass,并找出故障的根本原因。
hlo-opt --passes=schedule-aware-collective-cse input.hlo
hlo-opt 工具还支持 DebugOptions XLA_FLAGS。
hlo-opt --passes=schedule-aware-collective-cse
--xla_gpu_experimental_collective_cse_distance_threshold=20 input.hlo
使用 --list-passes 选项可获取通行证名称字符串。
hlo-opt --list-passes
用户可以通过为 --passes 选项指定多个 pass 来创建自己的自定义流水线。
hlo-opt --passes=pass1,pass2,pass3 input.hlo
协助开发新的 HLO 通行证
- 首先,撰写您的卡券。
将新通行证注册到
hlo-opt工具通行证注册表中。RegisterPass<FooPass>(FooPassInputOptions)根据通行证类型,选择以下任一注册位置:
opt_lib.cc与硬件无关的通行证。
cpu_opt.cc特定于 CPU 的传递。
gpu_opt.ccGPU 特定传递。
compiled_opt.ccCPU、GPU、XPU 通用的测试。
别忘了添加 build 依赖项。将卡券注册纳入您的 PR(示例)中,以便所有
hlo-opt用户都可以使用该卡券。重新构建
hlo-opt工具,使用--list-passes选项验证通行证注册是否成功,然后使用--passes选项运行通行证。$ hlo-opt --passes=foo-pass input.hlo为 pass 编写单元测试?如需了解详情,请参阅 https://openxla.org/xla/test_hlo_passes。
Pass 运行时测量
对于大型模型,完整编译运行可能需要几分钟时间,因此很难检测到细微的性能退化。相比之下,使用 hlo-opt 的单个 pass 运行可实现精确的性能衡量,并轻松检测到因新代码更改而导致的执行时间(即使是很小的增幅)。
time hlo-opt --passes=reduce-window-rewriter,scatter_simplifier
--xla_reduce_window_rewrite_base_length=128 input.hlo
[hlo-opt] 转换 HLO 模块格式
# Use the light weight version of the `hlo-opt` tool.
bazel run //xla/hlo/tools:hlo-opt -- [flags] <filename>
转换 HLO Text -> HLO Proto
hlo-opt --emit-proto input.hlo
将 HLO Proto 或 HLO Proto Binary 转换为 HLO Text
hlo-opt input.pbtxt or input.pb
[ptx-opt] 将编译器 LLVM 模块转换为 PTX
该工具将运行 LLVMIR 优化流水线,然后调用 CompileToPtx。
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 <filename>
该工具还可以在每个路径之后转储 LLVMIR。
bazel run //xla/hlo/tools/ptx-opt -- --arch=9.0 --xla_dump_to=<path> --xla_gpu_dump_llvmir <filename>