与方言无关的分片

我们的长期目标是将 Shardy 打造成一个完全独立的组件,能够与任何 MLIR 方言搭配使用。目前,Shardy 直接依赖于 StableHLO,但我们正在努力通过各种抽象和接口来提升这一点,以使 Shardy 更加灵活。

分片规则

分片规则用于编码我们如何通过操作进行传播。由于 Shardy 现在依赖于 StableHLO,因此它会为每个 stablehlo 操作定义分片规则。此外,Shardy 还提供了 ShardingRuleOpInterface,方言所有者可以在其操作中使用该函数为自己的操作定义分片规则。只要操作实现了此接口,Shardy 便可通过该接口进行传播。

def ShardingRuleOpInterface : OpInterface<"ShardingRuleOpInterface"> {
  let methods = [
    InterfaceMethod<
      /*desc=*/[{
        Returns the sharding rule of the op.
      }],
      /*retType=*/"mlir::sdy::OpShardingRuleAttr",
      /*methodName=*/"getShardingRule"
    >,
  ];
}

数据流操作

某些运算(例如基于区域的运算)需要采用不同的方法,仅描述所有运算数和结果之间的维度对应关系的分片规则是不够的。在这些情况下,Shardy 会定义 ShardableDataFlowOpInterface,以便方言所有者可以通过其操作描述分片的传播。此接口提供了一些方法,用于通过其所有者获取每个数据流边的来源和目标,以及获取和设置边所有者的分片。

def ShardableDataFlowOpInterface :
    OpInterface<"ShardableDataFlowOpInterface"> {
  (get|set)BlockArgumentEdgeOwnerShardings;
  (get|set)OpResultEdgeOwnerShardings;
  getBlockArgumentEdgeOwners;
  getOpResultEdgeOwners;
  getEdgeSources;
  // ...
}

另请参阅数据流操作,简要了解我们如何处理数据流操作。

尚未实现的接口

未来,我们将添加更多接口和 trait,使 Shardy 更加灵活且不依赖于方言。我们在下方列出了这些问题。

常量拆分

MLIR 中的大多数张量程序都包含一个常量实例,该实例会被需要该值的任何操作重复使用。当所需的常量相同时,这种做法很有意义。不过,为了对程序进行最优的拆分,我们希望允许对常量的每次使用都有自己的拆分,并且不受其他操作对该常量的使用方式的影响。

例如,在下图中,如果 add 进行了分片,则不应影响 dividesubtract(在计算的不同部分)的分片方式。

常量拆分

我们称之为虚假依赖项:由于常量开销很低,因此使用相同常量的运算之间没有真正的依赖项。因此,用户可以决定其常量(和类似常量)运算的分片方式。这样一来,对该常量的每次使用都可以采用不同的分片,这些分片可以独立传播到常量子计算的各个副本。

为此,Shardy 用户需要定义:- your_dialect.constant -> sdy.constant 传递;- sdy::ConstantLike trait,例如 iota;- 适用于 addmultiply 等元素级运算的 mlir::Elementwise trait;- 适用于 slice/broadcast 等运算的 sdy::ConstantFoldable。从技术层面上讲,如果所有运算数/结果都是常量,则可以在编译时计算这些运算。

操作优先级

在 GSPMD 中,首先传播按元素的运算,然后是 matmul 等运算。在 Shardy 中,我们希望允许用户设置自己的操作优先级,因为我们无法预先知道其方言。因此,我们会要求他们按希望 Shardy 传播操作的顺序传递操作列表。

下图展示了 GSPMD 中如何使用优先级以正确的顺序传播操作。

操作优先级。请参阅 GSPMD 论文,了解操作优先级为何重要

如需了解操作优先级的重要性,请参阅 GSPMD 论文

与方言无关

只要您实现了之前的接口、trait 和传递,Shardy 便可用于您的方言。我们正在努力让 Shardy 变得更加灵活且不受方言影响,敬请期待后续动态。