カテゴリ: ランタイム: プログラムの割り当てに失敗しました
このエラーは、TPU デバイスの XLA ランタイムが、コンパイルされた XLA プログラムの実行可能ファイルを TPU の HBM に読み込めなかったことを示します。
エラー メッセージの例:
XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program 'jit_embedding_pipeline_step_fn': Attempting to reserve 29.49G at the bottom of memory. That was not possible. There are 147.64M free, 0B reserved, and 147.64M reservable. Scope: unknown..: while running replica 0 and partition 34 of a replicated computation (other replicas may have failed as well).
XLA バックエンド: TPU
概要
通常、このエラーは次のいずれかの原因で発生します。
- プログラム サイズが使用可能な HBM を超えている: コンパイルされた XLA プログラム(命令、静的データ、埋め込み定数を含む)が、プログラムが読み込まれている特定の TPU コアで現在使用可能な HBM の合計量よりも大きい。
- HBM の断片化: デバイス上の HBM の合計空き容量は十分でも、プログラム全体を収容できるほど大きな連続したブロックとして利用できない。
TPU ランタイムがメモリに優先順位を付ける方法を理解することが重要です。バッファ割り当ては、読み込まれたプログラムよりも優先されます。バッファの割り当てが失敗した場合、ランタイムはすでに読み込まれているプログラムを HBM から強制排除して、空き容量を確保します。これにより、以前は正常に読み込まれていたプログラムが、HBM がより多くのデータバッファで占有されるため、OOM エラーで失敗する可能性があります。
デバッグ
- バッファ メモリのフットプリントを削減する: データバッファで使用されるメモリを解放すると、プログラム自体により多くの空き容量が残ります。
- バッチサイズを減らす: これは、アクティベーションに使用されるメモリ量を減らす最も効果的な方法の 1 つです。
- パラメータ シャーディング: 非常に大規模なモデルの場合は、モデルの並列処理またはシャーディング手法(FSDP や Megascale など)を使用して、モデルのパラメータと計算を複数の TPU コアまたはホストに分散します。
- シーケンス/コンテキストの長さを短縮する: 順次データを処理するモデル(NLP モデルなど)では、シーケンス長を短縮すると、メモリ使用量を大幅に削減できます。
- バッファ寄付: フレームワーク機能(
jax.jit(..., donate_argnums=...))を実装して、XLA が入力バッファのメモリを再利用して出力を保存できるようにし、ピーク時のメモリ使用量を削減します。
- 一時的なプログラムのメモリ要件を削減します。
tpu_shared_memory_percentフラグを使用して、一時的なプログラムのメモリ使用量を減らします。なお、この設定はパフォーマンスに悪影響を及ぼす可能性があります。
- 実行戦略を最適化する/サービング負荷を軽減する:
- プログラムの読み込みを管理する: 複数の関数を JIT コンパイルする場合、各関数でプログラムが読み込まれる可能性があることに注意してください。同時に読み込まれるプログラムの数を最小限に抑えるようにワークロードを構成します。
- メモリリークが発生しないことを確認します。
jax.Arrayオブジェクトへの参照が意図したよりも長く保持されていないことを確認します。jax.Arrayオブジェクトを保持すると、プログラムのコンパイルが完了した後でも自動割り当て解除が妨げられる可能性があります。
ツール
- アロケータが OOM 発生時に現在のすべてのアロケーションの詳細なレポートをダンプする
tpu_log_allocations_on_oomフラグを有効にします。これはデバッグに非常に役立ちます。 - プログラムのプロファイリング: JAX メモリ プロファイラまたは TensorFlow プロファイラを使用して、プログラムのメモリ使用量の経時変化の詳細なビューを取得します。これにより、メモリ使用量の予期しないピークを特定できます。