'sdy' 方言

Shardy (SDY) 方言定义了基于轴的张量分片表示法,以及用于将分片附加到张量的其他 API 组件。

操作

sdy.constant (sdy::ConstantOp)

常量运算

从常量 value 生成 output 张量。

请参阅:https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant

示例:

%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>

特征:AlwaysSpeculatableImplTrait

接口:ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影响:MemoryEffects::Effect{}

属性:

属性MLIR 类型说明
value::mlir::ElementsAttr常量矢量/张量属性

结果:

结果 说明
output 任意类型值的张量

sdy.data_flow_edge (sdy::DataFlowEdgeOp)

数据流边缘操作

语法:

operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)

某个操作 X 的数据流边定义了一组来源(每个都是 X 的运算数或 X 的块终止符的运算数)和一组目标(每个都是 X 的结果或 X 的块参数)之间的桥梁,以便所有来源和目标都应以相同的方式进行分片。

一个运算可以有多个彼此正交的数据流边。

例如:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

此 while 运算有 n 条数据流边,第 i 条数据流边位于源 x_ireturn_value_i 和目标 y_ipred_arg_ibody_arg_i 之间。

sdy.data_flow_edge 将边缘的根目标(可以是任何目标,但最好是操作结果而不是块参数)作为输入,不应有任何其他用途。此操作不是纯粹的,因为它可以接受最初没有任何用途的输入。

sdy.data_flow_edge 还包含边缘的所有目标的可选分片,并且在传播期间,应更新该分片,而不是目标的分片(如果可以附加)。当运算具有许多边时,这非常有用,因为这样可以更高效地执行以下操作:

  • 分别通过每个边传播。
  • 分别更新每个边的拆分,而不是一次更新所有目标(例如,某个操作具有单个不可变的 TensorShardingPerValueAttr 用于结果拆分)。
  • 当来源的分片发生变化时,将每个边单独添加到工作列表中。

传播会在 sdy.data_flow_edge 的所有来源和目标之间传播分片,就像它是将来源作为运算数、目标作为结果且具有标识 sdy.op_sharding_rule 的常规运算一样。也就是说,正向传播是从来源到目标,而反向传播是从目标到来源。

我们不允许 sdy.data_flow_edge 的输入由 SdyDialect 运算定义,因此我们可以假定它由已取消注册 sdy.sharding 属性的运算定义。

特征:SameOperandsAndResultType

接口:InferTypeOpInterface

属性:

属性MLIR 类型说明
sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
input 形状为任何类型值

结果:

结果 说明
result 形状为任何类型值

sdy.manual_computation (sdy::ManualComputationOp)

使用手动集合进行多设备并行操作

语法:

operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)

跳转到使用显式集合按设备本地代码编写的区域,其中逻辑形状与设备本地物理缓冲区形状相匹配,并且集合与物理跨设备通信完全对应。

相对于 manual_axes,body 是局部的。传播将通过身体上的任何自由轴(不在 manual_axes 列表中)进行。

trait:IsolatedFromAboveRecursiveMemoryEffectsSingleBlockImplicitTerminator<ReturnOp>SingleBlock

属性:

属性MLIR 类型说明
in_shardings::mlir::sdy::TensorShardingPerValueAttr按操作数/操作结果对张量进行分片
out_shardings::mlir::sdy::TensorShardingPerValueAttr按操作数/操作结果对张量进行分片
manual_axes::mlir::sdy::ManualAxesAttr

运算数:

Operand 说明
tensors 任意类型值的排序张量的变参

结果:

结果 说明
results 任意类型值的已排序张量的可变变异值

sdy.mesh (sdy::MeshOp)

命名网格

语法:

operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict

定义一个新的命名网格。模块中的所有网格必须包含相同数量的设备(仅限具有单个 device_id 的网格)。网格是出现在模块的 SymbolTable 中的 Symbol 操作,可由其 name 引用。

特征:HasParent<ModuleOp>

接口:Symbol

属性:

属性MLIR 类型说明
sym_name::mlir::StringAttr字符串属性
mesh::mlir::sdy::MeshAttr轴网格和设备列表

sdy.named_computation (sdy::NamedComputationOp)

命名计算操作

语法:

operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)

对计算(即一组操作)进行分组并为其命名。传播将流入/流出该区域,就像所有内容都内嵌一样。

这可用于处理通过调用指令传播到其他函数的情况。Shardy 的任何用户都应编写一个导入/导出传递,以将其调用操作转换为 sdy.named_computation 操作,将被调函数的正文复制/复制到 named_computation 的正文中。

该区域中每个块参数和返回值的类型必须与运算元的类型和运算结果类型相同。

示例:

%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

trait:IsolatedFromAboveRecursiveMemoryEffectsRecursivelySpeculatableImplTraitSingleBlockImplicitTerminator<ReturnOp>SingleBlock

接口:ConditionallySpeculatableShardableDataFlowOpInterface

