XLA Flags Guidance

This guide offers a curated selection of key XLA flags to assist users in effectively navigating and utilizing XLA's capabilities. The following sections detail flags that can significantly impact runtime performance and memory utilization. Should any issues, such as crashes, arise after enabling a flag, it is recommended to revert to the default setting and create a GitHub issue.

Performance Flags

The following flags are instrumental in enhancing runtime performance. Experimenting with these settings may lead to considerable performance gains.

Flag Description Default Values Suggested Values Candidate Values
Pipelining
1. xla_should_allow_loop_variant_parameter_in_chain
2. xla_should_add_loop_invariant_op_in_chain
3. xla_tpu_enable_ici_ag_pipelining
These 3 flags should be used in conjunction to enable collective pipelining of ICI(Interchip-Interconnect) all-gather operations, which creates more opportunities for overlapping execution. 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled
2. xla_should_add_loop_invariant_op_in_chain=kDisabled
3. xla_tpu_enable_ici_ag_pipelining=false
1. xla_should_allow_loop_variant_parameter_in_chain=kEnabled
2. xla_should_add_loop_invariant_op_in_chain=kEnabled
3. xla_tpu_enable_ici_ag_pipelining=true
1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled/kEnabled/kAuto
2. xla_should_add_loop_invariant_op_in_chain=kDisabled/kEnabled/kAuto
3. xla_tpu_enable_ici_ag_pipelining=true/false
v5e/Async
xla_enable_async_all_gather
xla_tpu_enable_async_collective_fusion
xla_tpu_enable_async_collective_fusion_fuse_all_gather
These 3 flags should be used in conjunction to activate asynchronous all-gather operations on v5e. xla_enable_async_all_gather=kAuto
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
xla_enable_async_all_gather=kAuto
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true
xla_enable_async_all_gather=kDisabled/kEnabled/kAuto
xla_tpu_enable_async_collective_fusion=true/false
xla_tpu_enable_async_collective_fusion_fuse_all_gather=true/false
v5e/Async
xla_tpu_enable_async_collective_fusion
xla_tpu_enable_async_collective_fusion_fuse_all_reduce
These 2 flags should be used in conjunction to activate asynchronous all-reduce operations on v5e. xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false
xla_tpu_enable_async_collective_fusion=true
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true
xla_tpu_enable_async_collective_fusion=true/false
xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true/false
Async
xla_tpu_enable_async_all_to_all
This flag enables asynchronous all-to-all communication. xla_tpu_enable_async_all_to_all=false xla_tpu_enable_async_all_to_all=true xla_tpu_enable_async_all_to_all=true/false
Latency-bound
xla_all_gather_latency_bound_threshold_in_bytes
This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound all-gathers. Typically it’s used in inference workloads. xla_all_gather_latency_bound_threshold_in_bytes=-1
(which is not enabled)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
Latency-bound
xla_all_reduce_latency_bound_threshold_in_bytes
This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound all-reduces. Typically it’s used in inference workloads. xla_all_reduce_latency_bound_threshold_in_bytes=-1
(which is not enabled)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
Latency-bound
xla_collective_permute_latency_bound_threshold_in_bytes
This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound collective-permutes. Typically it’s used in inference workloads. xla_collective_permute_latency_bound_threshold_in_bytes=-1
(which is not enabled)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
Latency-bound
xla_all_to_all_latency_bound_threshold_in_bytes
This flag is intended for latency-bound (i.e., small-sized) all-gather operations. Enabling this triggers specific optimizations that can reduce execution time for latency-bound all-to-all. Typically it’s used in inference workloads. xla_all_to_all_latency_bound_threshold_in_bytes=-1
(which is not enabled)
4~16Mb(i.e. 4~16 * 1024 * 1024) [0, 9223372036854775807]
xla_enable_async_collective_permute Rewrites all collective-permute operations to their asynchronous variants. When set to auto, XLA can turn on async collective based on other configurations or conditions automatically. xla_enable_async_collective_permute=kAuto xla_enable_async_collective_permute=kAuto xla_enable_async_collective_permute=kAuto/kEnabled/kDisabled

Memory Flags

The flags listed below are provided to address HBM-related issues. These should only be adjusted if you encounter HBM "out of memory" errors during model compilation. In all other scenarios, the default values are recommended, as altering them could adversely affect performance.

