XProf 是一种出色的工具,可用于获取和直观呈现程序的性能轨迹和配置文件,包括 GPU 和 TPU 上的活动。最终结果如下所示:

程序化捕获
您可以使用 jax.profiler.start_trace 和 jax.profiler.stop_trace 方法对代码进行插桩,以捕获 JAX 代码的分析器轨迹。使用要写入轨迹文件的目录调用 jax.profiler.start_trace。该目录应与用于启动 XProf 的 --logdir 目录相同。然后,您可以使用 XProf 查看轨迹。
例如,如需获取性能剖析器轨迹,请执行以下操作:
import jax
jax.profiler.start_trace("/tmp/profile-data")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
请注意 jax.block_until_ready 调用。我们使用此功能来确保轨迹捕获设备上的执行情况。如需详细了解为何需要这样做,请参阅异步调度。
您还可以使用 jax.profiler.trace 上下文管理器来替代 start_trace 和 stop_trace:
import jax
with jax.profiler.trace("/tmp/profile-data"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
查看轨迹
捕获轨迹后,您可以使用 XProf 界面查看轨迹。
您可以使用独立的 XProf 命令直接启动性能剖析器界面,只需将该命令指向您的日志目录即可:
$ xprof --port=8791 /tmp/profile-data
Attempting to start XProf server:
Log Directory: /tmp/profile-data
Port: 8791
Worker Service Address: 0.0.0.0:50051
Hide Capture Button: False
XProf at http://localhost:8791/ (Press CTRL+C to quit)
前往提供的网址(例如 http://localhost:8791/)在浏览器中查看个人资料。
左侧的“会话”下拉菜单中会显示可用的轨迹。选择您感兴趣的会话,然后在“工具”下拉菜单下选择“轨迹查看器”。您现在应该会看到执行时间轴。您可以使用 WASD 键浏览轨迹,点击或拖动可选择事件以查看更多详情。如需详细了解如何使用轨迹查看器,请参阅轨迹查看器工具文档。
通过 XProf 手动捕获
以下说明介绍了如何从正在运行的程序中捕获手动触发的 N 秒轨迹。
启动 XProf 服务器:
xprof --logdir /tmp/profile-data/您应该能够在
<http://localhost:8791/>加载 XProf。您可以使用--port标志指定其他端口。在要进行性能剖析的 Python 程序或进程中,在开头附近添加以下内容:
import jax.profiler jax.profiler.start_server(9999)此命令会启动 XProf 连接到的分析器服务器。分析器服务器必须处于运行状态,然后才能继续执行下一步。使用完服务器后,您可以调用
jax.profiler.stop_server()将其关闭。如果您想分析长时间运行的程序(例如长时间的训练循环)的一小段,可以将此代码放在程序开头,然后像往常一样启动程序。如果您想分析一个简短的程序(例如微基准),一种方法是在 IPython shell 中启动性能分析器服务器,并在下一步中开始捕获后使用
%run运行该简短的程序。另一种方法是在程序开始时启动性能分析器服务器,并使用time.sleep()为您提供足够的时间来开始捕获。打开
<http://localhost:8791/>,然后点击左上角的“捕获性能剖析文件”按钮。输入“localhost:9999”作为配置文件服务网址(这是您在上一步中启动的性能分析器服务器的地址)。输入要进行分析的毫秒数,然后点击“捕获”。如果您要分析的代码尚未运行(例如,如果您在 Python shell 中启动了性能分析器服务器),请在捕获运行时运行该代码。
捕获完成后,XProf 应会自动刷新。(并非所有 XProf 性能剖析功能都与 JAX 相关联,因此最初看起来可能好像没有捕获任何内容。)在左侧的“工具”下方,选择“Trace Viewer”。
您现在应该会看到执行时间轴。您可以使用 WASD 键浏览轨迹,点击或拖动选择事件,以便在底部查看更多详细信息。如需详细了解如何使用轨迹查看器,请参阅轨迹查看器工具文档。
XProf 和 TensorBoard
XProf 是 TensorBoard 中用于支持性能分析和轨迹捕获功能的底层工具。只要安装了 xprof,TensorBoard 中就会显示“Profile”标签页。使用此功能与单独启动 XProf 相同,只要启动时指向同一日志目录即可。这包括个人资料捕获、分析和查看功能。XProf 取代了之前推荐的 tensorboard_plugin_profile 功能。
$ tensorboard --logdir=/tmp/profile-data
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.19.0 at http://localhost:6006/ (Press CTRL+C to quit)
添加自定义轨迹事件
默认情况下,Trace Viewer 中的事件大多是低级内部 JAX 函数。您可以在代码中使用 jax.profiler.TraceAnnotation 和 jax.profiler.annotate_function 添加自己的事件和函数。
配置性能剖析器选项
start_trace 方法接受一个可选的 profiler_options 参数,该参数可用于对分析器的行为进行精细控制。此参数应该是 jax.profiler.ProfileOptions 的实例。
例如,如需停用所有 Python 和主机轨迹,请运行以下命令:
import jax
options = jax.profiler.ProfileOptions()
options.python_tracer_level = 0
options.host_tracer_level = 0
jax.profiler.start_trace("/tmp/profile-data", profiler_options=options)
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
常规选项
host_tracer_level:设置主机端活动的轨迹级别。支持的值:
0:完全停用宿主 (CPU) 跟踪。1:仅启用用户插桩的 TraceMe 事件的跟踪。2:包含 1 级轨迹以及高级程序执行详情,例如耗时的 XLA 操作(默认)。3:包含 2 级轨迹以及更详细的低级程序执行详情,例如低成本 XLA 操作。
device_tracer_level:控制是否启用设备跟踪。支持的值:
0:停用设备跟踪。1:启用设备跟踪(默认)。
python_tracer_level:控制是否启用 Python 跟踪。支持的值:
0:停用 Python 函数调用跟踪(默认)。1:启用 Python 跟踪。
高级配置选项
TPU 选项
tpu_trace_mode:指定 TPU 轨迹的模式。支持的值:
TRACE_ONLY_HOST:这意味着仅跟踪主机端 (CPU) 活动,而不收集设备 (TPU/GPU) 跟踪记录。TRACE_ONLY_XLA:这意味着仅跟踪设备上的 XLA 级操作。TRACE_COMPUTE:此轨迹用于跟踪设备上的计算操作。TRACE_COMPUTE_AND_SYNC:此轨迹会跟踪设备上的计算操作和同步事件。
如果未提供“tpu_trace_mode”,则 trace_mode 默认为
TRACE_ONLY_XLA。tpu_num_sparse_cores_to_trace:指定要在 TPU 上跟踪的稀疏核心数量。tpu_num_sparse_core_tiles_to_trace:指定要在 TPU 上跟踪的每个稀疏核心内的 tile 数量。tpu_num_chips_to_profile_per_task:指定要为每个任务分析的 TPU 芯片数量。
GPU 选项
以下选项可用于 GPU 分析:
gpu_max_callback_api_events:设置 CUPTI 回调 API 收集的事件数上限。默认设置为2*1024*1024。gpu_max_activity_api_events:设置 CUPTI 活动 API 收集的事件数上限。默认设置为2*1024*1024。gpu_max_annotation_strings:设置可收集的注释字符串数上限。默认设置为1024*1024。gpu_enable_nvtx_tracking:在 CUPTI 中启用 NVTX 跟踪。默认值为False。gpu_enable_cupti_activity_graph_trace:针对 CUDA 图启用 CUPTI activity 图跟踪。默认设置为False。gpu_pm_sample_counters:一个以英文逗号分隔的字符串,包含要使用 CUPTI 的 PM 抽样功能收集的 GPU 性能监控指标(例如"sm__cycles_active.avg.pct_of_peak_sustained_elapsed")。默认情况下,PM 抽样处于停用状态。如需了解可用指标,请参阅 NVIDIA 的 CUPTI 文档。gpu_pm_sample_interval_us:设置 CUPTI PM 采样的采样间隔(以微秒为单位)。默认设置为500。gpu_pm_sample_buffer_size_per_gpu_mb:为 CUPTI PM 采样设置每个设备的系统内存缓冲区大小(以 MB 为单位)。默认值为 64MB。支持的最大值为 4GB。gpu_num_chips_to_profile_per_task:指定要为每个任务分析的 GPU 设备数量。如果未指定、设置为 0 或设置为无效值,则系统将对所有可用的 GPU 进行分析。这可用于减小轨迹收集大小。gpu_dump_graph_node_mapping:如果启用,则将 CUDA 图节点映射信息转储到轨迹中。默认设置为False。
例如:
options = ProfileOptions()
options.advanced_configuration = {"tpu_trace_mode" : "TRACE_ONLY_HOST", "tpu_num_sparse_cores_to_trace" : 2}
如果发现任何无法识别的键或选项值,则返回 InvalidArgumentError。