属性:

属性MLIR 类型说明
name::mlir::StringAttr字符串属性
in_shardings::mlir::sdy::TensorShardingPerValueAttr按操作数/操作结果对张量进行分片
out_shardings::mlir::sdy::TensorShardingPerValueAttr每个操作数/操作结果的张量分片

操作数:

Operand 说明
operands 任意类型的变参

结果:

结果 说明
“未命名” 任意类型的变参

sdy.propagation_barrier (sdy::PropagationBarrierOp)

传播屏障操作

语法:

operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)

此运算的运作方式与恒等运算类似,输出与其作为输入时相同的值。但在传播方面,这只会允许传播沿着特定方向流经它。

这可防止在使用屏障运算结果及其运算子时传播分片。

  • FORWARD 表示分片只能从运算对象流向结果。
  • BACKWARD 表示分片只能从结果流向运算对象。
  • NONE 表示无法通过此操作传播任何分片。
  • 无法指定 BOTH,因为此操作会重复。

trait:AlwaysSpeculatableImplTraitElementwiseSameOperandsAndResultType

接口:ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影响:MemoryEffects::Effect{}

属性:

属性MLIR 类型说明
allowed_direction::mlir::sdy::PropagationDirectionAttr传播方向枚举

操作数:

Operand 说明
input 任何类型值的排名张量

结果:

结果 说明
result 任意类型值的已排序张量

sdy.reshard (sdy::ReshardOp)

将张量重新分片为其他分片

语法:

operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)

使用指定的分片(不同于输入张量的现有分片)重新分片输入张量。

ShardingConstraintOp 和 ReshardOp 都会将分片附加到张量。其生命周期如下:

  1. 在分片传播之前,用户会添加 ShardingConstraintOp。
  2. 分片传播会使用 ShardingConstraintOp。分片传播结果中没有 ShardingConstraintOp。不过,如果需要,可以添加 ReshardOp。
  3. 分区器会将 ReshardOp 转换为集合运算(或身份运算)。分区器的结果中不应包含 ReshardOp。

// TODO(b/331680067)。添加了规范化模式,以移除多余的 // reshard 操作。

trait:AlwaysSpeculatableImplTraitElementwiseSameOperandsAndResultType

接口:ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影响:MemoryEffects::Effect{}

属性:

属性MLIR 类型说明
sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
input 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.return (sdy::ReturnOp)

sdy.return 操作会终止附加到 sdy 基于区域的操作和任何其他基于 Shardy 区域的操作的区域。它具有可变性:它接受一个类型可以是任何类型(但属于同一种类,例如 AnyTensor)的值列表作为参数,因此可以在 Shardy IR 堆栈的不同级别重复使用。

语法:

operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?

trait:AlwaysSpeculatableImplTraitTerminator

接口:ConditionallySpeculatableNoMemoryEffect (MemoryEffectOpInterface)

效果:MemoryEffects::Effect{}

操作数:

Operand 说明
results 任意类型的变异参数

sdy.sharding_constraint (sdy::ShardingConstraintOp)

将张量约束为指定分片

语法:

operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)

将分片附加到中间张量(例如矩阵乘法结果),以指明该张量或其用途的一部分应采用这种方式进行分片。

如果分片具有开放维度和未约束的轴,则表示张量可以沿开放维度进一步分片。

此操作可以:

  • 没有用途(悬空)- 这意味着附加的分片方式应为输入张量本身的分片方式。
  • 使用 - 这意味着附加的分片就是应如何使用分片约束操作进行分片,而其他输入张量用法可能具有不同的分片(如果输入张量没有其他用途,则行为与无用例相同)。

trait:ElementwiseSameOperandsAndResultType

接口:InferTypeOpInterface

属性:

属性MLIR 类型说明
sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
input 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.sharding_group (sdy::ShardingGroupOp)

分片组操作

语法:

operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)

此运算提供了一个接口,用于将张量分配给分片组(系统会强制要求这些张量组具有相同的分片)。在传播过程中,只要一个分组元素被分片,所有其他成员都会以完全相同的方式进行分片。此操作采用参数组 ID 且不返回任何结果,但会修改内部分片组表示法,将输入张量添加到具有指定 ID 的组中。

属性:

属性MLIR 类型说明
group_id::mlir::IntegerAttr64 位无符号整数属性

运算数:

Operand 说明
input 任何类型值的排名张量

属性

AxisRefAttr

引用全轴或拆分子轴

语法:

#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>

参数:

参数 C++ 类型 说明
name ::llvm::StringRef name
sub_axis_info SubAxisInfoAttr

DimMappingAttr

维度的因子索引列表

所有因子索引都必须在 [0, num_factors) 范围内,空列表表示这是 null 映射(使用 * 解析/输出),即该维度未映射到任何因子。

参数:

参数 C++ 类型 说明
factor_indices ::llvm::ArrayRef<int64_t>

DimensionShardingAttr

维度分片

