类别:编译时:HBM OOM
此错误表示程序所需的高带宽内存 (HBM) 比 TPU 设备上实际可用的 HBM 多。
示例错误消息:
RESOURCE_EXHAUSTED: TPU TensorCore Hbm usage: 34.82G, SparseCore Hbm usage 174.10G, exceeding available bytes: 95.74G
RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 49.34G of 32.00G hbm. Exceeded hbm capacity by 17.34G.
XLA 后端:TPU
概览
XLA 会执行检查,以确保所有必需的静态分配的总大小适合设备的 HBM。
编译器会针对以下几种类型的分配管理 TPU 的固定 HBM 容量:
- 程序输入和输出:训练批次、优化器状态等。
- TPU 临时变量:中间计算(例如激活、梯度等)所需的动态内存。
- 已编译的二进制文件:TensorCore (TC) 和 SparseCore (SC) 的机器代码。
- 系统开销:为 XLA 运行时预留的空间(例如,旧版 TPU 中的 infeed 缓冲区)。
- 常量:嵌入在 HLO IR 中的常量值分配在 HBM 上。
- 编译器内部:程序级和每个 HLO 的分配(例如网格中节点的路由信息)。
当 XLA 编译器无法将上述所有分配都放入设备 HBM 中时,会发生此错误。
调试
仔细分析错误消息和日志,确定以下哪个类别的 HBM OOM 最能准确描述您的错误:
- “TC Hbm usage: X, SC Hbm usage Y”:如果错误明确细分了 HBM 用量,则 TensorCore (TC) + SparseCore (SC) 的总用量超过了 HBM 限制。→ 跳转到情形 1. 平衡 TC 和 SC HBM 使用率。
- “Ran out of memory in memory space HBM”:检查日志,查看 HBM 上最大分配的枚举。
- 如果存在一个或多个出乎意料的大张量(例如,超过 HBM 限制的 50%)→ 跳转到方案 2。因意外的大量分配而导致内存不足。
- 如果日志中没有意外的大张量 → 跳至情形 3。因总分配量过大而导致内存不足。
场景 1。平衡 TC 和 SC HBM 使用率
如果错误明确细分了用量,例如 “TC Hbm usage: X, SC Hbm usage Y”,则表示 TensorCore (TC) + SparseCore (SC) 的总用量超过了 HBM 限制。比较这两个值以确定瓶颈:
- SparseCore 使用率高
- 优化 HBM 堆栈使用情况:HBM 堆栈内存消耗量与
feature_width、max_unique_nz_per_row和logical_replica_count成正比。 您可以通过调整--xla_sc_num_serialized_tables_to_optimize_hbm标志来减少峰值堆栈使用量,该标志可序列化表的处理。但代价是并行性降低。 - 检查填充开销:SparseCore 将嵌入表对齐到 32B(8 个浮点数)。特征宽度较小的表(例如,< 8 个浮点数)会产生大量的填充开销,从而浪费 HBM。
- 减少堆使用量:较高的
maximum_parallel_iterations值会增加预取到 HBM 堆中的输入数据量。降低此值可以释放大量内存。 - 验证分片:确保嵌入表在所有芯片上都已正确进行 mod 分片。请参阅限制如何转化为表格。
- 如需了解更多相关信息,请参阅 SC:性能和内存瓶颈。
- 优化 HBM 堆栈使用情况:HBM 堆栈内存消耗量与
- TensorCore 使用率高
- 前往场景 2。
- 平衡
- 如果两者单独来看都不算过高,但总和过高,则说明您已达到芯片的容量上限。您必须尝试降低这两个组件的用量。 遵循所有三个部分中的建议。
情景 2。因分配的内存过大而导致内存不足
如果您看到错误消息 "Ran out of memory in memory space HBM",并且日志中存在一个或多个意外的大型分配(超过 HBM 限制的 50%),则几乎不可能是硬件容量问题。这通常是配置错误。检查大型分配的 XLA 标签(如果有),以获取有关其 JAX 源代码的提示。
- 移除调试制品
- 在大规模运行中使用 jax.debug.print() 可能会强制编译器在 HBM 中具体化完整的张量以将其转移到 CPU,从而破坏融合并增加峰值内存使用量。移除所有剩余的
jax.debug.print()。
- 在大规模运行中使用 jax.debug.print() 可能会强制编译器在 HBM 中具体化完整的张量以将其转移到 CPU,从而破坏融合并增加峰值内存使用量。移除所有剩余的
- 修正低效的网格形状或分片
- 错误的网格形状或缺少分片注解可能会导致编译器默认采用复制,从而迫使编译器尝试将非常大的张量拟合到单个芯片上。
- 检查大型分配的形状,并验证分片是否由 XLA 正确指定和传播。
方案 3。因总分配量而导致内存不足
如果您看到错误消息“Ran out of memory in memory space HBM”(内存空间 HBM 中内存不足),并且日志中没有意外的大张量,则表示程序因分配的总和超过 HBM 限制而耗尽容量。在这种情况下,直观呈现内存配置文件通常有助于确定导致内存使用量达到峰值的特定缓冲区。如需有关如何识别峰值内存贡献者的分步指南,请参阅使用 XProf 调试 OOM 错误。
确定一些主要贡献者后,请按照以下步骤优化内存占用。
方案 3.A 调整配置
您通常可以通过调整以下配置来解决 OOM 问题:
- 减小批次大小:中间激活和梯度所需的内存与批次大小成正比。减小批次大小通常有助于减少内存用量。
- 捐赠输入缓冲区:使用
jax.jit时,请为模型参数指定 donate_argnums。这允许 XLA 使用输出覆盖输入内存。 - 启用混合精度 (bfloat16):如果模型架构和质量要求允许,请对程序中最大的张量使用 bfloat16 或量化(int8 等)。请注意,此更改可能会影响模型行为,因此应仔细考虑。
场景 3.B 优化架构和分片
如果配置更改不够充分,则模型拓扑可能对于当前的硬件设置而言过大。
- 使用更新的 TPU 代系:更新的 TPU 通常每个芯片提供更多 HBM;如果可用,请切换到更新的 TPU 代系。
- 在更大的芯片拓扑上运行:如果模型权重对于现有拓扑来说过大,您可以尝试将它们分片到更多芯片上。
- 实现高级分片技术:
- 探索更高级的数据、张量或流水线并行处理方法。
- 为中间值和输出指定分片提示。
- 使用 JAX 主机分流:主机分流技术允许用户将大型张量分流到主机 CPU 内存(例如激活分流和优化器状态分流)。
场景 3.C 检查张量填充和对齐
低效的张量形状是导致 TPU 上出现 OOM 的常见原因,但往往不会发出任何警告。为了在 TPU 上获得最佳性能,XLA 会填充张量维度,通常将最次要的维度填充为 128 的倍数,将次次要的维度填充为 8 的倍数。这种填充会影响输入数组和中间张量(HLO 临时变量),可能会显著增加内存用量,尤其是在维度大小较小的情况下。请参阅阵列布局。
- 审核大型缓冲区的形状:(在采用默认布局的 TPU v5 上)
- 将鼠标悬停在 Xprof 内存查看器中的缓冲区上会显示缓冲区详细信息卡片,其中包含缓冲区详细信息(包括填充信息)。
- 示例:形状为
(129, 1024)的张量可能会填充为(256, 1024),导致近 50% 的内存浪费。 - 更正:形状为
(128, 1024)时不需要填充,内存浪费为 0%。
- 对齐维度:确保所有大型张量维度(批次大小、嵌入维度、隐藏大小)都是 128 的倍数。请注意,此更改可能会影响模型行为,因此应仔细考虑。
场景 3.D 调整影响 XLA 的关键内存标志
您可以调整关键内存标志,以在性能和更低的内存用量之间做出权衡。不过,此策略应作为最后的手段,因为它可能会对性能产生不利影响。
场景 3.E 调整 XLA 重新实物化传递/手动检查点
如果模型接近于适合内存,您可以将 jax.checkpoint 修饰器与 jax.grad 搭配使用,以手动控制在正向传递中保存哪些中间变量,以及在反向传递中重新计算哪些中间变量,从而用计算周期换取 HBM。
或者,您可以强制 XLA::Rematerialization 传递优先考虑节省内存,但可能会导致编译速度变慢:
| 标志 | 说明 | 影响 / 权衡 |
|---|---|---|
--xla_tpu_max_hbm_size_mib |
手动设置重新实物化传递所使用的 HBM 大小上限。 | 强制编译器更加努力地将程序纳入小于实际物理 HBM 的限制中。 |
--xla_tpu_rematerialization_algo=PEAK_PRIORITY |
在内存用量达到峰值时集中精力。 | 与默认算法相比,在大幅减少内存方面可能更高效。 |
--xla_tpu_rematerialization_max_block_size_limit=32 |
控制一个块中可一次性重新实物化的指令数上限。 | 增加此值可以节省内存,但会显著增加编译时间。 |
--xla_tpu_rematerialization_block_effort_factor=10.0 |
定义了在搜索要重新具体化的块时花费的精力(编译时间)。 | 值越高,搜索内存节省的范围越广,但代价是编译时间增加。 |
--xla_tpu_pre_fusion_remat=true |
在融合传递之前启用额外的重新实物化传递。 | 可以节省更多内存,但会增加编译时间,并可能影响数值稳定性。 |
请注意,更改 XLA 标志应作为最后的手段,因为这可能会对性能产生不利影响。
方案 3.F 使用高级性能分析工具
使用 XProf 调试 OOM 错误提供了一个教程,介绍了如何使用 XProf 内存查看器直观呈现编译器对 HBM 使用情况的看法。
借助此工具,您可以查看峰值内存分配和缓冲区生命周期,这对于准确了解在峰值利用率时哪些内容会消耗 HBM 至关重要。如需了解常规性能剖析设置,请参阅 Xprof 使用入门和 TensorBoard 性能剖析。