错误代码:0101

类别:运行时:程序分配失败

此错误表示 TPU 设备上的 XLA 运行时未能将已编译的 XLA 程序可执行文件加载到 TPU 的 HBM 中。

示例错误消息

XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program 'jit_embedding_pipeline_step_fn': Attempting to reserve 29.49G at the bottom of memory. That was not possible. There are 147.64M free, 0B reserved, and 147.64M reservable. Scope: unknown..: while running replica 0 and partition 34 of a replicated computation (other replicas may have failed as well).

XLA 后端:TPU

概览

此错误通常是由以下某种原因造成的:

  • 程序大小超出可用 HBM:编译后的 XLA 程序(包括其指令、静态数据和任何嵌入式常量)大于当前在加载该程序的特定 TPU 核心上可用的空闲 HBM 总量。
  • HBM 碎片化:虽然设备上的总可用 HBM 可能足够,但它不是以单个连续块的形式提供,因此无法容纳整个程序。

请务必了解 TPU 运行时如何确定内存优先级。缓冲区分配优先于已加载的程序。如果缓冲区分配失败,运行时将从 HBM 中逐出已加载的程序,以释放空间。这可能会导致之前成功加载的程序现在因 HBM 被更多数据缓冲区占用而失败,并显示 OOM 错误。

调试

  • 减少缓冲区内存占用空间:释放数据缓冲区使用的内存将为程序本身留出更多空间:
    • 减小批次大小:这是减少激活所用内存量的最有效方法之一。
    • 参数分片:对于超大型模型,请使用模型并行或分片技术(例如 FSDP 或 Megascale)将模型的参数和计算分布到多个 TPU 核心或主机上。
    • 缩短序列/上下文长度:对于处理序列数据的模型(例如,NLP 模型),缩短序列长度可以显著减少内存使用量。
    • 缓冲区捐赠:使用框架功能(例如,jax.jit(..., donate_argnums=...)) 以允许 XLA 重用输入缓冲区的内存来存储输出,从而减少内存用量峰值。
  • 减少临时变量的程序内存要求:
    • 使用 tpu_shared_memory_percent 标志减少临时变量的程序内存用量。请注意,这可能会对性能产生负面影响。
  • 优化执行策略/减少投放负荷:
    • 管理程序加载:如果您要对多个函数进行 JIT 编译,请注意每个函数都可能导致加载一个程序。尽量合理安排工作负载,以最大限度地减少同时加载的程序数量。
  • 确保没有内存泄漏:
    • 确保对 jax.Array 对象的引用不会保留过长时间。即使在程序编译完成后,保留 jax.Array 对象也可能会阻止自动取消分配。

工具

  • 启用 tpu_log_allocations_on_oom 标志后,当发生 OOM 时,分配器将转储所有当前分配的详细报告,这对于调试来说非常宝贵。
  • 分析程序:使用 JAX 内存分析器或 TensorFlow 分析器详细了解程序在一段时间内的内存用量。这有助于识别内存消耗的意外峰值。