当加速器(GPU 或 TPU)的高带宽内存 (HBM) 容量耗尽时,就会发生内存不足 (OOM) 错误。E1000 - 编译时 HBM OOM 文档和 JAX GPU 内存分配文档中详细介绍了 OOM 问题的一些常见原因和调试技巧。
本页介绍了如何使用 XProf 的内存查看器工具直观呈现 JAX 程序的内存使用情况、识别内存使用峰值实例,以及调试 OOM 错误。这涉及到下列步骤:
- 使用
jax.profiler.trace运行程序以捕获配置文件。 - 在后台启动 XProf,然后使用内存查看器工具查看内存利用率详情。
示例程序
以下 JAX 程序会导致 OOM 错误:
import jax
from jax import random
import jax.numpy as jnp
@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
return a @ a
if __name__ == "__main__":
oom()
在 TPU 机器上,此程序会失败,并显示以下错误消息:
XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>
(在 GPU 机器上,错误看起来像这样:XlaRuntimeError: RESOURCE_EXHAUSTED:
Out of memory while trying to allocate 214748364800 bytes.)
运行 XProf
安装 xprof (pip install xprof),然后启动 XProf 实例,并指定存储配置文件的目录:
xprof --logdir=/tmp/xprof/ --port=6006
前往实例(在本地机器上,位于 http://localhost:6006)。在工具下拉菜单中,选择内存查看器,然后在“内存查看器”工具窗口中,从内存类型下拉菜单中选择 HBM(通常默认处于选中状态)。

XProf:内存查看器工具文档介绍了该工具的组件和显示的信息。
重点关注峰值内存分配时的 HLO 操作部分,该部分显示了峰值内存使用点处的三个缓冲区图表。缓冲区包括:* 程序输入和输出:训练批次、优化器状态等 * TensorCore 和 SparseCore 临时变量:中间计算(如激活、梯度等)所需的动态内存
您可以将鼠标悬停在缓冲区图表上,详细了解相应操作,例如其大小、形状、分配类型等。这有助于您识别和评估可能具有高临时变量或持久临时变量的运算、具有低效填充的任何大型输入/中间/输出张量等,这些都会导致内存峰值,需要进行调整或优化。
如需了解具体的调试技巧,请参阅 E1000:调试。