Compiler API

背景

我们假定读者至少熟悉分片表示的基础知识,其中介绍了如何在 Shardy 中表示张量的分片。本文档介绍了如何在程序中使用分片表示法,例如将分片附加到程序的特定张量。

分片传播是指在给定部分张量的分片约束条件的情况下,为程序中的每个张量确定分片的过程。Shardy 的编译器 API 提供了多种影响/控制分片传播的方法。此外,它还允许用户将手动分片的计算插入其程序中。

目标

本文档介绍了 Shardy 中此类 API 组件的设计,并说明了它们的行为和不变性。请注意,虽然此 API 用于控制分片传播,但本文档不会讨论任何有关传播行为或设计方式的内容。

概览

  • 输入/输出分片 - 将分片附加到主函数的输入或输出,以指明在将输入/输出张量传递给/从函数返回时,应采用这种分片方式。

  • 分片约束条件 - 将分片附加到中间张量(例如矩阵乘法结果),以指明应采用这种方式对该张量或其用例的一部分进行分片。

  • 分片组 - 按 ID 对多个张量进行分组,以指明它们应以相同方式进行分片。

  • 手动计算 - 用于封装使用部分网格轴手动分区的子计算,其中为所有输入和输出指定了沿这些手动轴的分片,并且在子计算内,张量类型相对于这些分片是本地的。

详细设计

输入/输出分片

允许用户为主函数的输入和输出指定分片。

在 MLIR 中,可以将属性附加到函数参数和结果,因此用户可以通过这种方式将分片属性附加到函数。

例如:

@mesh_xy = <["x"=2, "y"=2]>

// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
            {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
            %arg1: tensor<8x16xf32>)
    -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
  ...
}

分片约束条件

允许用户将分片附加到程序中的中间张量,这会告知分区器应如何对该张量或其用例的一部分进行分片。

这是一个 MLIR 运算,它接受张量作为输入,并附加了分片属性。该操作可以:

  • 没有用途(悬空)- 这意味着附加的分片方式应为张量本身的分片方式。
  • 有用途 - 这意味着附加的分片方式是分片约束操作的用途应采用的分片方式,而输入张量的其他用途可能采用不同的分片方式(如果输入张量没有其他用途,则行为与没有用途的情况相同)。传播将确定张量本身的分片,并在必要时对其进行重新分片。

它可以具有开放维度分片,这意味着操作数可以沿可用轴进一步分片。

@mesh_xy = <["x"=2, "y"=2]>

%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>

分片组

如果两个或更多张量之间没有数据依赖项或没有强数据依赖项,而用户知道这些张量应以相同或类似的方式进行分区,Shardy API 提供了一种指定此关系的方法。这样,用户就可以自由地明确指定张量应彼此分区。

为此,我们引入了分片组的概念,其中每个组包含与同一分片组 ID 关联的任意数量的说明。分片组会强制要求同一组中的分片相同。

例如,在假设的用户程序(如以下所示)中,我们希望将程序的输出分片与程序的输入完全相同,同时这两者之间没有数据依赖项。

如果我们运行此程序,分片传播将无法推断张量 %1%2 的分片,它们最终会被复制。不过,通过附加一个 shard_group 属性(表明输入 %0 和输出 %2 位于同一 shard_group 中),我们允许将分片 @mesh_xy, [{"x"},{"y"}]> 从输入 %0 传播到输出 %2,进而传播到图的其余部分(此处为广播常量 %1)。我们可以使用 sdy.sharding_group 运算为组分配值。

@mesh_xy = <["x"=2, "y"=2]>

module @"jit_zeros_like" {
  func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
    %0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
    %1 = stablehlo.constant dense<0> : tensor<8x2xi64>
    %2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
    return %2 : tensor<8x2xi64>
  }
}

在上面的这个简单示例中,我们也可以为输出明确指定与输入相同的分片,这将会产生相同的效果,因为我们已经知道要为输入分配哪个分片,但在更现实的情况下,我们使用分片是为了让多个张量的分片保持同步,而无需了解其中任何一个张量的分片,Shardy 会负责处理其余事宜,并找到最适合分配给它们的分片。

手动计算

用户可能希望明确控制计算的各部分如何分区以及使用哪些集合。例如,有些用户希望手动(通过前端 API)应用集体 matmul,而不是推迟到编译器。我们提供了一个手动计算 API,可供他们执行此操作。

这是 MLIR 操作,其中包含一个用于手动子计算的区域。 用户可以使用网格轴的一部分(可能包括全部)为此子计算指定输入/输出分片。相对于指定的网格轴(也称为手动轴),子计算将是局部/手动;相对于未指定的轴(也称为自由轴),子计算将是全局/未分区。在传播期间,子计算可以沿自由轴进一步分片,就像在此操作之外进行计算一样。

例如:

@mesh_name = <["data"=2, "model"=2]>

%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
    in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
    out_shardings=[<@mesh_name, [{"data"}, {?}]>]
    manual_axes={"data"}
    (%arg1: tensor<8x32xf32>) {
  // body
  return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

不变量

  1. 所有 in_shardingsout_shardingsmanual_axes 都必须引用同一网格。manual_axes 会按网格排序。

  2. 必须在所有入/出分片中明确使用 manual_axes,也就是说,对于每个分片,所有手动轴都必须分片维度或明确复制。

  3. 如果输入/输出分片中存在自由轴(不属于 manual_axes 的任何网格轴),则该轴必须小于同一维度分片中的任何手动轴(在上述示例中,维度分片 {"model", "data"} 无效)。

  4. 计算的区域/正文是本地计算(例如,包括用户指定的集合)。它必须相对于沿手动轴进行的入/出分片而言是本地的(请参阅上文中的备注)。

嵌套手动计算

您可以将多个手动计算嵌套在一起,前提是每个计算都使用自己的一组唯一的手动轴。