LHS 費用モデル


要約

このページでは、レイテンシ隠蔽スケジューラで使用される費用モデルの内部構造について説明します。モデルのチューニングに関心がある場合は、チューニング セクションに直接進んでください。

レイテンシ隠蔽スケジューラ(LHS)は、経過時間を最小限に抑えるように HLO DAG をスケジュールするコンパイラパスです。

この決定は、パフォーマンス テーブルと分析モデルを組み合わせた統合費用モデルに基づいて行われます。特に、XLA は GEMM と高速インターコネクト コレクティブのパフォーマンス テーブルを埋め込み、他のケースでは分析ネットワーキングと融合コストモデルを使用します。このドキュメントの残りの部分では、これらの内部動作の概要について説明します。


パフォーマンス テーブル - ICI コレクティブ

パフォーマンス テーブルは、コレクタとインターポレータの 2 つの主要コンポーネントで構成されています。

コレクタ

コレクタは、集団オペレーションのパフォーマンス テーブルを生成する C++ ツールです。個々の HLO オペレーション(all-gatherall-reduce)を静的に定義されたパラメータ空間で実行します。

仕組み

このツールは、特定のクラスタのさまざまな集団演算、転送サイズ、転送スキームをスイープします。既存のマルチホスト HLO ランナー インフラストラクチャと ExecutionProfile データを使用して、生成された HLO を実行し、パフォーマンス指標を収集します。

データ収集パラメータ

レイテンシ テーブルは、次のパラメータのクロス積に対して収集されます。

  • Collective Type:
    • all-reduce
    • all-gather
    • reduce-scatter
  • 転送サイズ:
    • 1,024 バイトから 2 GiB までの対数スケール(例: 1024B、2048B、4096B など)
  • 転送スキーム:
    • rail-aligned
    • non-rail-aligned

このスイープは、2、4、8 台のデバイスを備えたノード内クラスタに対して実行されます。

出力

収集の実行結果は、.pbtxt 形式のレイテンシ テーブルです(プラットフォームごとに約 116 KB)。

Interpolator

インターポレータは、生成されたパフォーマンス テーブルを使用してコンパイル中にランタイムの見積もりを提供するコンパイラ コンポーネントです。

内部データ構造

初期化時に、Interpolator はパフォーマンス テーブルをマップに変換します。このマップは、キーとして (collective_type, transfer_scheme) のタプルを使用します。

各キーに関連付けられたは 2 次元ユークリッド平面です。このプレーンは、次の 2 つの軸に基づいて、ネットワーク スループット(コレクタによって測定)をインデックスに登録します。

  1. 転送サイズ。
  2. 対象デバイスの数。

ルックアップと補間

コンパイラが集団オペレーションを検出すると、インターポレータは次の手順を実行します。

  1. オペレーションの (collective_type, transfer_scheme) をマップキーとして使用して、正しい 2D スループット プレーンを識別します。
  2. 次に、その 2D 平面内で、オペレーションの (transfer_size, num_devices) をクエリポイントとして使用して、加重平均検索(ユークリッド距離に基づく)を使用します。
  3. このルックアップの結果は、一意の 1 つのネットワーク スループット値です。

理由: スループットと推定

このシステムは、レイテンシの生データではなく、ネットワーク スループットを保存するように設計されています。この設計により、テーブルに明示的に示されていない転送サイズのパフォーマンスを外挿する作業が大幅に簡素化されます。

レイテンシ テーブルが集合サイズ S でネットワーク帯域幅の飽和をキャプチャした場合、その時点のスループット T が最大と見なされます。サイズ S' > S の新しいコレクティブの場合、ランタイムは次のように推定できます。

\[\text{EstimatedTime}(S') = \frac{S'}{T_{\text{saturated} } }\]

これにより、コレクタで測定された最大値の 2 GiB を超える場合でも、任意のサイズのコレクティブのパフォーマンスをモデルで推定できます。

  • 最大スループットを過小評価します。
  • そのため、大規模な転送の実行時間を過大に見積もることになります。

