类别:运行时:程序分配失败
此错误表示 TPU 设备上的 XLA 运行时未能将已编译的 XLA 程序可执行文件加载到 TPU 的 HBM 中。
示例错误消息:
XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program 'jit_embedding_pipeline_step_fn': Attempting to reserve 29.49G at the bottom of memory. That was not possible. There are 147.64M free, 0B reserved, and 147.64M reservable. Scope: unknown..: while running replica 0 and partition 34 of a replicated computation (other replicas may have failed as well).
XLA 后端:TPU
概览
此错误通常是由以下某种原因造成的:
- 程序大小超出可用 HBM:编译后的 XLA 程序(包括其指令、静态数据和任何嵌入式常量)大于当前在加载该程序的特定 TPU 核心上可用的空闲 HBM 总量。
- HBM 碎片化:虽然设备上的总可用 HBM 可能足够,但它不是以单个连续块的形式提供,因此无法容纳整个程序。
请务必了解 TPU 运行时如何确定内存优先级。缓冲区分配优先于已加载的程序。如果缓冲区分配失败,运行时将从 HBM 中逐出已加载的程序,以释放空间。这可能会导致之前成功加载的程序现在因 HBM 被更多数据缓冲区占用而失败,并显示 OOM 错误。
调试
- 减少缓冲区内存占用空间:释放数据缓冲区使用的内存将为程序本身留出更多空间:
- 减小批次大小:这是减少激活所用内存量的最有效方法之一。
- 参数分片:对于超大型模型,请使用模型并行或分片技术(例如 FSDP 或 Megascale)将模型的参数和计算分布到多个 TPU 核心或主机上。
- 缩短序列/上下文长度:对于处理序列数据的模型(例如,NLP 模型),缩短序列长度可以显著减少内存使用量。
- 缓冲区捐赠:使用框架功能(例如,
jax.jit(..., donate_argnums=...)) 以允许 XLA 重用输入缓冲区的内存来存储输出,从而减少内存用量峰值。
- 减少临时变量的程序内存要求:
- 使用
tpu_shared_memory_percent标志减少临时变量的程序内存用量。请注意,这可能会对性能产生负面影响。
- 使用
- 优化执行策略/减少投放负荷:
- 管理程序加载:如果您要对多个函数进行 JIT 编译,请注意每个函数都可能导致加载一个程序。尽量合理安排工作负载,以最大限度地减少同时加载的程序数量。
- 确保没有内存泄漏:
- 确保对
jax.Array对象的引用不会保留过长时间。即使在程序编译完成后,保留jax.Array对象也可能会阻止自动取消分配。
- 确保对
工具
- 启用
tpu_log_allocations_on_oom标志后,当发生 OOM 时,分配器将转储所有当前分配的详细报告,这对于调试来说非常宝贵。 - 分析程序:使用 JAX 内存分析器或 TensorFlow 分析器详细了解程序在一段时间内的内存用量。这有助于识别内存消耗的意外峰值。