本指南精选了一些关键的 XLA 标志,可帮助用户有效地了解和利用 XLA 的功能。以下部分详细介绍了可能会对运行时性能和内存利用率产生重大影响的标志。如果在启用某个标志后出现任何问题(例如崩溃),建议恢复为默认设置并创建 GitHub 问题。
性能标志
以下标志有助于提升运行时性能。 尝试调整这些设置可能会带来显著的效果提升。
标志 | 说明 | 默认值 | 建议值 | 候选值 |
---|---|---|---|---|
流水线 1. xla_should_allow_loop_variant_parameter_in_chain 2. xla_should_add_loop_invariant_op_in_chain 3. xla_tpu_enable_ici_ag_pipelining |
这 3 个标志应结合使用,以实现 ICI(芯片间互连)全收集操作的集体流水线处理,从而为重叠执行创造更多机会。 | 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled 2. xla_should_add_loop_invariant_op_in_chain=kDisabled 3. xla_tpu_enable_ici_ag_pipelining=false |
1. xla_should_allow_loop_variant_parameter_in_chain=kEnabled 2. xla_should_add_loop_invariant_op_in_chain=kEnabled 3. xla_tpu_enable_ici_ag_pipelining=true |
1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled/kEnabled/kAuto 2. xla_should_add_loop_invariant_op_in_chain=kDisabled/kEnabled/kAuto 3. xla_tpu_enable_ici_ag_pipelining=true/false |
v5e/Async xla_enable_async_all_gather xla_tpu_enable_async_collective_fusion xla_tpu_enable_async_collective_fusion_fuse_all_gather |
这 3 个标志应结合使用,以在 v5e 上激活异步 all-gather 操作。 | xla_enable_async_all_gather=kAuto xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_gather=true |
xla_enable_async_all_gather=kAuto xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_gather=true |
xla_enable_async_all_gather=kDisabled/kEnabled/kAuto xla_tpu_enable_async_collective_fusion=true/false xla_tpu_enable_async_collective_fusion_fuse_all_gather=true/false |
v5e/Async xla_tpu_enable_async_collective_fusion xla_tpu_enable_async_collective_fusion_fuse_all_reduce |
这两个标志应结合使用,以在 v5e 上激活异步 all-reduce 操作。 | xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false |
xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true |
xla_tpu_enable_async_collective_fusion=true/false xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true/false |
Async xla_tpu_enable_async_all_to_all |
此标志用于启用异步全到全通信。 | xla_tpu_enable_async_all_to_all=false |
xla_tpu_enable_async_all_to_all=true |
xla_tpu_enable_async_all_to_all=true/false |
受延迟时间限制 xla_all_gather_latency_bound_threshold_in_bytes |
此标志适用于受延迟限制(即规模较小)的 all-gather 操作。启用此设置会触发特定优化,从而缩短延迟时间受限的全收集的执行时间。通常用于推理工作负载。 | xla_all_gather_latency_bound_threshold_in_bytes=-1 (未启用) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
受延迟时间限制 xla_all_reduce_latency_bound_threshold_in_bytes |
此标志适用于受延迟限制(即规模较小)的 all-gather 操作。启用此功能会触发特定的优化,从而缩短延迟受限的 all-reduce 的执行时间。通常用于推理工作负载。 | xla_all_reduce_latency_bound_threshold_in_bytes=-1 (未启用) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
受延迟时间限制 xla_collective_permute_latency_bound_threshold_in_bytes |
此标志适用于受延迟限制(即规模较小)的 all-gather 操作。启用此功能会触发特定的优化,从而缩短延迟受限的集合排列的执行时间。通常用于推理工作负载。 | xla_collective_permute_latency_bound_threshold_in_bytes=-1 (未启用) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
受延迟时间限制 xla_all_to_all_latency_bound_threshold_in_bytes |
此标志适用于受延迟限制(即规模较小)的 all-gather 操作。启用此功能会触发特定优化,从而缩短延迟受限的全到全执行时间。通常用于推理工作负载。 | xla_all_to_all_latency_bound_threshold_in_bytes=-1 (未启用) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
xla_enable_async_collective_permute |
将所有集体排列操作重写为其异步变体。如果设置为 auto ,XLA 可以根据其他配置或条件自动开启异步集合。 |
xla_enable_async_collective_permute=kAuto |
xla_enable_async_collective_permute=kAuto |
xla_enable_async_collective_permute=kAuto/kEnabled/kDisabled |
内存标志
以下标志用于解决与 HBM 相关的问题。只有在模型编译期间遇到 HBM“内存不足”错误时,才应调整这些参数。在所有其他情况下,建议使用默认值,因为更改这些值可能会对效果产生不利影响。
标志 | 说明 | 默认值 | 建议值 | 候选值 |
---|---|---|---|---|
调度器 xla_latency_hiding_scheduler_rerun |
此设置用于调整延迟隐藏调度器的行为。它通过在每次进程“重新运行”时逐步减少为调度分配的内存限制来发挥作用。 | xla_latency_hiding_scheduler_rerun=1 |
xla_latency_hiding_scheduler_rerun=5 |
0~10(it doesn’t make much sense beyond 10 reruns) |
Fusion xla_tpu_rwb_fusion |
此标志可启用 reduce+broadcast 类型的融合,并可能会减少内存用量。 | xla_tpu_rwb_fusion=true |
xla_tpu_rwb_fusion=false |
xla_tpu_rwb_fusion=true/false |
调度器 xla_memory_scheduler |
此标志指定内存调度程序将用于最大限度减少内存消耗的算法。使用更高级的算法可能会获得内存消耗更少的调度,但代价是编译时间更长。 | xla_memory_scheduler=kDefault |
xla_memory_scheduler=kBrkga |
xla_memory_scheduler=kDefault/kList/kDfs/kPostOrder/kBrkga |
调度器 xla_tpu_enable_latency_hiding_scheduler |
此标志用于启用延迟隐藏调度程序,从而允许我们执行异步集合,而不是同步集合。停用此功能可减少内存使用量,但会失去这些异步操作带来的性能提升。 | xla_tpu_enable_latency_hiding_scheduler=true |
xla_tpu_enable_latency_hiding_scheduler=false |
xla_tpu_enable_latency_hiding_scheduler=true/false |
SPMD xla_jf_spmd_threshold_for_windowed_einsum_mib |
此标志用于设置触发集体 matmul 的点积大小下限。将其设置为较高的值可节省内存,但会失去执行集体 matmul 的机会。 | xla_jf_spmd_threshold_for_windowed_einsum_mib=-1 |
10Mb~1Gb (i.e. 10*1024*1024 ~ 1024*1024*1024) |
[0, 9223372036854775807] |
其他常用标志
标志 | 类型 | 备注 |
---|---|---|
xla_dump_to |
字符串(文件路径) | 放置预优化 HLO 文件和其他制品(请参阅 XLA 工具)的文件夹。 |
TPU XLA 标志
标志 | 类型 | 备注 |
---|---|---|
xla_tpu_enable_data_parallel_all_reduce_opt |
布尔值 (true/false) | 优化了 DCN(数据中心网络)用于数据并行分片的所有归约,以增加重叠机会。 |
xla_tpu_data_parallel_opt_different_sized_ops |
布尔值 (true/false) | 即使数据并行操作的输出大小与堆叠变量中可就地保存的大小不匹配,也能在多次迭代中实现数据并行操作的流水线处理。可能会增加内存压力。 |
xla_tpu_spmd_rng_bit_generator_unsafe |
布尔值 (true/false) | 是否以分区方式运行 RngBitGenerator HLO,如果希望在计算的不同部分使用不同的分片获得确定性结果,则此方式不安全。 |
xla_tpu_megacore_fusion_allow_ags |
布尔值 (true/false) | 允许将所有收集与卷积/所有归约融合。 |
xla_tpu_enable_ag_backward_pipelining |
布尔值 (true/false) | 流水线通过扫描循环向后进行全收集(目前为巨型全收集)。 |
GPU XLA flag
标志 | 类型 | 备注 |
---|---|---|
xla_gpu_enable_latency_hiding_scheduler |
布尔值 (true/false) | 此标志可让延迟隐藏调度程序高效地将异步通信与计算重叠。默认值为 False。 |
xla_gpu_enable_triton_gemm |
布尔值 (true/false) | 使用基于 Triton 的矩阵乘法。 |
xla_gpu_graph_level |
标志 (0-3) | 用于设置 GPU 图形级别的旧版标志。在新使用情形中,使用 xla_gpu_enable_command_buffer。0 = 关闭;1 = 捕获融合和内存复制;2 = 捕获 GEMM;3 = 捕获卷积。 |
xla_gpu_all_reduce_combine_threshold_bytes |
整数(字节) | 这些标志用于调整何时将多个小的 AllGather / ReduceScatter / AllReduce 合并为一个大的 AllGather / ReduceScatter / AllReduce,以减少在跨设备通信上花费的时间。例如,对于基于 Transformer 的工作负载上的 AllGather / ReduceScatter 阈值,请考虑将它们调整得足够高,以便至少组合一个 Transformer 层的权重 AllGather / ReduceScatter。默认情况下,combine_threshold_bytes 设置为 256。 |
xla_gpu_all_gather_combine_threshold_bytes |
整数(字节) | 请参阅上文中的 xla_gpu_all_reduce_combine_threshold_bytes。 |
xla_gpu_reduce_scatter_combine_threshold_bytes |
整数(字节) | 请参阅上文中的 xla_gpu_all_reduce_combine_threshold_bytes。 |
xla_gpu_enable_pipelined_all_gather |
布尔值 (true/false) | 启用 all-gather 指令的流水线处理。 |
xla_gpu_enable_pipelined_reduce_scatter |
布尔值 (true/false) | 启用 reduce-scatter 指令流水线。 |
xla_gpu_enable_pipelined_all_reduce |
布尔值 (true/false) | 启用 all-reduce 指令流水线。 |
xla_gpu_enable_while_loop_double_buffering |
布尔值 (true/false) | 为 while 循环启用双缓冲。 |
xla_gpu_enable_all_gather_combine_by_dim |
布尔值 (true/false) | 将具有相同收集维度或不考虑维度的 all-gather 操作相结合。 |
xla_gpu_enable_reduce_scatter_combine_by_dim |
布尔值 (true/false) | 合并具有相同维度或不考虑维度的 reduce-scatter 操作。 |