通常、XLA:GPU チームがパフォーマンス テーブルを維持しますが、ユーザーが独自のテーブルを提供する場合、テーブルを生成するユーザーは、テーブルが代表的であり、ターゲット ハードウェアの帯域幅飽和領域の測定値が含まれていることを確認する必要があります。


パフォーマンス テーブル - GEMM

コレクティブのシステムと同様に、GEMM レイテンシ テーブルは、コレクタインターポレータの 2 つのコンポーネントでサポートされています。

コレクタ

コレクタは、一般行列乗算(GEMM)のパフォーマンス テーブルを計算する C++ ツールです。HLO dot オペレーション レベルで行列乗算のパフォーマンスを測定します。

仕組み

このツールは、GEMM ディメンション(バッチ、2 つの非収縮ディメンション、1 つの収縮ディメンション)とデータ型の静的空間をスイープします。

  • デフォルトのデータ型: LHS = bf16,f32RHS = bf16,f32OUT = bf16,f32
  • インフラストラクチャ: HLO op プロファイラを再利用します。

コレクション パラメータ

レイテンシ テーブルは、次のディメンションのクロス プロダクトについて収集されます。

  • batch: {1, 2, 4}
  • m(契約なし): {256, 512, ..., 4096}
  • n(契約なし): {256, 512, ..., 4096}
  • k(収縮): {256, 512, ..., 4096}

出力とストレージ

フルスイープでは、インターポレータで使用できる .pbtxt レイテンシ テーブルが生成されます。

Interpolator

インターポレータは、生成されたテーブルを使用して GEMM パフォーマンスを推定するコンパイラ コンポーネントです。

理由: FLOPS の飽和

収集されたレイテンシ テーブルにより、補間器は各エントリの FLOPS を再構築できます。

\[\text{FLOPS} = \frac{2 \times b \times m \times n \times k}{\text{runtime} }\]

重要なのは、FLOPS がある時点で飽和することです。つまり、ハードウェアは特定の行列の形状を超えるとピーク FLOPS に達します。この飽和により、集合体で使用されるのと同じ外挿法を使用できます。

ルックアップと補間

補間器は、テーブルデータから 4D ユークリッド空間を構築します。パフォーマンスの見積もりを提供するために、この 4 次元空間内で加重平均補間を行います。特定のデータ型のテーブルがない場合、各ディメンションはヒューリスティックとしてバイト数に正規化されます。


分析費用モデル - DCN

S カーブの集合費用モデル

S 字曲線モデルは、完全に分析的なネットワーキングのルーフライン モデルです。

概要

このモデルは、一連の固定ネットワーク プロパティに基づいて、集団オペレーションのパフォーマンスを推定するように設計されています。

モデル入力

このモデルには、次の 2 つのカテゴリの入力が必要です。

  1. 固定ネットワーク プロパティ(ユーザー定義):

    • 一括起動のオーバーヘッド
    • NIC の速度
    • RTT(ラウンドトリップ時間)

    デフォルトでは、XLA はプラットフォームを自動検出し、最も一般的なアーキテクチャの値を使用します。これらのプロパティはユーザーが構成できます。詳細については、チューニングのセクションをご覧ください。

  2. Per-Collective Inputs:

    • 集合型(例: AllGatherReduceScatter
    • 転送サイズ
    • 通信に関与するノードの数

統合

S カーブ モデルは XLA:GPU に統合され、Hopper と Blackwell で使用されています。


分析費用モデル - フュージョン

他のカーネルについては、GPU パフォーマンス費用モデルを使用して適切なランタイムを推定します。詳しくは、お支払い基準額をご参照ください。


チューニング

S 字曲線モデルは、適切な XLA フラグを発行することで調整できます。ほとんどの場合、デフォルトの構成で十分ですが、他のケースではモデル制御が公開されます。

export NIC_SPEED_GBPS=... # NIC speed per GPU in Gigabytes
export GPUS_PER_NODE=... # Num of GPUs per cluster interconnected with fast network (e.g. NVLINK)
export XLA_FLAGS=--xla_gpu_analytical_latency_estimator_options="nic_speed_gbps=$NIC_SPEED_GBPS,gpus_per_node=$GPUS_PER_NODE"