カテゴリ: コンパイル時間: HBM OOM
このエラーは、プログラムに必要な高帯域幅メモリ(HBM)が、TPU デバイスで物理的に使用可能な量を超えていることを示します。
エラー メッセージの例:
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.
RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
XLA バックエンド: TPU
概要
XLA は、必要なすべての静的割り当ての合計サイズがデバイスの HBM に収まることを確認するチェックを実行します。
コンパイラは、いくつかのタイプのアロケーションに対して TPU の固定 HBM 容量を管理します。
- プログラムの入力と出力: トレーニング バッチ、オプティマイザーの状態など。
- TensorCore + SparseCore の一時変数: 中間計算(アクティベーション、グラデーションなど)に必要な動的メモリ。
- コンパイル済みバイナリ: TensorCore(TC)と SparseCore(SC)の両方のマシンコード。
- システム オーバーヘッド: XLA ランタイム用に予約されたスペース(以前の TPU 世代のインフィード バッファなど)。
- 定数: HLO IR に埋め込まれた定数値は HBM に割り当てられます。
- コンパイラの内部: プログラム レベルと HLO ごとの割り当て(メッシュ内のノードのルーティング情報など)
このエラーは、XLA コンパイラが上記の割り当てをすべてデバイスの HBM に収めることができない場合に発生します。
デバッグ
エラー メッセージとログを慎重に分析して、次の HBM OOM のカテゴリのうち、エラーに最も適したものを特定します。
- TensorCore(TC)+ SparseCore(SC)HBM の使用量が上限を超えています: エラーで使用量が明示的に分類されている場合(例: 「TC Hbm 使用量: X、SC Hbm 使用量 Y」。→ セクション 1 に移動します。TC と SC の HBM 使用量のバランスを取る。
- 予期しない大きな割り当て: エラーに「メモリ空間 HBM でメモリ不足が発生しました」と表示された場合は、ログで HBM の最大の割り当ての列挙を確認します。予期せず大きなテンソル(HBM 上限の 50% 超など)が 1 つ以上存在する場合 → セクション 2 に進みます。予期しない大規模な割り当て。
- Aggregate Allocations Exceed HBM Limit: エラーに「Ran out of memory in memory space HBM」と表示されているが、ログに予期しない大きなテンソルがない場合 → セクション 3 に進みます。Aggregate Allocations Exceed HBM Limit。
セクション 1. TC と SC の HBM 使用量のバランスを取る
エラーで使用状況が明示的に分類されている場合(例: 「TC Hbm usage: X, SC Hbm usage Y」: 2 つの値を比較してボトルネックを特定します。
- SparseCore の使用率が高い:
- HBM スタックの使用量を最適化する: HBM スタックのメモリ消費量は
feature_width、max_unique_nz_per_row、logical_replica_countに応じてスケーリングされます。テーブルの処理をシリアル化する--xla_sc_num_serialized_tables_to_optimize_hbmフラグを調整することで、ピーク時のスタック使用量を削減できます。ただし、並列処理の削減という代償を伴います。 - パディングのオーバーヘッドを確認する: SparseCore は、エンベディング テーブルを 32B(8 個の浮動小数点数)に調整します。特徴量の幅が小さいテーブル(< 8 個の浮動小数点数)は、パディングのオーバーヘッドが大きくなり、HBM が無駄になります。
- ヒープ使用量を減らす:
maximum_parallel_iterationsの値が大きいと、HBM ヒープにプリフェッチされる入力データの量が増加します。この値を小さくすると、かなりのメモリを解放できます。 - シャーディングを確認する: エンベディング テーブルがすべてのチップで適切に mod-sharded されていることを確認します。上限がテーブルにどのように変換されるかをご覧ください。
- その他のアイデアについては、SC: パフォーマンスとメモリのボトルネックをご覧ください。
- HBM スタックの使用量を最適化する: HBM スタックのメモリ消費量は
- TensorCore の使用率が高い:
- セクション 2 に進みます。
- バランス重視
- どちらも個別に過剰ではないが、合計が大きすぎる場合は、チップの容量に達しています。両方のコンポーネントの使用量を減らす必要があります。3 つのセクションすべてで推奨事項に従います。
セクション 2. 予期しない大きな割り当て
ログに予期しない大きな割り当てが 1 つ以上ある場合(HBM 上限の 50% 超)、ハードウェア容量の問題であることはほとんどありません。通常は構成エラーです。大きな割り当ての XLA ラベル(存在する場合)を調べて、JAX ソースコードのヒントを確認します。
- デバッグ アーティファクトを削除:
- 大規模な実行で jax.debug.print() を使用すると、コンパイラが HBM でテンソル全体を具体化して CPU に転送し、フュージョンを中断してピーク時のメモリ使用量を増やす可能性があります。残っている
jax.debug.print()を削除します。
- 大規模な実行で jax.debug.print() を使用すると、コンパイラが HBM でテンソル全体を具体化して CPU に転送し、フュージョンを中断してピーク時のメモリ使用量を増やす可能性があります。残っている
- 非効率的なメッシュ形状またはシャーディングを修正する:
- メッシュ形状が正しくない場合や、シャーディング アノテーションが欠落している場合、コンパイラはデフォルトで Replication に設定され、コンパイラは非常に大きなテンソルを単一のチップに適合させようとします。
- 大きな割り当ての形状を確認し、シャーディングが XLA によって正しく指定され、伝播されていることを確認します。
セクション 3. 合計割り当てが HBM の上限を超えている
割り当ての合計が HBM の上限を超えたためにプログラムの容量が不足した場合は、メモリ プロファイルを可視化して、ピーク時の使用量に寄与している特定のバッファを特定すると効果的です。ピーク時のメモリ使用量の原因を特定する手順については、XProf で OOM エラーをデバッグするをご覧ください。
上位の貢献者を特定したら、次の手順でメモリ フットプリントを最適化します。
A. テンソルのパディングと配置を確認する
非効率的なテンソル形状は、TPU で OOM が発生する一般的な原因ですが、気づきにくいものです。TPU でピーク パフォーマンスを得るために、XLA はテンソル ディメンションをパディングします。通常、最下位のディメンションは 128 の倍数、2 番目の下位ディメンションは 8 の倍数になります。このパディングは、入力配列と中間テンソル(HLO テンポラリ)の両方に影響し、特にディメンション サイズが小さい場合に、メモリ使用量が大幅に増加する可能性があります。配列レイアウトをご覧ください。
- 大きなバッファの形状を監査:(デフォルト レイアウトの TPU v5 の場合)
- Xprof Memory Viewer でバッファにカーソルを合わせると、パディング情報などのバッファの詳細を含むバッファの詳細カードが表示されます。
- 例:
(129, 1024)の形状が(256, 1024)にパディングされると、メモリの約 50% が無駄になります。 - 修正:
(128, 1024)の形状にはパディングは不要で、メモリの無駄は 0% になります。
- ディメンションを調整する: すべての大きなテンソル ディメンション(バッチサイズ、エンベディング ディメンション、隠れ層のサイズ)が 128 の倍数であることを確認します。
B. 構成を調整する
多くの場合、次の構成調整で OOM を解決できます。
- バッチサイズを減らす: 中間アクティベーションとグラデーションに必要なメモリは、バッチサイズに正比例します。バッチサイズを減らすと、メモリ使用量を削減できることがよくあります。
- 入力バッファを寄付:
jax.jitを使用する場合は、モデル パラメータに donate_argnums を指定します。これにより、XLA は入力メモリを出力で上書きできます。 - 混合精度(bfloat16)を有効にする: モデル アーキテクチャと品質要件で許容される場合は、プログラム内の最大のテンソルに bfloat16 または量子化(int8 など)を使用します。
C. アーキテクチャとシャーディングを最適化する
構成の変更が不十分な場合、モデル トポロジが現在のハードウェア設定に対して大きすぎる可能性があります。
- 新しい TPU 世代を使用する: 一般的に、新しい TPU はチップあたりの HBM が多くなります。利用可能な場合は、新しい TPU 世代に切り替えます。
- より大きなチップ トポロジで実行する: モデルの重みが既存のトポロジに対して大きすぎる場合は、より多くのチップにシャーディングしてみてください。
- 高度なシャーディング手法を実装する:
- より高度なデータ、テンソル、パイプラインの並列処理アプローチを検討します。
- 中間値と出力のシャーディング ヒントを指定します。
- JAX ホスト オフロードを使用する: 大きなテンソルをホスト CPU メモリにオフロードします。たとえば、アクティベーション オフロードやオプティマイザー状態オフロードなどです。
D. メモリに影響する主な XLA フラグを調整します。
主要なメモリフラグを調整して、パフォーマンスとメモリ使用量のバランスを取ることができます。ただし、パフォーマンスに悪影響を及ぼす可能性があるため、これらの方法は最後の手段として使用する必要があります。
E. XLA 再マテリアライズ パス / 手動チェックポイント設定を調整する
モデルがメモリにほぼ収まる場合は、XLA::Rematerialization パスを強制してメモリの節約を優先できます。ただし、コンパイルが遅くなる可能性があります。
| フラグ | 説明 | 影響 / トレードオフ |
|---|---|---|
--xla_tpu_max_hbm_size_mib |
再マテリアライズ パスで使用される HBM サイズの上限を手動で設定します。 | コンパイラに、実際の物理 HBM よりも小さい制限内にプログラムを収めるよう強制します。 |
--xla_tpu_rematerialization_algo=PEAK_PRIORITY |
ピーク時のメモリ使用量に焦点を当てて作業を進めます。 | デフォルトのアルゴリズムよりも積極的なメモリ削減に効率的です。 |
--xla_tpu_rematerialization_max_block_size_limit=32 |
一度に再実体化できるブロック内の命令の最大数を制御します。 | この値を大きくすると、コンパイル時間が大幅に増加しますが、メモリを節約できます。 |
--xla_tpu_rematerialization_block_effort_factor=10.0 |
再実体化するブロックの検索に費やす労力(コンパイル時間)の量を定義します。 | 値を大きくすると、メモリ節約のための徹底的な検索が可能になりますが、コンパイル時間が長くなります。 |
--xla_tpu_pre_fusion_remat=true |
融合パスの前に追加の再実体化パスを有効にします。 | メモリ使用量をさらに削減できますが、コンパイル時間が長くなり、数値の安定性に影響する可能性があります。 |
または、jax.grad で jax.checkpoint デコレータを使用して、フォワード パスで保存される中間値とバックワード パスで再計算される中間値を手動で制御し、コンピューティング サイクルと HBM をトレードオフします。
F. 高度なプロファイリング ツールを使用する
XProf を使用して OOM エラーをデバッグするでは、XProf メモリビューアを使用してコンパイラの HBM 使用状況を可視化するチュートリアルを提供しています。
このツールを使用すると、ピーク時のメモリ割り当てとバッファのライフタイムを確認できます。これは、ピーク時の使用率で HBM を消費しているものを正確に把握するために不可欠です。一般的なプロファイリングの設定については、Xprof を使ってみると TensorBoard プロファイリングをご覧ください。