XProf を使用して OOM エラーをデバッグする

メモリ不足(OOM)エラーは、アクセラレータ(GPU または TPU)の高帯域幅メモリ(HBM)容量が不足すると発生します。OOM の問題の一般的な原因とデバッグ手法については、E1000 - コンパイル時の HBM OOM のドキュメントGPU メモリ割り当てに関する JAX のドキュメントで詳しく説明しています。

このページでは、XProf のメモリビューア ツールを使用して、JAX プログラムのメモリ使用量を可視化し、ピーク時の使用量インスタンスを特定して、OOM エラーをデバッグする方法について説明します。これには、次の手順が含まれます。

  1. jax.profiler.trace を使用してプログラムを実行し、プロファイルをキャプチャします。
  2. XProf をバックグラウンドで起動し、メモリ ビューア ツールを使用してメモリ使用率の詳細を表示します。

サンプル プログラム

次の JAX プログラムは OOM エラーになります。

import jax
from jax import random
import jax.numpy as jnp


@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
    a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
    return a @ a


if __name__ == "__main__":
    oom()

TPU マシンでは、このプログラムは次のエラーで失敗します。

XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>

(GPU マシンでは、エラーは XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 214748364800 bytes. のようになります)。

XProf を実行する

xprofpip install xprof)をインストールし、プロファイルが保存されるディレクトリを指定して XProf インスタンスを起動します。

xprof --logdir=/tmp/xprof/ --port=6006

インスタンス(ローカル マシン上の http://localhost:6006)に移動します。[ツール] プルダウンで [メモリ ビューア] を選択し、メモリ ビューア ツール ウィンドウの [メモリタイプ] プルダウンで [HBM] を選択します(通常はデフォルトで選択されています)。

上記のプログラム例の XProf メモリビューア ページ

XProf: メモリビューア ツールのドキュメントでは、ツールのコンポーネントと表示される情報について説明しています。

ピーク時のメモリ使用量の時点での 3 つのバッファグラフを示す HLO Ops at Peak Memory Allocation セクションに注目します。バッファには、次のものが含まれます。* プログラムの入力と出力: トレーニング バッチ、オプティマイザーの状態など。* TensorCore と SparseCore の一時変数: 中間計算(アクティベーション、グラデーションなど)に必要な動的メモリ

バッファグラフにカーソルを合わせると、Op のサイズ、形状、割り当てタイプなどの詳細を確認できます。これにより、ピークメモリに影響を与えており、調整または最適化が必要な、一時的なメモリ使用量が多い、または一時的なメモリ使用量が長期間にわたる Ops、非効率的なパディングを含む大きな入力/中間/出力テンソルなどを特定して評価できます。

具体的なデバッグ手法については、E1000: デバッグをご覧ください。