メモリ不足(OOM)エラーは、アクセラレータ(GPU または TPU)の高帯域幅メモリ(HBM)容量が不足すると発生します。OOM の問題の一般的な原因とデバッグ手法については、E1000 - コンパイル時の HBM OOM のドキュメントと GPU メモリ割り当てに関する JAX のドキュメントで詳しく説明しています。
このページでは、XProf のメモリビューア ツールを使用して、JAX プログラムのメモリ使用量を可視化し、ピーク時の使用量インスタンスを特定して、OOM エラーをデバッグする方法について説明します。これには、次の手順が含まれます。
jax.profiler.traceを使用してプログラムを実行し、プロファイルをキャプチャします。- XProf をバックグラウンドで起動し、メモリ ビューア ツールを使用してメモリ使用率の詳細を表示します。
サンプル プログラム
次の JAX プログラムは OOM エラーになります。
import jax
from jax import random
import jax.numpy as jnp
@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
return a @ a
if __name__ == "__main__":
oom()
TPU マシンでは、このプログラムは次のエラーで失敗します。
XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>
(GPU マシンでは、エラーは XlaRuntimeError: RESOURCE_EXHAUSTED:
Out of memory while trying to allocate 214748364800 bytes. のようになります)。
XProf を実行する
xprof(pip install xprof)をインストールし、プロファイルが保存されるディレクトリを指定して XProf インスタンスを起動します。
xprof --logdir=/tmp/xprof/ --port=6006
インスタンス(ローカル マシン上の http://localhost:6006)に移動します。[ツール] プルダウンで [メモリ ビューア] を選択し、メモリ ビューア ツール ウィンドウの [メモリタイプ] プルダウンで [HBM] を選択します(通常はデフォルトで選択されています)。

XProf: メモリビューア ツールのドキュメントでは、ツールのコンポーネントと表示される情報について説明しています。
ピーク時のメモリ使用量の時点での 3 つのバッファグラフを示す HLO Ops at Peak Memory Allocation セクションに注目します。バッファには、次のものが含まれます。* プログラムの入力と出力: トレーニング バッチ、オプティマイザーの状態など。* TensorCore と SparseCore の一時変数: 中間計算(アクティベーション、グラデーションなど)に必要な動的メモリ
バッファグラフにカーソルを合わせると、Op のサイズ、形状、割り当てタイプなどの詳細を確認できます。これにより、ピークメモリに影響を与えており、調整または最適化が必要な、一時的なメモリ使用量が多い、または一時的なメモリ使用量が長期間にわたる Ops、非効率的なパディングを含む大きな入力/中間/出力テンソルなどを特定して評価できます。
具体的なデバッグ手法については、E1000: デバッグをご覧ください。