类别:编译时:硬件上不支持的 RHS 数据类型
当矩阵乘法(例如,jax.lax.dot_general、jax.lax.conv、jax.numpy.matmul 或 @ 运算符)不受所用特定 TPU 代系的本地支持。
示例错误消息:
INTERNAL: Mosaic failed to compile TPU kernel: Unsupported matmul RHS type on target: 'vector<256x256xi8>'
...
The MLIR operation involved:
%13440 = "tpu.matmul"(%13435, %13437, %13439) <dimension_numbers = #tpu.dot_dimension_numbers<...>
XLA 后端:TPU
概览
TPU 的矩阵乘法单元 (MXU) 在所有硬件代系中都原生支持 Float32 操作。
不过,对 BFloat16 和其他量化数据类型(例如 Int4、Int8 或 Float8)因硬件代际而异。当内核尝试使用特定 TPU 代系没有相应物理电路来执行的数据类型将矩阵乘法映射到 MXU 时,会触发此错误。
此错误通常表示,编译器的规范化传递(尝试自动将不受支持的类型转换为受支持的类型,例如通过软件模拟)无法找到有效的转换规则,或者由于兼容性模式处于停用状态而无法找到有效的转换规则。
调试
如需解决此错误,您必须使数据类型与硬件的功能保持一致。您可以选择以下选项:
1. 转换为原生类型
最可靠的修复方法是在 matmul 操作之前,在内核中手动将操作数转换为硬件支持的数据类型(例如 TPU v4+ 上的 Float32 或 BFloat16)。
- 原因:
Float32是所有 TPU 世代的 MXU 原生支持的通用数据类型。 - 权衡:这会产生 VPU(向量处理单元)费用(即执行转换所需的周期),但可保证您的内核在当前硬件上运行。
2. 检查兼容模式
通常,编译器可以在默认启用的兼容模式下自动处理这些类型不匹配问题。仔细检查 XLA 配置,确保 --xla_mosaic_compat_mode 未设置为 false。
这相当于一个“polyfill”,可为硬件本身不支持的操作注入软件模拟序列。
兼容模式可实现以下功能:
- 混合精度 MatMul:通过自动插入转换操作(例如,在 matmul 之前将整数扩展为
Float32),允许将整数操作数与浮点累加器混合使用。 - 低精度模拟:在某些硬件代系上,通过将不支持的类型(如
4-bit浮点数 (4E2M1FN) 或8-bit浮点数 (8E4M3FN))扩展为支持的类型(如BFloat16或Float32),在执行之前模拟这些类型。
请注意,此模式优先考虑兼容性而非峰值性能,因为仿真需要额外的指令来转换数据格式,然后 MXU 才能对这些数据进行操作。
3. 升级硬件或申请支持
如果您的算法严格要求 Int4 或 Float8 等类型具有原生性能,而没有转换或模拟的开销,则需要使用原生支持的较新一代 TPU。
功能请求:如果您认为自己的硬件支持此操作,或者即使在兼容模式下,编译器也缺少有效的模拟路径,请提交功能请求。我们通常保证操作是向前兼容的。因此,如果您的内核可在某个 TPU 代系上运行,那么它应该也能在所有未来的代系上运行。但无法保证它能模拟旧版设备(对于某些旧版设备,投屏成本会非常高)。