錯誤代碼:E2001

類別:編譯時間:硬體上不支援的 RHS DataType

如果矩陣乘法 (例如 jax.lax.dot_generaljax.lax.convjax.numpy.matmul@ 運算子) 中,用於右側 (RHS) 運算元的資料型別,並非所用特定 TPU 代別原生支援的型別,就會發生這項錯誤。

錯誤訊息範例:

INTERNAL: Mosaic failed to compile TPU kernel: Unsupported matmul RHS type on target: 'vector<256x256xi8>'
...

The MLIR operation involved:
%13440 = "tpu.matmul"(%13435, %13437, %13439) <dimension_numbers = #tpu.dot_dimension_numbers<...>

XLA 後端:TPU

總覽

TPU 的矩陣乘法單元 (MXU) 原生支援所有硬體世代的 Float32 作業。

不過,硬體世代不同,對 BFloat16 和其他量化資料類型 (例如 Int4、Int8 或 Float8) 的原生支援程度也不同。當核心嘗試使用特定 TPU 代別沒有實體電路可執行的資料類型,將矩陣乘法對應至 MXU 時,就會觸發這項錯誤。

這項錯誤通常表示編譯器的「標準化」傳遞 (嘗試自動將不支援的型別轉換為支援的型別,例如透過軟體模擬) 找不到有效的轉換規則,或是因為「相容性模式」已停用而無法執行轉換。

偵錯

如要解決這個錯誤,請務必讓資料型別與硬體功能一致。您可以採取下列做法:

1. 轉換為原生型別

最可靠的修正方式是在 matmul 運算前,於核心內將運算元手動轉換為硬體支援的資料型別 (例如 TPU v4+ 上的 Float32BFloat16)。

  • 原因: Float32 是所有 TPU 世代的 MXU 原生支援的通用資料型別。
  • 取捨:這會產生 VPU (向量處理單元) 成本,也就是執行層級轉換所需的週期,但可確保核心會在目前的硬體上執行。

2. 檢查相容性模式

一般來說,編譯器會自動處理這些類型不符的問題,方法是在預設啟用的相容模式中執行。請仔細檢查 XLA 設定,確認 --xla_mosaic_compat_mode 未設為 false。

這項功能會充當「填補程式」,為硬體不支援的操作注入軟體模擬序列。

相容模式可啟用以下功能:

  • 混合精確度 MatMul:可透過自動插入轉換作業 (例如在 matmul 前將整數擴充為 Float32),混合整數運算元和浮點累加器。
  • 低精確度模擬:在特定硬體世代中,會先將不支援的型別 (例如 4-bit 浮點數 (4E2M1FN) 或 8-bit 浮點數 (8E4M3FN)) 擴充為支援的型別 (例如 BFloat16Float32),再執行模擬。

請注意,這個模式會優先考量相容性,而非最高效能,因為模擬需要額外指令來轉換資料格式,MXU 才能處理資料。

3. 升級硬體或要求支援

如果演算法嚴格要求 Int4Float8 等類型的原生效能,且不允許投放或模擬的額外負荷,則必須在原生支援的新一代 TPU 上執行。

功能要求:如果您認為硬體支援這項作業,或編譯器即使在相容模式下仍缺少有效的模擬路徑,請提出功能要求。我們通常會保證作業向前相容。因此,如果您的核心在某一代 TPU 上執行,就應該能在所有後續世代上執行,但無法保證能模擬舊世代 (因為某些舊世代的轉換成本會非常高)。