Shardy 是一种基于 MLIR 的张量分区系统,适用于所有方言。该工具由 GSPMD 和 PartIR 团队联合打造,融合了这两个系统的优势,以及团队和用户的共同经验。
优势
- 通过将 GSPMD 的传播与 PartIR 的增量分区相结合,为用户提供更强的控制力和可预测性。
- 由共享体验驱动的新功能,例如对重塑的全新支持。重塑功能众所周知会产生额外的沟通,除非用户知道如何规避。
- 更好的实用性和可调试性,以提高最终用户的开发速度,例如通过使用基于轴的分片表示法。
- 使用 MLIR 的简单开源代码库,拥有更广泛的活跃贡献者(内部、外部和不同时区的贡献者),可为用户提供支持。
组件
- 分片表示法:基于轴的分片表示法,绑定到特定逻辑网格(可能有多个网格),支持限制维度分片和轴、为重塑等操作拆分轴、增量分区的优先级等。
- 编译器 API:一组编译器组件,可与分片表示法搭配使用,以影响分片传播。
- 输入/输出分片 - 将分片附加到主函数的输入或输出,以指明在将输入/输出张量传递给/从函数返回时,应采用这种分片方式。
- 分片约束条件 - 将分片附加到中间张量(例如矩阵乘法结果),以指明该张量或其用途的一部分应采用这种分片方式。
- 按/类似分片 - 按 ID 对多个张量进行分组,以指明它们应以相同方式进行分片。
- 手动计算 - 用于封装使用部分网格轴手动分区的子计算,其中为所有输入和输出指定了沿这些手动轴的分片,并且在子计算内,张量类型相对于这些分片是本地的。
- 分片传播:一种传播算法,可将用户优先级和分片约束条件与编译器成本模型和启发词语相结合:
- 用户定义的优先级,例如先执行批处理并行处理,然后再执行 ZeRO
- 基于运算的优先级,例如先执行元素级运算,然后执行矩阵乘法等。
- 更精细的启发词语,例如首选批量维度。
- SPMD 分区器:这是一个组件,通过将程序划分为 SPMD 程序,并在此过程中添加必要的数据移动/格式设置和集合操作,从而降低分片传播决策的难度。
- 短期内,初始实现将使用当前的 GSPMD SPMD 分区器。
- 从长远来看,我们计划创建一个基于 MLIR 的新 SPMD 分区器。
代码库
Shardy 项目正在积极开发中,我们希望获得开源社区的反馈。Shardy 代码位于 https://github.com/openxla/shardy