Flag Description Default Values Suggested Values Candidate Values
Scheduler
xla_latency_hiding_scheduler_rerun
This setting adjusts the behavior of the latency-hiding scheduler. It works by incrementally reducing the memory limit allocated for scheduling with each "rerun" of the process. xla_latency_hiding_scheduler_rerun=1 xla_latency_hiding_scheduler_rerun=5 0~10(it doesn’t make much sense beyond 10 reruns)
Fusion
xla_tpu_rwb_fusion
This flag enables reduce+broadcast type of fusions, and may decrease memory usage. xla_tpu_rwb_fusion=true xla_tpu_rwb_fusion=false xla_tpu_rwb_fusion=true/false
Scheduler
xla_memory_scheduler
This flag specifies the algorithm the memory scheduler will use to minimize memory consumption. Using a more advanced algorithm might get a less memory-consuming schedule, at the cost of longer compilation time. xla_memory_scheduler=kDefault xla_memory_scheduler=kBrkga xla_memory_scheduler=kDefault/kList/kDfs/kPostOrder/kBrkga
Scheduler
xla_tpu_enable_latency_hiding_scheduler
This flag enables the latency-hiding scheduler, which allows us to perform asynchronous collective instead of synchronous ones. Disabling it reduces memory usage at the cost of losing the performance gains from these asynchronous operations. xla_tpu_enable_latency_hiding_scheduler=true xla_tpu_enable_latency_hiding_scheduler=false xla_tpu_enable_latency_hiding_scheduler=true/false
SPMD
xla_jf_spmd_threshold_for_windowed_einsum_mib
This flag sets the lower threshold of the minimum size of the dot to trigger collective matmul. Setting it to a higher value would save memory at the cost of losing opportunities to perform collective matmul. xla_jf_spmd_threshold_for_windowed_einsum_mib=-1 10Mb~1Gb (i.e. 10*1024*1024 ~ 1024*1024*1024) [0, 9223372036854775807]

Other commonly used flags

Flag Type Notes
xla_dump_to String (filepath) The folder where pre-optimization HLO files and other artifacts will be placed (see XLA Tools).

TPU XLA flags

Flag Type Notes
xla_tpu_enable_data_parallel_all_reduce_opt Boolean (true/false) Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding.
xla_tpu_data_parallel_opt_different_sized_ops Boolean (true/false) Enables pipelining of data parallel ops across multiple iterations even if their output sizes don't match what can be saved in place in the stacked variables. Can increase memory pressure.
xla_tpu_spmd_rng_bit_generator_unsafe Boolean (true/false) Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation.
xla_tpu_megacore_fusion_allow_ags Boolean (true/false) Allows fusing all-gathers with convolutions/all-reduces.
xla_tpu_enable_ag_backward_pipelining Boolean (true/false) Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops.

GPU XLA flags

Flag Type Notes
xla_gpu_enable_latency_hiding_scheduler Boolean (true/false) This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.
xla_gpu_enable_triton_gemm Boolean (true/false) Use Triton-based matrix multiplication.
xla_gpu_graph_level Flag (0-3) The legacy flag for setting GPU graph level. Use xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture fusions and memcpys; 2 = capture gemms; 3 = capture convolutions.
xla_gpu_all_reduce_combine_threshold_bytes Integer (bytes) These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256.
xla_gpu_all_gather_combine_threshold_bytes Integer (bytes) See xla_gpu_all_reduce_combine_threshold_bytes above.
xla_gpu_reduce_scatter_combine_threshold_bytes Integer (bytes) See xla_gpu_all_reduce_combine_threshold_bytes above.
xla_gpu_enable_pipelined_all_gather Boolean (true/false) Enable pipelinling of all-gather instructions.
xla_gpu_enable_pipelined_reduce_scatter Boolean (true/false) Enable pipelinling of reduce-scatter instructions.
xla_gpu_enable_pipelined_all_reduce Boolean (true/false) Enable pipelinling of all-reduce instructions.
xla_gpu_enable_while_loop_double_buffering Boolean (true/false) Enable double-buffering for while loop.
xla_gpu_enable_all_gather_combine_by_dim Boolean (true/false) Combine all-gather ops with the same gather dimension or irrespective of their dimension.
xla_gpu_enable_reduce_scatter_combine_by_dim Boolean (true/false) Combine reduce-scatter ops with the same dimension or irrespective of their dimension.