HLO 計算をダンプする

HLO ダンプは、計算のさまざまな段階における HLO モジュールのテキスト表現です。デバッグに役立ち、バグレポートに含める必要があることがよくあります。通常、これは HLO 命令とそのプロパティを一覧表示する、人間が読めるテキスト ファイルです。HLO モジュールは次のようにダンプされることがあります。

  • HloProto: より構造化された機械可読形式のプロトコル バッファ ファイル。
  • HloSnapshot: HLO モジュールとその入力。HLO を再生する場合、ランダムなデータではなく、特定の計算に供給される実際の入力が必要になることがあります。

XLA フラグを使用して、ダンプを指定して取得できます。ほとんどの場合、環境変数で設定できます。JAX には、HLO ダンプをプログラムで出力する方法もあります。

ローカルでの実行

環境変数の使用

必要なフラグを使用して XLA_FLAGS 環境変数を設定し、ダンプを取得できます。これは、JAX、TensorFlow、PyTorch/XLA で機能します。

HLO モジュールやその他のデバッグ情報を特定のディレクトリにダンプするには、--xla_dump_to フラグを指定してプログラムを実行します。

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH"

たとえば、パスとして /tmp または /tmp/xladump を使用できます。

デフォルトでは、最適化パイプラインの最初と最後に HLO モジュールがテキストとしてダンプされます。

形式を明示的に指定することもできます。

  1. テキストダンプ
XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=DIRECTORY_PATH"
  1. HLO protos
XLA_FLAGS="--xla_dump_hlo_as_proto --xla_dump_to=DIRECTORY_PATH"
  1. HLO スナップショット
XLA_FLAGS="--xla_dump_hlo_snapshots --xla_dump_to=DIRECTORY_PATH"
  1. graphviz サーバーを使用したグラフのレンダリング(小さなグラフでのみ適切に機能します)
XLA_FLAGS="--xla_dump_hlo_as_url --xla_dump_to=DIRECTORY_PATH"
  1. グラフを HTML ファイルにレンダリングする(小規模なグラフでのみ適切に動作する)
XLA_FLAGS="--xla_dump_hlo_as_html --xla_dump_to=DIRECTORY_PATH"

大きなグラフの場合は、interactive_graphviz を使用してグラフの一部を可視化できます。

Dump Specific Intermediate Passes

標準の事前最適化 / 最終最適化された HLO に加えて、特定のコンパイラパス後の HLO の状態をダンプすることもできます。

XLA_FLAGS="--xla_dump_hlo_pass_re=regex --xla_dump_to=DIRECTORY_PATH"

名前が正規表現(regex)に一致するパスの HLO モジュールがダンプされます。たとえば、SPMD パーティショニングに関連するパスの結果として得られる HLO を確認するには、次のコマンドを使用します。

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=spmd|propagation"

XLA パスごとに結果をダンプするには(多くのファイルが生成されます)、次のように設定します。

XLA_FLAGS="--xla_dump_to=DIRECTORY_PATH --xla_dump_hlo_pass_re=.*"

JAX 固有のオプション

JAX でプログラム的に

フラグや環境変数を渡す代わりに、JAX の lower API と compile API を使用して HLO をプログラムでダンプすることもできます。

最適化されていない元の低レベル HLO をローカルで取得します。

jax.jit(f).lower(*args).as_text('hlo')

HLO コンパイル パス中にファイルにダンプするには、次のように指定します。

compilation_args = {
    'xla_dump_to': DIRECTORY_PATH,
    'xla_dump_hlo_pass_re': 'spmd|propagation', # or some other pass filter
    ...
    }

jax.jit(f).lower(*args).compile(compilation_args)

jaxpr をダンプする

jaxprs は、プログラム トレースの JAX の中間表現です。これをダンプするには、次の環境変数を設定します。

JAX_DUMP_IR_TO="DIRECTORY_PATH" JAX_DUMP_IR_MODES=jaxpr

詳しくは、JAX ドキュメントのステージアウトされた計算のエクスポートとシリアル化: デバッグをご覧ください。

Google Colab

環境変数

ノートブックの最初に実行されるセルに(環境変数とコマンドライン フラグは通常、モジュールのインポート時や XLA バックエンドの初期化時など、一度だけ処理されるため)、上記で説明した XLA_FLAGSos.environ とともに追加します。例:

import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=DIRECTORY_PATH"

これにより、計算が DIRECTORY_PATH(例: /tmp)にダンプされます。Colab で、左側のサイドバーにある [ファイル] ブラウザに移動して、このディレクトリを表示してアクセスします。

ローカル実行セクションで説明したすべてのフラグを使用できます。

JAX 固有のオプション

ローカル実行と同様に、ライブのインタラクティブなイントロスペクションでは、計算の事前最適化された HLO を直接出力できます。

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

c = jax.jit(f).lower(3.).compiler_ir('hlo')

print(c.as_hlo_text())

計算の最適化された HLO を直接出力することもできます。

def optimized_HLO(f, *args, platform=None):
    print(jax.jit(f).lower(*args).compile().as_text())

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

optimized_HLO(f, 1.0)

すべての計算/小規模な計算のダンプ

すべての小さなコンパイルを含むダンプ内のすべてを表示する場合は、JAX 環境変数を設定します。

JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0

モザイク

Mosaic は、Pallas TPU バックエンドと試験運用版の Pallas GPU バックエンド用のコンパイラです。モザイク計算をダンプするには、次のフラグを設定します。

--xla_mosaic_dump_to=/tmp/mosaic_dumps

または、TPU の初期化引数を環境変数として設定します。

export LIBTPU_INIT_ARGS="--xla_mosaic_dump_to=/tmp/mosaic_dumps"

詳細については、Pallas と Mosaic に関する JAX のドキュメントをご覧ください。

HLO ダンプの詳細

適切な計算を見つける

通常、多くの計算がダンプされます。ダンプされたファイルには、ログに記録された JAX、Tensorflow、PyTorch/XLA の「計算名」が明示的に付けられるため、関連する HLO ファイルを簡単に特定できます。次に例を示します。

1624325116260738.module_0065.pmap__unnamed_wrapped_function_.186875.before_optimizations.txt

それ以外の場合は、ripgrep を使用して、特定のシンボルまたは計算を保持するモジュールをすばやく特定できます。

ヒント: バグレポートには、関心のある 3 つのダンプされた before/after/buffer-assignment ファイルを含めてください。

HLO 変換

HLOProto とテキスト形式の間で変換できる hlo-opt というツール。これは、ある形式があるものの、デバッグ用に別の形式が必要な場合に便利です。

使用方法: XLA ツール ドキュメント: hlo-opt

もう一度再生

ダンプされた計算を、指定された XLA バックエンドで、フェイクデータまたは入力スナップショットを使用して実行(再生)できます。これは、XLA で問題を再現、反復、デバッグするのに便利な方法です。

次のコマンドでは、フェイクデータを使用します。HLO スナップショットを保存している場合は、代わりにそれらを渡すことができ、スナップショットのデータが使用されます。スナップショットの実行中に引き続きフェイクデータを使用するには、--force_fake_data フラグを渡します。

CPU バックエンド:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=cpu
 /tmp/xladump/module_4561.before_optimizations.txt

GPU バックエンド:

bazel run -c opt //xla/hlo/tools:run_hlo_module -- --platform=CUDA
 /tmp/xladump/module_4561.before_optimizations.txt