カテゴリ: コンパイル時間: HBM OOM
このエラーは、プログラムに必要な高帯域幅メモリ(HBM)が、TPU デバイスで物理的に使用可能な量を超えていることを示します。
エラー メッセージの例:
RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.
XLA バックエンド: TPU
概要
XLA は、必要な静的割り当ての合計サイズがデバイスの HBM に収まることを確認するチェックを実行します。
コンパイラは、次の複数の割り当てタイプに対して TPU の固定 HBM 容量を管理します。
- プログラムの入力と出力: トレーニング バッチ、オプティマイザーの状態など。
- TPU 一時変数: 中間計算(アクティベーション、グラデーションなど)に必要な動的メモリ。
- コンパイル済みバイナリ: TensorCore(TC)と SparseCore(SC)の両方のマシンコード。
- システム オーバーヘッド: XLA ランタイム用に予約されたスペース(以前の TPU 世代のインフィード バッファなど)。
- 定数: HLO IR に埋め込まれた定数値は HBM に割り当てられます。
- コンパイラの内部: プログラム レベルと HLO ごとの割り当て(メッシュ内のノードのルーティング情報など)。
このエラーは、XLA コンパイラが上記の割り当てをすべてデバイスの HBM に収めることができない場合に発生します。
デバッグ
エラー メッセージとログを慎重に分析して、次の HBM OOM のカテゴリのうち、エラーに最も適したものを特定します。
- 「TC Hbm usage: X, SC Hbm usage Y」: エラーで HBM 使用率が明示的に内訳表示されている場合、TensorCore(TC)と SparseCore(SC)の合計使用率が HBM の上限を超えています。→ シナリオ 1TC と SC の HBM 使用量のバランスを取る。
- 「Ran out of memory in memory space HBM」: ログで、HBM の最大割り当ての列挙を確認します。
- 予期せず大きなテンソル(HBM 上限の 50% 超など)が 1 つ以上存在する場合 → シナリオ 2 に進みます。予期しない大規模な割り当てによるメモリ不足。
- ログに予期しない大きなテンソルがない場合 → シナリオ 3 に進みます。集計割り当てによるメモリ不足。
シナリオ 1. TC と SC の HBM 使用量のバランスを取る
エラーで使用量が明示的に分類されている場合(例: 「TC Hbm usage: X, SC Hbm usage Y」)、これは TensorCore(TC)と SparseCore(SC)の合計使用量が HBM 上限を超えていることを意味します。2 つの値を比較して、ボトルネックを特定します。
- SparseCore の使用率が高い
- HBM スタックの使用量を最適化する: HBM スタックのメモリ使用量は、
feature_width、max_unique_nz_per_row、logical_replica_countに応じてスケーリングされます。テーブルの処理をシリアル化する--xla_sc_num_serialized_tables_to_optimize_hbmフラグを調整することで、スタック使用率のピークを減らすことができます。ただし、並列処理の削減という代償を伴います。 - パディングのオーバーヘッドを確認する: SparseCore は、エンベディング テーブルを 32B(8 個の浮動小数点数)に調整します。特徴の幅が小さい(8 個未満の浮動小数点数など)テーブルでは、パディングのオーバーヘッドが大きくなり、HBM が無駄になります。
- ヒープ使用量を減らす:
maximum_parallel_iterationsの値が大きいほど、HBM ヒープにプリフェッチされる入力データの量が増えます。この値を小さくすると、かなりのメモリを解放できます。 - シャーディングを確認する: エンベディング テーブルがすべてのチップで正しく mod-sharded されていることを確認します。上限がテーブルにどのように変換されるかをご覧ください。
- 他のアイデアについては、SC: パフォーマンスとメモリのボトルネックをご覧ください。
- HBM スタックの使用量を最適化する: HBM スタックのメモリ使用量は、
- TensorCore の使用率が高い
- シナリオ 2 に進みます。
- バランス重視
- どちらも個別に過剰ではないが、合計が大きすぎる場合は、チップの容量に達しています。両方のコンポーネントの使用量を減らす必要があります。3 つのセクションの推奨事項をすべて実施します。
シナリオ 2. 予期しない大きな割り当てによるメモリ不足
「Ran out of memory in memory space HBM」というエラー メッセージが表示され、ログに予期しない大きな割り当てが 1 つ以上存在する場合(HBM 制限の 50% 超)、ハードウェア容量の問題であることはほとんどありません。通常は構成エラーです。大規模な割り当ての XLA ラベル(存在する場合)で、JAX ソースコードのヒントを確認します。
- デバッグ アーティファクトを削除
- 大規模な実行で jax.debug.print() を使用すると、コンパイラが HBM でテンソル全体を具体化して CPU に転送し、フュージョンを中断してピーク時のメモリ使用量を増やす可能性があります。残っている
jax.debug.print()を削除します。
- 大規模な実行で jax.debug.print() を使用すると、コンパイラが HBM でテンソル全体を具体化して CPU に転送し、フュージョンを中断してピーク時のメモリ使用量を増やす可能性があります。残っている
- 非効率的なメッシュ形状またはシャーディングを修正する
- メッシュ形状が正しくない場合や、シャーディング アノテーションがない場合、コンパイラはデフォルトでレプリケーションを使用するため、コンパイラは非常に大きなテンソルを単一のチップに適合させようとします。
- 大きな割り当ての形状を確認し、シャーディングが XLA によって正しく指定され、伝播されていることを確認します。
シナリオ 3. 集計割り当てによるメモリ不足
「Ran out of memory in memory space HBM」というエラー メッセージが表示され、ログに予期しない大きなテンソルがない場合、割り当ての合計が HBM の上限を超えているため、プログラムの容量が不足しています。このような場合は、メモリ プロファイルを可視化して、ピーク使用量に寄与している特定のバッファを特定すると効果的です。ピーク時のメモリ使用量の原因を特定する手順については、XProf で OOM エラーをデバッグするをご覧ください。
上位の貢献者を特定したら、次の手順でメモリ使用量を最適化します。
シナリオ 3.A 構成を調整する
多くの場合、次の構成調整で OOM を解決できます。
- バッチサイズを減らす: 中間アクティベーションとグラデーションに必要なメモリは、バッチサイズに正比例します。バッチサイズを減らすと、メモリ使用量を削減できることがよくあります。
- 入力バッファを寄付する:
jax.jitを使用する場合は、モデル パラメータの donate_argnums を指定します。これにより、XLA は入力メモリを出力で上書きできます。 - 混合精度(bfloat16)を有効にする: モデル アーキテクチャと品質要件で許容される場合は、プログラム内の最大のテンソルに bfloat16 または量子化(int8 など)を使用します。この変更はモデルの動作に影響する可能性があるため、慎重に検討する必要があります。
シナリオ 3.B アーキテクチャとシャーディングを最適化する
構成の変更が不十分な場合、モデル トポロジが現在のハードウェア設定に対して大きすぎる可能性があります。
- 新しい TPU 世代を使用する: 一般的に、新しい TPU ほどチップあたりの HBM が多くなります。利用可能な場合は、新しい TPU 世代に切り替えます。
- より大きなチップ トポロジで実行する: モデルの重みが既存のトポロジに対して大きすぎる場合は、より多くのチップにシャーディングしてみてください。
- 高度なシャーディング手法を実装する:
- より高度なデータ、テンソル、パイプラインの並列処理アプローチを検討します。
- 中間値と出力のシャーディング ヒントを指定します。
- JAX ホスト オフロードを使用する: ホスト オフロード手法を使用すると、ユーザーは大きなテンソルをホスト CPU メモリにオフロードできます(アクティベーション オフロードやオプティマイザー状態オフロードなど)。
シナリオ 3.C テンソル パディングとアライメントを確認する
非効率的なテンソル形状は、TPU で OOM が発生する一般的な原因ですが、気づきにくいものです。TPU でピーク パフォーマンスを得るために、XLA はテンソル ディメンションをパディングします。通常、最下位のディメンションは 128 の倍数、2 番目に下位のディメンションは 8 の倍数にパディングされます。このパディングは、入力配列と中間テンソル(HLO テンポラリ)の両方に影響し、特にディメンション サイズが小さい場合に、メモリ使用量が大幅に増加する可能性があります。配列レイアウトをご覧ください。
- 大きなバッファの形状を監査:(デフォルト レイアウトの TPU v5 の場合)
- Xprof Memory Viewer でバッファにカーソルを合わせると、パディング情報などのバッファの詳細を含むバッファの詳細カードが表示されます。
- 例:
(129, 1024)の形状が(256, 1024)にパディングされると、メモリの約 50% が無駄になります。 - 訂正:
(128, 1024)の形状にはパディングは不要で、メモリの無駄は 0% になります。
- ディメンションを調整する: すべての大きなテンソル ディメンション(バッチサイズ、エンベディング ディメンション、隠れ層のサイズ)が 128 の倍数であることを確認します。この変更はモデルの動作に影響する可能性があるため、慎重に検討する必要があります。
シナリオ 3.D XLA フラグに影響するキーメモリを調整する
主要なメモリフラグを調整して、パフォーマンスとメモリ使用量のバランスを取ることができます。ただし、この戦略はパフォーマンスに悪影響を及ぼす可能性があるため、最後の手段として使用する必要があります。
シナリオ 3.E XLA 再実体化パス/手動チェックポイント設定を調整する
モデルがメモリに収まる直前の場合、jax.grad で jax.checkpoint デコレータを使用すると、フォワード パスで保存される中間値とバックワード パスで再計算される中間値を手動で制御し、コンピューティング サイクルと HBM を交換できます。
または、XLA::Rematerialization パスを強制的に実行して、メモリの節約を優先することもできます。この場合、コンパイルが遅くなる可能性があります。
| フラグ | 説明 | 影響 / トレードオフ |
|---|---|---|
--xla_tpu_max_hbm_size_mib |
再マテリアライズ パスで使用される HBM サイズの上限を手動で設定します。 | コンパイラに、実際の物理 HBM よりも小さい制限内にプログラムを収めるよう強制します。 |
--xla_tpu_rematerialization_algo=PEAK_PRIORITY |
ピーク時のメモリ使用量に焦点を当てて作業を進めます。 | デフォルトのアルゴリズムよりも、メモリを積極的に削減する場合に効率的です。 |
--xla_tpu_rematerialization_max_block_size_limit=32 |
一度に再実体化できるブロック内の命令の最大数を制御します。 | この値を大きくすると、コンパイル時間が大幅に増加しますが、メモリを節約できます。 |
--xla_tpu_rematerialization_block_effort_factor=10.0 |
再実体化するブロックの検索に費やす労力(コンパイル時間)の量を定義します。 | 値を大きくすると、メモリ節約のための徹底的な検索が可能になりますが、コンパイル時間が長くなります。 |
--xla_tpu_pre_fusion_remat=true |
融合パスの前に追加の再実体化パスを有効にします。 | メモリ使用量をさらに削減できますが、コンパイル時間が長くなり、数値の安定性に影響する可能性があります。 |
XLA フラグの変更は、パフォーマンスに悪影響を及ぼす可能性があるため、最後の手段として使用してください。
シナリオ 3.F 高度なプロファイリング ツールを使用する
XProf を使用して OOM エラーをデバッグするでは、XProf メモリビューアを使用して、コンパイラの HBM 使用状況のビューを可視化するチュートリアルを提供しています。
このツールを使用すると、ピーク時のメモリ割り当てとバッファのライフタイムを確認できます。これは、ピーク時の使用率で HBM を消費しているものを正確に把握するために不可欠です。一般的なプロファイリングの設定については、Xprof を使ってみると TensorBoard プロファイリングをご覧ください。