使用 XProf 对 JAX 计算进行性能剖析

XProf 是一种出色的工具,可用于获取和直观呈现程序的性能轨迹和配置文件,包括 GPU 和 TPU 上的活动。最终结果如下所示:

XProf 示例

程序化捕获

您可以使用 jax.profiler.start_tracejax.profiler.stop_trace 方法对代码进行插桩,以捕获 JAX 代码的分析器轨迹。使用要写入轨迹文件的目录调用 jax.profiler.start_trace。该目录应与用于启动 XProf 的 --logdir 目录相同。然后,您可以使用 XProf 查看轨迹。

例如,如需获取性能剖析器轨迹,请执行以下操作:

import jax

jax.profiler.start_trace("/tmp/profile-data")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

请注意 jax.block_until_ready 调用。我们使用此功能来确保轨迹捕获设备上的执行情况。如需详细了解为何需要这样做,请参阅异步调度

您还可以使用 jax.profiler.trace 上下文管理器来替代 start_tracestop_trace

import jax

with jax.profiler.trace("/tmp/profile-data"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready()

查看轨迹

捕获轨迹后,您可以使用 XProf 界面查看轨迹。

您可以使用独立的 XProf 命令直接启动性能剖析器界面,只需将该命令指向您的日志目录即可:

$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
  Log Directory: /tmp/profile-data
  Port: 8791
  Worker Service Address: 0.0.0.0:50051
  Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)

前往提供的网址(例如 http://localhost:8791/)在浏览器中查看个人资料。

左侧的“会话”下拉菜单中会显示可用的轨迹。选择您感兴趣的会话,然后在“工具”下拉菜单下选择“轨迹查看器”。您现在应该会看到执行时间轴。您可以使用 WASD 键浏览轨迹,点击或拖动可选择事件以查看更多详情。如需详细了解如何使用轨迹查看器,请参阅轨迹查看器工具文档

通过 XProf 手动捕获

以下说明介绍了如何从正在运行的程序中捕获手动触发的 N 秒轨迹。

  1. 启动 XProf 服务器:

    xprof --logdir /tmp/profile-data/
    

    您应该能够在 <http://localhost:8791/> 加载 XProf。您可以使用 --port 标志指定其他端口。

  2. 在要进行性能剖析的 Python 程序或进程中,在开头附近添加以下内容:

    import jax.profiler
    jax.profiler.start_server(9999)
    

    此命令会启动 XProf 连接到的分析器服务器。分析器服务器必须处于运行状态,然后才能继续执行下一步。使用完服务器后,您可以调用 jax.profiler.stop_server() 将其关闭。

    如果您想分析长时间运行的程序(例如长时间的训练循环)的一小段,可以将此代码放在程序开头,然后像往常一样启动程序。如果您想分析一个简短的程序(例如微基准),一种方法是在 IPython shell 中启动性能分析器服务器,并在下一步中开始捕获后使用 %run 运行该简短的程序。另一种方法是在程序开始时启动性能分析器服务器,并使用 time.sleep() 为您提供足够的时间来开始捕获。

  3. 打开 <http://localhost:8791/>,然后点击左上角的“捕获性能剖析文件”按钮。输入“localhost:9999”作为配置文件服务网址(这是您在上一步中启动的性能分析器服务器的地址)。输入要进行分析的毫秒数,然后点击“捕获”。

  4. 如果您要分析的代码尚未运行(例如,如果您在 Python shell 中启动了性能分析器服务器),请在捕获运行时运行该代码。

  5. 捕获完成后,XProf 应会自动刷新。(并非所有 XProf 性能剖析功能都与 JAX 相关联,因此最初看起来可能好像没有捕获任何内容。)在左侧的“工具”下方,选择“Trace Viewer”。

您现在应该会看到执行时间轴。您可以使用 WASD 键浏览轨迹,点击或拖动选择事件,以便在底部查看更多详细信息。如需详细了解如何使用轨迹查看器,请参阅轨迹查看器工具文档

XProf 和 TensorBoard

XProf 是 TensorBoard 中用于支持性能分析和轨迹捕获功能的底层工具。只要安装了 xprof,TensorBoard 中就会显示“Profile”标签页。使用此功能与单独启动 XProf 相同,只要启动时指向同一日志目录即可。这包括个人资料捕获、分析和查看功能。XProf 取代了之前推荐的 tensorboard_plugin_profile 功能。

$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)