用于按主轴到次轴对张量维度进行分片的轴名称列表、一个布尔值(表示该维度是否可以进一步分片),以及一个可选整数(表示此维度分片的优先级,分片传播期间会遵循此优先级)。优先级来自用户分片注解,值越小,优先级越高。如果注解中缺少优先级,则假定优先级为最高。

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<AxisRefAttr> 轴参考列表
is_closed bool
优先级 std::optional<int64_t>

ManualAxesAttr

语法:

#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<StringAttr>

MeshAttr

轴网格和设备列表

语法:

#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>

网格由一系列轴和一个可选的设备 ID 列表构成,用于指定设备排序。

如果轴列表为空,则网格具有大小为 1 的隐式未命名轴。在这种情况下,如果未提供设备 ID 列表,则隐式设备 ID 列表为 [0];如果提供了设备 ID 列表,则其中必须包含一个值为任何非负值的整数。我们将这种情况称为最大分片情形。

对于所有非最大分片情况,如果指定了设备 ID 列表,则轴大小的乘积应与设备数量一致。如果未指定设备 ID 列表,则隐式设备 ID 列表为 iota(product(axes))。 为简单起见,我们还禁止指定与 iota(product(axes)) 相同的设备 ID 列表;在这种情况下,不应指定设备 ID 列表。

以下是网格的一些示例:

  • 空网格表示可以在传播期间替换的占位符网格:<[]>
  • 网格具有未命名的轴和显式设备 ID,通常用于表示最大分片:<[], device_ids=[3]>
  • 具有两个轴和隐式设备 ID 的网格 iota(6):<["a"=2, "b"=3]>
  • 具有两个轴和显式设备 ID 来指定设备排序的网格:<["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<MeshAxisAttr>
device_ids ::llvm::ArrayRef<int64_t>

MeshAxisAttr

网格中的命名轴

语法:

#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>

参数:

参数 C++ 类型 说明
name ::llvm::StringRef name
size int64_t

OpShardingRuleAttr

指定操作的分区方式。

语法:

#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  bool   # is_custom_rule
>

分片规则指定了如何根据操作的各种属性(任何属性、运算元的形状、结果的形状等)对操作进行分区。例如:

%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>

请注意,即使不能分片,我们也允许大小为 1 的因子,这主要是为了实现完整性,因为点操作等许多操作的大小为 1,其维度与运算数和结果相对应。

is_custom_rule 描述该规则是否是用户针对 stablehlo.custom_call 操作定义的规则。分区程序不知道如何对这些操作进行分区,因此用户必须告知其具体方法。如果是自定义规则,则该规则始终会保留/永远不会移除。is_custom_rule 只能针对 stablehlo.custom_call 操作为 true。

参数:

参数 C++ 类型 说明
factor_sizes ::llvm::ArrayRef<int64_t>
operand_mappings ::llvm::ArrayRef<TensorMappingAttr>
result_mappings ::llvm::ArrayRef<TensorMappingAttr>
is_custom_rule bool

SubAxisInfoAttr

有关此子轴如何从全轴派生的相关信息

语法:

#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>

将完整轴拆分为 n 个子轴时,轴会重塑为 [k_1,...,k_n],并且第 i 个子轴可以表示为其左侧所有轴大小 m=prod(k_1,...,k_(i-1))(也称为“前大小”)和大小 k_i 的乘积。因此,sub-axis-info 属性会存储这两个数字,并表示为:(m)k(预设大小为 m,大小为 k)。

参数:

参数 C++ 类型 说明
pre_size int64_t
size int64_t

TensorMappingAttr

张量每个维度的因子映射。

语法:

#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>

参数:

参数 C++ 类型 说明
dim_mappings ::llvm::ArrayRef<DimMappingAttr>

TensorShardingAttr

张量分片

语法:

#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>   # replicated_axes
>

张量分片绑定到特定网格,并且只能引用该网格中的轴名称。维度分片可告诉我们张量的每个维度是沿哪些轴(或子轴)从主轴到次轴进行分片的。不对维度进行分片的所有其他轴都会隐式或显式复制(如果它们出现在复制的轴列表中)。

此分片绑定到的网格可以由符号名称(引用相应的 MeshOp 符号)或内嵌 MeshAttr 指定。

参数:

参数 C++ 类型 说明
mesh_or_ref ::mlir::Attribute 网格属性或平面网格符号引用属性
dim_shardings ::llvm::ArrayRef<DimensionShardingAttr>
replicated_axes ::llvm::ArrayRef<AxisRefAttr> 轴引用列表

TensorShardingPerValueAttr

按操作数/操作结果对张量进行分片

语法:

#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>

参数:

参数 C++ 类型 说明
分片 ::llvm::ArrayRef<TensorShardingAttr>

枚举

PropagationDirection

传播方向枚举

支持请求:

符号 字符串
0
FORWARD 1 FORWARD
BACKWARD 2 BACKWARD
双方 3 双方