このガイドでは、XLA の機能を効果的にナビゲートして活用できるように、主要な XLA フラグを厳選して紹介します。以降のセクションでは、ランタイム パフォーマンスとメモリ使用率に大きな影響を与える可能性のあるフラグについて詳しく説明します。フラグを有効にした後にクラッシュなどの問題が発生した場合は、デフォルト設定に戻して GitHub の問題を作成することをおすすめします。
パフォーマンス フラグ
次のフラグは、ランタイム パフォーマンスの向上に役立ちます。これらの設定をテストすると、パフォーマンスが大幅に向上する可能性があります。
フラグ | 説明 | デフォルト値 | 推奨値 | 候補者の価値 |
---|---|---|---|---|
パイプライン処理 1. xla_should_allow_loop_variant_parameter_in_chain 2. xla_should_add_loop_invariant_op_in_chain 3. xla_tpu_enable_ici_ag_pipelining |
これらの 3 つのフラグを組み合わせて使用すると、ICI(Interchip-Interconnect)の all-gather オペレーションの集団パイプライン処理を有効にできます。これにより、実行の重複の機会が増えます。 | 1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled 2. xla_should_add_loop_invariant_op_in_chain=kDisabled 3. xla_tpu_enable_ici_ag_pipelining=false |
1. xla_should_allow_loop_variant_parameter_in_chain=kEnabled 2. xla_should_add_loop_invariant_op_in_chain=kEnabled 3. xla_tpu_enable_ici_ag_pipelining=true |
1. xla_should_allow_loop_variant_parameter_in_chain=kDisabled/kEnabled/kAuto 2. xla_should_add_loop_invariant_op_in_chain=kDisabled/kEnabled/kAuto 3. xla_tpu_enable_ici_ag_pipelining=true/false |
v5e/Async xla_enable_async_all_gather xla_tpu_enable_async_collective_fusion xla_tpu_enable_async_collective_fusion_fuse_all_gather |
これらの 3 つのフラグを組み合わせて使用すると、v5e で非同期の all-gather オペレーションを有効にできます。 | xla_enable_async_all_gather=kAuto xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_gather=true |
xla_enable_async_all_gather=kAuto xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_gather=true |
xla_enable_async_all_gather=kDisabled/kEnabled/kAuto xla_tpu_enable_async_collective_fusion=true/false xla_tpu_enable_async_collective_fusion_fuse_all_gather=true/false |
v5e/Async xla_tpu_enable_async_collective_fusion xla_tpu_enable_async_collective_fusion_fuse_all_reduce |
v5e で非同期の all-reduce オペレーションを有効にするには、この 2 つのフラグを組み合わせて使用する必要があります。 | xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_reduce=false |
xla_tpu_enable_async_collective_fusion=true xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true |
xla_tpu_enable_async_collective_fusion=true/false xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true/false |
Async xla_tpu_enable_async_all_to_all |
このフラグにより、非同期の全対全通信が有効になります。 | xla_tpu_enable_async_all_to_all=false |
xla_tpu_enable_async_all_to_all=true |
xla_tpu_enable_async_all_to_all=true/false |
レイテンシ バウンド xla_all_gather_latency_bound_threshold_in_bytes |
このフラグは、レイテンシ制限のある(つまり、小規模な)all-gather オペレーションを対象としています。これを有効にすると、レイテンシ制限のある all-gather の実行時間を短縮できる特定の最適化がトリガーされます。通常は推論ワークロードで使用されます。 | xla_all_gather_latency_bound_threshold_in_bytes=-1 (有効になっていません) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
レイテンシ バウンド xla_all_reduce_latency_bound_threshold_in_bytes |
このフラグは、レイテンシ制限のある(つまり、小規模な)all-gather オペレーションを対象としています。これを有効にすると、レイテンシ制限のある all-reduce の実行時間を短縮できる特定の最適化がトリガーされます。通常は推論ワークロードで使用されます。 | xla_all_reduce_latency_bound_threshold_in_bytes=-1 (有効になっていません) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
レイテンシ バウンド xla_collective_permute_latency_bound_threshold_in_bytes |
このフラグは、レイテンシ制限のある(つまり、小規模な)all-gather オペレーションを対象としています。これを有効にすると、レイテンシ制限のあるグループ置換の実行時間を短縮できる特定の最適化がトリガーされます。通常は推論ワークロードで使用されます。 | xla_collective_permute_latency_bound_threshold_in_bytes=-1 (有効になっていません) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
レイテンシ バウンド xla_all_to_all_latency_bound_threshold_in_bytes |
このフラグは、レイテンシ制限のある(つまり、小規模な)all-gather オペレーションを対象としています。これを有効にすると、レイテンシが制約された all-to-all の実行時間を短縮できる特定の最適化がトリガーされます。通常は推論ワークロードで使用されます。 | xla_all_to_all_latency_bound_threshold_in_bytes=-1 (有効になっていません) |
4~16Mb(i.e. 4~16 * 1024 * 1024) |
[0, 9223372036854775807] |
xla_enable_async_collective_permute |
すべての collective-permute オペレーションを非同期バリアントに書き換えます。auto に設定すると、XLA は他の構成や条件に基づいて非同期コレクティブを自動的にオンにできます。 |
xla_enable_async_collective_permute=kAuto |
xla_enable_async_collective_permute=kAuto |
xla_enable_async_collective_permute=kAuto/kEnabled/kDisabled |
メモリフラグ
以下のフラグは、HBM 関連の問題に対処するために提供されています。これらは、モデルのコンパイル中に HBM のメモリ不足エラーが発生した場合にのみ調整する必要があります。他のすべてのシナリオでは、デフォルト値を使用することをおすすめします。デフォルト値を変更すると、パフォーマンスに悪影響を及ぼす可能性があります。
フラグ | 説明 | デフォルト値 | 推奨値 | 候補者の価値 |
---|---|---|---|---|
スケジューラ xla_latency_hiding_scheduler_rerun |
この設定は、レイテンシを隠すスケジューラの動作を調整します。この機能は、プロセスの「再実行」ごとにスケジューリングに割り当てられるメモリ上限を段階的に減らすことで機能します。 | xla_latency_hiding_scheduler_rerun=1 |
xla_latency_hiding_scheduler_rerun=5 |
0~10(it doesn’t make much sense beyond 10 reruns) |
Fusion xla_tpu_rwb_fusion |
このフラグは reduce+broadcast タイプのフュージョンを有効にし、メモリ使用量を削減する可能性があります。 | xla_tpu_rwb_fusion=true |
xla_tpu_rwb_fusion=false |
xla_tpu_rwb_fusion=true/false |
スケジューラ xla_memory_scheduler |
このフラグは、メモリ スケジューラがメモリ消費量を最小限に抑えるために使用するアルゴリズムを指定します。より高度なアルゴリズムを使用すると、コンパイル時間が長くなる代わりに、メモリ消費量の少ないスケジュールを取得できる可能性があります。 | xla_memory_scheduler=kDefault |
xla_memory_scheduler=kBrkga |
xla_memory_scheduler=kDefault/kList/kDfs/kPostOrder/kBrkga |
スケジューラ xla_tpu_enable_latency_hiding_scheduler |
このフラグにより、レイテンシを隠蔽するスケジューラが有効になります。これにより、同期ではなく非同期の集合を実行できます。無効にすると、これらの非同期オペレーションによるパフォーマンスの向上は失われますが、メモリ使用量は削減されます。 | xla_tpu_enable_latency_hiding_scheduler=true |
xla_tpu_enable_latency_hiding_scheduler=false |
xla_tpu_enable_latency_hiding_scheduler=true/false |
SPMD xla_jf_spmd_threshold_for_windowed_einsum_mib |
このフラグは、集合 matmul をトリガーするドットの最小サイズのしきい値を設定します。値を大きくすると、メモリは節約できますが、集合 matmul を実行する機会が失われます。 | xla_jf_spmd_threshold_for_windowed_einsum_mib=-1 |
10Mb~1Gb (i.e. 10*1024*1024 ~ 1024*1024*1024) |
[0, 9223372036854775807] |
その他のよく使用されるフラグ
フラグ | タイプ | メモ |
---|---|---|
xla_dump_to |
String(filepath) | 最適化前の HLO ファイルとその他のアーティファクトが配置されるフォルダ(XLA ツールを参照)。 |
TPU XLA フラグ
フラグ | タイプ | メモ |
---|---|---|
xla_tpu_enable_data_parallel_all_reduce_opt |
ブール値(true / false) | データ並列シャーディングに使用される DCN(データセンター ネットワーキング)の all-reduce のオーバーラップの機会を増やすための最適化。 |
xla_tpu_data_parallel_opt_different_sized_ops |
ブール値(true / false) | 出力サイズがスタックされた変数に保存できるものと一致しない場合でも、複数のイテレーションにわたってデータ並列処理のパイプライン処理を有効にします。メモリ負荷が増加する可能性があります。 |
xla_tpu_spmd_rng_bit_generator_unsafe |
ブール値(true / false) | RngBitGenerator HLO をパーティション分割された方法で実行するかどうか。計算の異なる部分で異なるシャーディングを使用して決定論的な結果が期待される場合は安全ではありません。 |
xla_tpu_megacore_fusion_allow_ags |
ブール値(true / false) | all-gather を畳み込み/all-reduce と融合できます。 |
xla_tpu_enable_ag_backward_pipelining |
ブール値(true / false) | パイプラインは、スキャンループを逆方向にすべて収集します(現在はメガスケールですべて収集します)。 |
GPU XLA フラグ
フラグ | タイプ | メモ |
---|---|---|
xla_gpu_enable_latency_hiding_scheduler |
ブール値(true / false) | このフラグにより、レイテンシ隠蔽スケジューラが非同期通信と計算を効率的にオーバーラップできるようになります。デフォルト値は False です。 |
xla_gpu_enable_triton_gemm |
ブール値(true / false) | Triton ベースの行列乗算を使用します。 |
xla_gpu_graph_level |
フラグ(0 ~ 3) | GPU グラフのレベルを設定するための以前のフラグ。新しいユースケースでは xla_gpu_enable_command_buffer を使用します。0 = オフ、1 = フュージョンと memcpys をキャプチャ、2 = gemm をキャプチャ、3 = 畳み込みをキャプチャ。 |
xla_gpu_all_reduce_combine_threshold_bytes |
整数(バイト) | これらのフラグは、複数の小さな AllGather / ReduceScatter / AllReduce を 1 つの大きな AllGather / ReduceScatter / AllReduce に結合して、デバイス間の通信に費やす時間を短縮するタイミングを調整します。たとえば、Transformer ベースのワークロードの AllGather / ReduceScatter しきい値については、少なくとも Transformer レイヤの重み AllGather / ReduceScatter を組み合わせるのに十分な高さに調整することを検討してください。デフォルトでは、combine_threshold_bytes は 256 に設定されています。 |
xla_gpu_all_gather_combine_threshold_bytes |
整数(バイト) | 上記の xla_gpu_all_reduce_combine_threshold_bytes をご覧ください。 |
xla_gpu_reduce_scatter_combine_threshold_bytes |
整数(バイト) | 上記の xla_gpu_all_reduce_combine_threshold_bytes をご覧ください。 |
xla_gpu_enable_pipelined_all_gather |
ブール値(true / false) | all-gather 命令のパイプライン処理を有効にします。 |
xla_gpu_enable_pipelined_reduce_scatter |
ブール値(true / false) | reduce-scatter 命令のパイプライン処理を有効にします。 |
xla_gpu_enable_pipelined_all_reduce |
ブール値(true / false) | all-reduce 命令のパイプライン処理を有効にします。 |
xla_gpu_enable_while_loop_double_buffering |
ブール値(true / false) | while ループのダブル バッファリングを有効にします。 |
xla_gpu_enable_all_gather_combine_by_dim |
ブール値(true / false) | 同じ収集ディメンションを持つ all-gather オペレーションを組み合わせるか、ディメンションに関係なく組み合わせます。 |
xla_gpu_enable_reduce_scatter_combine_by_dim |
ブール値(true / false) | 同じディメンションを持つ reduce-scatter オペレーションを結合します。ディメンションに関係なく結合することもできます。 |