添加自定义轨迹事件

默认情况下,Trace Viewer 中的事件大多是低级内部 JAX 函数。您可以在代码中使用 jax.profiler.TraceAnnotationjax.profiler.annotate_function 添加自己的事件和函数。

配置性能剖析器选项

start_trace 方法接受一个可选的 profiler_options 参数,该参数可用于对分析器的行为进行精细控制。此参数应该是 jax.profiler.ProfileOptions 的实例。

例如,如需停用所有 Python 和主机轨迹,请运行以下命令:

import jax

options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

常规选项

  1. host_tracer_level:设置主机端活动的轨迹级别。

    支持的值:

    • 0:完全停用宿主 (CPU) 跟踪。
    • 1:仅启用用户插桩的 TraceMe 事件的跟踪。
    • 2:包含 1 级轨迹以及高级程序执行详情,例如耗时的 XLA 操作(默认)。
    • 3:包含 2 级轨迹以及更详细的低级程序执行详情,例如低成本 XLA 操作。
  2. device_tracer_level:控制是否启用设备跟踪。

    支持的值:

    • 0:停用设备跟踪。
    • 1:启用设备跟踪(默认)。
  3. python_tracer_level:控制是否启用 Python 跟踪。

    支持的值:

    • 0:停用 Python 函数调用跟踪(默认)。
    • 1:启用 Python 跟踪。

高级配置选项

TPU 选项

  1. tpu_trace_mode:指定 TPU 轨迹的模式。

    支持的值:

    • TRACE_ONLY_HOST:这意味着仅跟踪主机端 (CPU) 活动,而不收集设备 (TPU/GPU) 跟踪记录。
    • TRACE_ONLY_XLA:这意味着仅跟踪设备上的 XLA 级操作。
    • TRACE_COMPUTE:此轨迹用于跟踪设备上的计算操作。
    • TRACE_COMPUTE_AND_SYNC:此轨迹会跟踪设备上的计算操作和同步事件。

    如果未提供“tpu_trace_mode”,则 trace_mode 默认为 TRACE_ONLY_XLA

  2. tpu_num_sparse_cores_to_trace:指定要在 TPU 上跟踪的稀疏核心数量。

  3. tpu_num_sparse_core_tiles_to_trace:指定要在 TPU 上跟踪的每个稀疏核心内的 tile 数量。

  4. tpu_num_chips_to_profile_per_task:指定要为每个任务分析的 TPU 芯片数量。

GPU 选项

以下选项可用于 GPU 分析:

  • gpu_max_callback_api_events:设置 CUPTI 回调 API 收集的事件数上限。默认设置为 2*1024*1024
  • gpu_max_activity_api_events:设置 CUPTI 活动 API 收集的事件数上限。默认设置为 2*1024*1024
  • gpu_max_annotation_strings:设置可收集的注释字符串数上限。默认设置为 1024*1024
  • gpu_enable_nvtx_tracking:在 CUPTI 中启用 NVTX 跟踪。默认值为 False
  • gpu_enable_cupti_activity_graph_trace:针对 CUDA 图启用 CUPTI activity 图跟踪。默认设置为 False
  • gpu_pm_sample_counters:一个以英文逗号分隔的字符串,包含要使用 CUPTI 的 PM 抽样功能收集的 GPU 性能监控指标(例如 "sm__cycles_active.avg.pct_of_peak_sustained_elapsed")。默认情况下,PM 抽样处于停用状态。如需了解可用指标,请参阅 NVIDIA 的 CUPTI 文档
  • gpu_pm_sample_interval_us:设置 CUPTI PM 采样的采样间隔(以微秒为单位)。默认设置为 500
  • gpu_pm_sample_buffer_size_per_gpu_mb:为 CUPTI PM 采样设置每个设备的系统内存缓冲区大小(以 MB 为单位)。默认值为 64MB。支持的最大值为 4GB。
  • gpu_num_chips_to_profile_per_task:指定要为每个任务分析的 GPU 设备数量。如果未指定、设置为 0 或设置为无效值,则系统将对所有可用的 GPU 进行分析。这可用于减小轨迹收集大小。
  • gpu_dump_graph_node_mapping:如果启用,则将 CUDA 图节点映射信息转储到轨迹中。默认设置为 False

例如:

options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}

如果发现任何无法识别的键或选项值,则返回 InvalidArgumentError