XProf は、GPU と TPU のアクティビティなど、プログラムのパフォーマンス トレースとプロファイルを取得して可視化するのに最適な方法です。最終的な結果は次のようになります。

プログラムによるキャプチャ
jax.profiler.start_trace メソッドと jax.profiler.stop_trace メソッドを使用して、JAX コードのプロファイラ トレースをキャプチャするようにコードをインストルメント化できます。トレース ファイルを書き込むディレクトリを指定して jax.profiler.start_trace を呼び出します。これは、XProf の起動に使用したのと同じ --logdir ディレクトリである必要があります。次に、XProf を使用してトレースを表示できます。
たとえば、プロファイラ トレースを取得するには、次のようにします。
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
jax.block_until_ready 呼び出しに注意してください。これは、オンデバイス実行がトレースでキャプチャされるようにするために使用されます。これがなぜ必要なのかについては、非同期ディスパッチをご覧ください。
start_trace と stop_trace の代わりに jax.profiler.trace コンテキスト マネージャーを使用することもできます。
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
トレースの表示
トレースをキャプチャしたら、XProf UI を使用して表示できます。
ログ ディレクトリを指定して、スタンドアロンの XProf コマンドを使用すると、プロファイラ UI を直接起動できます。
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
指定された URL(http://localhost:8791/)をクリックして、プロフィールを表示します。
利用可能なトレースが左側の [セッション] プルダウン メニューに表示されます。目的のセッションを選択し、[ツール] プルダウンで [トレース ビューア] を選択します。実行のタイムラインが表示されます。WASD キーを使用してトレースを移動し、クリックまたはドラッグしてイベントを選択すると、詳細が表示されます。トレースビューアの使用方法の詳細については、Trace Viewer ツールのドキュメントをご覧ください。
XProf による手動キャプチャ
実行中のプログラムから手動でトリガーされた N 秒間のトレースをキャプチャする手順は次のとおりです。
XProf サーバーを起動します。
xprof --logdir /tmp/profile-data/<http://localhost:8791/>で XProf を読み込めるはずです。--portフラグを使用して別のポートを指定できます。プロファイリングする Python プログラムまたはプロセスで、次のコードを先頭付近に追加します。
import jax.profiler jax.profiler.start_server(9999)これにより、XProf が接続するプロファイラ サーバーが起動します。次のステップに進む前に、プロファイラ サーバーが実行されている必要があります。サーバーの使用が完了したら、
jax.profiler.stop_server()を呼び出してシャットダウンできます。長時間実行されるプログラム(長いトレーニング ループなど)のスニペットをプロファイリングする場合は、プログラムの先頭にこのコードを配置して、通常どおりにプログラムを開始します。短いプログラム(マイクロベンチマークなど)をプロファイリングする場合は、IPython シェルでプロファイラ サーバーを起動し、次のステップでキャプチャを開始してから
%runで短いプログラムを実行する方法があります。別の方法として、プログラムの開始時にプロファイラ サーバーを起動し、time.sleep()を使用してキャプチャを開始するのに十分な時間を確保することもできます。<http://localhost:8791/>を開き、左上の [CAPTURE PROFILE] ボタンをクリックします。プロファイル サービス URL として「localhost:9999」と入力します(これは、前の手順で起動したプロファイラ サーバーのアドレスです)。プロファイリングするミリ秒数を入力し、[キャプチャ] をクリックします。プロファイリングするコードがまだ実行されていない場合(Python シェルでプロファイラ サーバーを起動した場合など)、キャプチャの実行中にコードを実行します。
キャプチャが終了すると、XProf が自動的に更新されます。(XProf のプロファイリング機能の一部は JAX にフックされていないため、最初は何もキャプチャされていないように見えることがあります)。左側の [ツール] で、[トレース ビューア] を選択します。
実行のタイムラインが表示されます。WASD キーを使用してトレースを移動し、クリックまたはドラッグしてイベントを選択すると、下部に詳細が表示されます。トレースビューアの使用方法の詳細については、トレースビューア ツールのドキュメントをご覧ください。
XProf と TensorBoard
XProf は、Tensorboard のプロファイリングとトレース キャプチャ機能を支える基盤となるツールです。xprof がインストールされている限り、TensorBoard に [Profile] タブが表示されます。同じログ ディレクトリを指すように起動する限り、これを使用することは XProf を個別に起動することと同じです。これには、プロファイルのキャプチャ、分析、表示機能が含まれます。XProf は、以前に推奨されていた tensorboard_plugin_profile 機能に代わるものです。
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
カスタム トレース イベントを追加する
デフォルトでは、トレース ビューアのイベントは主に低レベルの内部 JAX 関数です。コードで jax.profiler.TraceAnnotation と jax.profiler.annotate_function を使用して、独自のイベントと関数を追加できます。
プロファイラ オプションの構成
start_trace メソッドは、プロファイラの動作をきめ細かく制御できるオプションの profiler_options パラメータを受け取ります。このパラメータは jax.profiler.ProfileOptions のインスタンスである必要があります。
たとえば、すべての Python トレースとホスト トレースを無効にするには、次のようにします。
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
全般オプション
host_tracer_level: ホストサイド アクティビティのトレースレベルを設定します。サポートされる値:
0: ホスト(CPU)のトレースを完全に無効にします。1: ユーザーが計測した TraceMe イベントのみのトレースを有効にします。2: レベル 1 のトレースに加えて、高コストの XLA 演算などの高レベルのプログラム実行の詳細を含めます(デフォルト)。3: レベル 2 のトレースに加えて、安価な XLA オペレーションなど、より詳細な低レベルのプログラム実行の詳細が含まれます。
device_tracer_level: デバイス トレースが有効かどうかを制御します。サポートされる値:
0: デバイス トレースを無効にします。1: デバイス トレースを有効にします(デフォルト)。
python_tracer_level: Python トレースが有効かどうかを制御します。サポートされる値:
0: Python 関数呼び出しのトレースを無効にします(デフォルト)。1: Python トレースを有効にします。
高度な構成のオプション
TPU のオプション
tpu_trace_mode: TPU トレースのモードを指定します。サポートされる値:
TRACE_ONLY_HOST: ホスト側(CPU)のアクティビティのみがトレースされ、デバイス(TPU/GPU)のトレースは収集されません。TRACE_ONLY_XLA: デバイス上の XLA レベルのオペレーションのみがトレースされることを意味します。TRACE_COMPUTE: デバイス上のコンピューティング オペレーションをトレースします。TRACE_COMPUTE_AND_SYNC: デバイス上のコンピューティング オペレーションと同期イベントの両方をトレースします。
「tpu_trace_mode」が指定されていない場合、trace_mode のデフォルトは
TRACE_ONLY_XLAです。tpu_num_sparse_cores_to_trace: TPU でトレースするスパース コアの数を指定します。tpu_num_sparse_core_tiles_to_trace: TPU でトレースする各スパースコア内のタイルの数を指定します。tpu_num_chips_to_profile_per_task: タスクごとにプロファイリングする TPU チップの数を指定します。
GPU オプション
GPU プロファイリングでは、次のオプションを使用できます。
gpu_max_callback_api_events: CUPTI コールバック API によって収集されるイベントの最大数を設定します。デフォルトは2*1024*1024です。gpu_max_activity_api_events: CUPTI アクティビティ API によって収集されるイベントの最大数を設定します。デフォルトは2*1024*1024です。gpu_max_annotation_strings: 収集できるアノテーション文字列の最大数を設定します。デフォルトは1024*1024です。gpu_enable_nvtx_tracking: CUPTI で NVTX トラッキングを有効にします。デフォルトはFalseです。gpu_enable_cupti_activity_graph_trace: CUDA グラフの CUPTI アクティビティ グラフのトレースを有効にします。デフォルトはFalseです。gpu_pm_sample_counters: CUPTI の PM サンプリング機能を使用して収集する GPU パフォーマンス モニタリング指標のカンマ区切り文字列(例:"sm__cycles_active.avg.pct_of_peak_sustained_elapsed")。PM サンプリングはデフォルトで無効になっています。使用可能な指標については、NVIDIA の CUPTI ドキュメントをご覧ください。gpu_pm_sample_interval_us: CUPTI PM サンプリングのサンプリング間隔をマイクロ秒単位で設定します。デフォルトは500です。gpu_pm_sample_buffer_size_per_gpu_mb: CUPTI PM サンプリング用に、デバイスごとのシステム メモリ バッファサイズを MB 単位で設定します。デフォルトは 64 MB です。サポートされている最大値は 4 GB です。gpu_num_chips_to_profile_per_task: タスクごとにプロファイリングする GPU デバイスの数を指定します。指定しない場合、0 に設定した場合、または無効な値を設定した場合、使用可能なすべての GPU がプロファイリングされます。これは、トレース収集のサイズを小さくするために使用できます。gpu_dump_graph_node_mapping: 有効にすると、CUDA グラフノード マッピング情報がトレースにダンプされます。デフォルトはFalseです。
次に例を示します。
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
認識されないキーまたはオプション値が見つかった場合は、InvalidArgumentError を返します。