使用 XProf 调试 OOM 错误

当加速器(GPU 或 TPU)的高带宽内存 (HBM) 容量耗尽时,就会发生内存不足 (OOM) 错误。E1000 - 编译时 HBM OOM 文档JAX GPU 内存分配文档中详细介绍了 OOM 问题的一些常见原因和调试技巧。

本页介绍了如何使用 XProf 的内存查看器工具直观呈现 JAX 程序的内存使用情况、识别内存使用峰值实例,以及调试 OOM 错误。这涉及到下列步骤:

  1. 使用 jax.profiler.trace 运行程序以捕获配置文件。
  2. 在后台启动 XProf,然后使用内存查看器工具查看内存利用率详情。

示例程序

以下 JAX 程序会导致 OOM 错误:

import jax
from jax import random
import jax.numpy as jnp


@jax.profiler.trace("/tmp/xprof")
@jax.jit
def oom():
    a = random.normal(random.PRNGKey(1), (327680, 327680), dtype=jnp.bfloat16)
    return a @ a


if __name__ == "__main__":
    oom()

在 TPU 机器上,此程序会失败,并显示以下错误消息:

XlaRuntimeError: RESOURCE_EXHAUSTED: Allocation (size=107374182400) would exceed memory (size=17179869184) :: #allocation7 [shape = 'u8[327680,327680]{1,0:T(8,128)(4,1)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of xor_convert_fusion@{}'] :: <no-hlo-instruction>

(在 GPU 机器上,错误看起来像这样:XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 214748364800 bytes.

运行 XProf

安装 xprof (pip install xprof),然后启动 XProf 实例,并指定存储配置文件的目录:

xprof --logdir=/tmp/xprof/ --port=6006

前往实例(在本地机器上,位于 http://localhost:6006)。在工具下拉菜单中,选择内存查看器,然后在“内存查看器”工具窗口中,从内存类型下拉菜单中选择 HBM(通常默认处于选中状态)。

上述示例程序的 XProf 内存查看器页面

XProf:内存查看器工具文档介绍了该工具的组件和显示的信息。

重点关注峰值内存分配时的 HLO 操作部分,该部分显示了峰值内存使用点处的三个缓冲区图表。缓冲区包括:* 程序输入和输出:训练批次、优化器状态等 * TensorCore 和 SparseCore 临时变量:中间计算(如激活、梯度等)所需的动态内存

您可以将鼠标悬停在缓冲区图表上,详细了解相应操作,例如其大小、形状、分配类型等。这有助于您识别和评估可能具有高临时变量或持久临时变量的运算、具有低效填充的任何大型输入/中间/输出张量等,这些都会导致内存峰值,需要进行调整或优化。

如需了解具体的调试技巧,请参阅 E1000:调试