類別:執行階段:程式分配失敗
這項錯誤表示 TPU 裝置上的 XLA 執行階段無法將已編譯的 XLA 程式可執行檔載入 TPU 的 HBM。
錯誤訊息示例:
XlaRuntimeError: RESOURCE_EXHAUSTED: Error loading program 'jit_embedding_pipeline_step_fn': Attempting to reserve 29.49G at the bottom of memory. That was not possible. There are 147.64M free, 0B reserved, and 147.64M reservable. Scope: unknown..: while running replica 0 and partition 34 of a replicated computation (other replicas may have failed as well).
XLA 後端:TPU
總覽
這項錯誤的常見原因如下:
- 程式大小超出可用 HBM:編譯的 XLA 程式(包括指令、靜態資料和任何內嵌常數) 大於特定 TPU 核心上目前可用的 HBM 總量,因此無法載入程式。
- HBM 碎片化:裝置上的 HBM 總可用空間可能足夠,但無法提供足夠大的連續區塊來容納整個程式。
請務必瞭解 TPU 執行階段如何決定記憶體優先順序。緩衝區分配的優先順序高於載入的程式。如果緩衝區分配失敗,執行階段會從 HBM 逐出已載入的程式,藉此釋出空間。這可能會導致先前成功載入的程式現在因 OOM 錯誤而失敗,因為 HBM 現在佔用更多資料緩衝區。
偵錯
- 減少緩衝區記憶體用量:釋出資料緩衝區使用的記憶體,可為程式本身留下更多空間:
- 減少批次大小:這是減少啟用記憶體用量的最有效方法之一。
- 參數分片:對於非常大型的模型,請使用模型平行化或分片技術 (例如 FSDP 或 Megascale),將模型的參數和運算作業分配到多個 TPU 核心或主機。
- 縮短序列/內容長度:對於處理序列資料的模型 (例如NLP 模型),縮短序列長度可大幅減少記憶體用量。
- 緩衝區捐贈:使用架構功能 (例如
jax.jit(..., donate_argnums=...)),讓 XLA 重複使用輸入緩衝區的記憶體來儲存輸出內容,進而減少記憶體用量峰值。
- 減少程式的暫時性記憶體需求:
- 使用
tpu_shared_memory_percent旗標,減少程式的暫時性記憶體用量。請注意,這可能會對效能造成負面影響。
- 使用
- 最佳化執行策略/減少放送負載:
- 管理程式載入:如果您要 JIT 編譯多個函式,請注意每個函式都可能導致程式載入。請盡量安排工作負載,減少同時載入的程式數量。
- 確保沒有記憶體流失:
- 請確保對
jax.Array物件的參照不會保留超過預期時間。即使程式編譯完成,保留jax.Array物件也可能會導致系統無法自動取消分配。
- 請確保對
工具
- 啟用
tpu_log_allocations_on_oom標記,當發生 OOM 時,分配器會傾印所有目前分配的詳細報告,這對偵錯來說非常寶貴。 - 分析程式:使用 JAX 記憶體分析器或 TensorFlow 分析器,詳細瞭解程式的記憶體用量隨時間變化的情況。這有助於找出記憶體用量意外激增的情況。