背景
我们假定读者至少熟悉分片表示的基础知识,其中介绍了如何在 Shardy 中表示张量的分片。本文档介绍了如何在程序中使用分片表示法,例如将分片附加到程序的特定张量。
分片传播是指在给定部分张量的分片约束条件的情况下,为程序中的每个张量确定分片的过程。Shardy 的编译器 API 提供了多种影响/控制分片传播的方法。此外,它还允许用户将手动分片的计算插入其程序中。
目标
本文档介绍了 Shardy 中此类 API 组件的设计,并说明了它们的行为和不变性。请注意,虽然此 API 用于控制分片传播,但本文档不会讨论任何有关传播行为或设计方式的内容。
概览
输入/输出分片 - 将分片附加到主函数的输入或输出,以指明在将输入/输出张量传递给/从函数返回时,应采用这种分片方式。
分片约束条件 - 将分片附加到中间张量(例如矩阵乘法结果),以指明应采用这种方式对该张量或其用例的一部分进行分片。
分片组 - 按 ID 对多个张量进行分组,以指明它们应以相同方式进行分片。
手动计算 - 用于封装使用部分网格轴手动分区的子计算,其中为所有输入和输出指定了沿这些手动轴的分片,并且在子计算内,张量类型相对于这些分片是本地的。
详细设计
输入/输出分片
允许用户为主函数的输入和输出指定分片。
在 MLIR 中,可以将属性附加到函数参数和结果,因此用户可以通过这种方式将分片属性附加到函数。
例如:
@mesh_xy = <["x"=2, "y"=2]>
// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
{sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
%arg1: tensor<8x16xf32>)
-> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
...
}
分片约束条件
允许用户将分片附加到程序中的中间张量,这会告知分区器应如何对该张量或其用例的一部分进行分片。
这是一个 MLIR 运算,它接受张量作为输入,并附加了分片属性。该操作可以:
- 没有用途(悬空)- 这意味着附加的分片方式应为张量本身的分片方式。
- 有用途 - 这意味着附加的分片方式是分片约束操作的用途应采用的分片方式,而输入张量的其他用途可能采用不同的分片方式(如果输入张量没有其他用途,则行为与没有用途的情况相同)。传播将确定张量本身的分片,并在必要时对其进行重新分片。
它可以具有开放维度分片,这意味着操作数可以沿可用轴进一步分片。
@mesh_xy = <["x"=2, "y"=2]>
%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>
分片组
如果两个或更多张量之间没有数据依赖项或没有强数据依赖项,而用户知道这些张量应以相同或类似的方式进行分区,Shardy API 提供了一种指定此关系的方法。这样,用户就可以自由地明确指定张量应彼此分区。
为此,我们引入了分片组的概念,其中每个组包含与同一分片组 ID 关联的任意数量的说明。分片组会强制要求同一组中的分片相同。
例如,在假设的用户程序(如以下所示)中,我们希望将程序的输出分片与程序的输入完全相同,同时这两者之间没有数据依赖项。
如果我们运行此程序,分片传播将无法推断张量 %1
和 %2
的分片,它们最终会被复制。不过,通过附加一个 shard_group
属性(表明输入 %0
和输出 %2
位于同一 shard_group
中),我们允许将分片 @mesh_xy,
[{"x"},{"y"}]>
从输入 %0
传播到输出 %2
,进而传播到图的其余部分(此处为广播常量 %1
)。我们可以使用 sdy.sharding_group
运算为组分配值。
@mesh_xy = <["x"=2, "y"=2]>
module @"jit_zeros_like" {
func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
%0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
%1 = stablehlo.constant dense<0> : tensor<8x2xi64>
%2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
return %2 : tensor<8x2xi64>
}
}
在上面的这个简单示例中,我们也可以为输出明确指定与输入相同的分片,这将会产生相同的效果,因为我们已经知道要为输入分配哪个分片,但在更现实的情况下,我们使用分片是为了让多个张量的分片保持同步,而无需了解其中任何一个张量的分片,Shardy 会负责处理其余事宜,并找到最适合分配给它们的分片。
手动计算
用户可能希望明确控制计算的各部分如何分区以及使用哪些集合。例如,有些用户希望手动(通过前端 API)应用集体 matmul,而不是推迟到编译器。我们提供了一个手动计算 API,可供他们执行此操作。
这是 MLIR 操作,其中包含一个用于手动子计算的区域。 用户可以使用网格轴的一部分(可能包括全部)为此子计算指定输入/输出分片。相对于指定的网格轴(也称为手动轴),子计算将是局部/手动;相对于未指定的轴(也称为自由轴),子计算将是全局/未分区。在传播期间,子计算可以沿自由轴进一步分片,就像在此操作之外进行计算一样。
例如:
@mesh_name = <["data"=2, "model"=2]>
%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
out_shardings=[<@mesh_name, [{"data"}, {?}]>]
manual_axes={"data"}
(%arg1: tensor<8x32xf32>) {
// body
return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
不变量
所有
in_shardings
、out_shardings
和manual_axes
都必须引用同一网格。manual_axes
会按网格排序。必须在所有入/出分片中明确使用
manual_axes
,也就是说,对于每个分片,所有手动轴都必须分片维度或明确复制。如果输入/输出分片中存在自由轴(不属于
manual_axes
的任何网格轴),则该轴必须小于同一维度分片中的任何手动轴(在上述示例中,维度分片{"model", "data"}
无效)。计算的区域/正文是本地计算(例如,包括用户指定的集合)。它必须相对于沿手动轴进行的入/出分片而言是本地的(请参阅上文中的备注)。
嵌套手动计算
您可以将多个手动计算嵌套在一起,前提是每个计算都使用自己的一组唯一的手动轴。