'sdy' 方言

Shardy (SDY) 方言定义了基于轴的张量分片表示法,以及用于将分片附加到张量的其他 API 组件。

操作

sdy.all_gather (sdy::AllGatherOp)

沿轴执行全收集通信

语法:

operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

沿 gathering_axes 中指定的轴收集张量分块。

gathering_axes 是轴列表的列表。外部列表超出了张量的维度。每个内部列表都指定了应沿哪些轴对相应维度执行单独的汇总。它将应用于操作数 (tensor) 的分片,以获取结果 (out_sharding) 的分片。

请注意,out_sharding 不会用于确定结果的分片。相反,结果的分片由运算数和 gathering_axes 的分片决定,并且 out_sharding 必须与此推断的分片一致。

示例:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>

限制

  • 必须满足 Sdy_CollectiveOpInterface 中列出的约束条件。
  • gathering_axes 中的元素必须满足 AxisRefListAttr 中列出的约束条件。
  • gathering_axes 应用于运算数分片会得到 out_sharding

特征:SameOperandsAndResultType

接口:InferTypeOpInterfaceSdy_CollectiveOpInterface

属性:

属性MLIR 类型说明
gathering_axes::mlir::sdy::ListOfAxisRefListsAttr轴引用列表列表
out_sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
tensor 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.all_reduce (sdy::AllReduceOp)

沿轴执行全局求和通信

语法:

operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

沿 reduction_axes 中指定的轴缩减张量分块。 reduction_axes 的顺序对结果而言并不重要,但可能会影响相应副本组的顺序。

限制

  • 必须满足 Sdy_CollectiveOpInterface 中列出的约束条件。
  • reduction_axes 必须满足 AxisRefListAttr 中列出的约束条件;
  • reduction_axes 不得与运算数分片轴重叠;

特征:SameOperandsAndResultType

接口:CollectiveOpInterfaceInferTypeOpInterface

属性:

属性MLIR 类型说明
reduction_axes::mlir::sdy::AxisRefListAttr轴引用列表
out_sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
tensor 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.all_slice (sdy::AllSliceOp)

沿轴执行动态切片操作

语法:

operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

沿 slicing_axes 中指定的轴切片张量分块。sdy.all_slicesdy.all_gather 之间存在代数对偶性。

slicing_axes 是轴列表的列表。外部列表超出了张量的维度。每个内部列表都指定了应沿哪些轴对相应维度执行切片。它将应用于操作数 (tensor) 的分片,以获取结果 (out_sharding) 的分片。

请注意,out_sharding 不会用于确定结果的分片。相反,结果的分片由运算数和 slicing_axes 的分片决定,并且 out_sharding 必须与此推断的分片一致。

示例:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>

限制

  • slicing_axes 中的元素必须满足 AxisRefListAttr 中列出的约束条件。
  • 必须满足 Sdy_CollectiveOpInterface 中列出的约束条件。
  • slicing_axes 应用于运算数分片会得到 out_sharding

特征:SameOperandsAndResultType

接口:CollectiveOpInterfaceInferTypeOpInterface

属性:

属性MLIR 类型说明
slicing_axes::mlir::sdy::ListOfAxisRefListsAttr轴引用列表列表
out_sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
tensor 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.all_to_all (sdy::AllToAllOp)

沿轴执行全对全通信

语法:

operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

对于参数列表中的每个 (axes, src_dim, tgt_dim) 元组,此操作会沿着 tgt_dim 维度和 axes 中指定的轴切片张量分块,沿着这些轴分散这些分块,并沿着 src_dim 维度将它们串联起来。

此操作本质上是沿 src_dimaxes 执行全收集,然后沿 tgt_dimaxes 执行全切片的组合,即将输入张量上的轴分片维度 src_dim 的后缀附加到输出张量上的轴分片维度 tgt_dim

全局分片将应用于操作数 (tensor) 的分片,以获取结果 (out_sharding) 的分片。

请注意,out_sharding 不会用于确定结果的分片。而是由操作数 src_dimtgt_dimaxes 的分片决定,并且 out_sharding 必须与此推断的分片匹配。

示例:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>

限制

  • 必须满足 Sdy_CollectiveOpInterface 中列出的约束条件。
  • 参数列表不得为空。
  • 对于 params 中的每个参数:
    • axes 中的元素必须满足 AxisRefAttr 的约束条件。
    • src_dimtgt_dim 必须是有效的维度(非负且小于张量的秩)。
    • 任何 src_dimtgt_dim 在所有参数中都必须是唯一的。
    • 所有参数的 src_dim 都必须按升序排序。
  • 将操作数分片中的 axessrc_dim 移至 tgt_dim 会导致 out_sharding

特征:SameOperandsAndResultType

接口:InferTypeOpInterfaceSdy_CollectiveOpInterface

属性:

属性MLIR 类型说明
params::mlir::sdy::AlltoAllParamListAttr全局参数列表
out_sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
tensor 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.collective_permute (sdy::CollectivePermuteOp)

执行集体排列通信以替换轴

语法:

operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

将输入张量的一个分块从一台设备发送到另一台设备,以重新排列/替换分片张量的轴。

集合排序可以转换输入分片,使每个维度都必须与之前一样进行分片,即必须沿着大小乘积与之前分片张量的轴相匹配的轴进行分片。

这对于在单个维度或不同维度中重新排列轴,以及将分片轴与复制轴进行切换非常有用。

在以下示例中,分片张量大小为 tensor<1x4x2xf32>,并由集体排序保留。

示例:

sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>

限制

  • 必须满足 Sdy_CollectiveOpInterface 中列出的约束条件。
  • 如果输入和输出分片具有不同的网格,则这些网格必须具有完全相同的轴和不同的设备 ID 顺序。
  • 对于每个维度,out_sharding 中的分片轴大小的乘积必须与相应运算数维度分片的乘积一致。

特征:SameOperandsAndResultType

接口:CollectiveOpInterfaceInferTypeOpInterface

属性:

属性MLIR 类型说明
out_sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
tensor 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.constant (sdy::ConstantOp)

常量运算

从常量 value 生成 output 张量。

请参阅:https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant

示例:

%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>

特征:AlwaysSpeculatableImplTrait

接口:ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影响:MemoryEffects::Effect{}

属性:

属性MLIR 类型说明
value::mlir::ElementsAttr常量矢量/张量属性

结果:

结果 说明
output 任意类型值的静态形状张量

sdy.data_flow_edge (sdy::DataFlowEdgeOp)

数据流边缘操作。

语法:

operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)

某个操作 X 的数据流边定义了一组来源(每个都是 X 的运算数或 X 的块终止符的运算数)和一组目标(每个都是 X 的结果或 X 的块参数)之间的桥梁,以便所有来源和目标都应以相同的方式进行分片。

一个运算可以有多个彼此正交的数据流边。

例如:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

此 while 运算有 n 条数据流边,第 i 条数据流边位于源 x_ireturn_value_i 和目标 y_ipred_arg_ibody_arg_i 之间。

sdy.data_flow_edge 将边的所有者(可以是任何目标,但最好是运算结果,而不是块参数)作为输入,该所有者不应有任何其他用途。此运算不是纯运算,因为它可以接受最初没有任何用途的输入。

sdy.data_flow_edge 还包含边缘的所有目标的可选分片,并且在传播期间,应更新该分片,而不是目标的分片(如果可以附加)。当运算具有许多边时,这非常有用,因为这样可以更高效地执行以下操作:

  • 分别通过每个边传播。
  • 分别更新每个边的分片,而不是一次更新所有目标(例如,某个操作只有一个不可变的 TensorShardingPerValueAttr 用于结果分片)。
  • 当来源的分片发生变化时,将每个边单独添加到工作列表中。

传播会在 sdy.data_flow_edge 的所有来源和目标之间传播分片,就像它是将来源作为运算数、目标作为结果且具有标识 sdy.op_sharding_rule 的常规运算一样。也就是说,正向传播是从来源到目标,而反向传播是从目标到来源。

我们不允许 sdy.data_flow_edge 的输入由 SdyDialect 运算定义,因此我们可以假定它由已取消注册 sdy.sharding 属性的运算定义。

特征:SameOperandsAndResultType

接口:InferTypeOpInterface

属性:

属性MLIR 类型说明
sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
input 形状为任何类型值

结果:

结果 说明
result 形状为任何类型值

sdy.manual_computation (sdy::ManualComputationOp)

使用手动集合进行多设备并行操作

语法:

operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)

跳转到使用显式集合按设备本地代码编写的区域,其中逻辑形状与设备本地物理缓冲区形状相匹配,并且集合与物理跨设备通信完全对应。

相对于 manual_axes,body 是局部的。传播将通过身体上的任何自由轴(不在 manual_axes 列表中)进行。

限制

  • in_shardingsout_shardings 中的元素必须满足 TensorShardingAttr 中列出的约束条件。
  • 运算区域的全局和本地张量输入/输出的数量必须一致。
  • 手动轴必须位于每个维度分片中的所有自由轴之前。
  • 手动轴无法添加内边距。也就是说,尺寸必须能被相应的手动轴尺寸整除。
  • 操作区域参数/结果的全局和局部形状必须匹配。
  • 不会拆分任何手动轴。

trait:IsolatedFromAboveRecursiveMemoryEffectsSingleBlockImplicitTerminator<ReturnOp>SingleBlock

接口:ShardableDataFlowOpInterface

属性:

属性MLIR 类型说明
in_shardings::mlir::sdy::TensorShardingPerValueAttr按操作数/操作结果对张量进行分片
out_shardings::mlir::sdy::TensorShardingPerValueAttr按操作数/操作结果对张量进行分片
manual_axes::mlir::sdy::ManualAxesAttrManualComputationOp 为手动计算的轴的列表

运算数:

Operand 说明
tensors 任意类型值的排序张量的变参

结果:

结果 说明
results 任意类型值的排序张量的变参

sdy.mesh (sdy::MeshOp)

命名网格

语法:

operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict

定义一个新的命名网格。模块中的所有网格都必须包含相同数量的设备(仅限具有单个 device_id 的网格)。网格是模块的 SymbolTable 中显示的 Symbol 操作,可由其 name 引用。

特征:HasParent<ModuleOp>

接口:Symbol

属性:

属性MLIR 类型说明
sym_name::mlir::StringAttr字符串属性
mesh::mlir::sdy::MeshAttr轴网格和设备列表

sdy.named_computation (sdy::NamedComputationOp)

命名计算操作

语法:

operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)

对计算(即一组操作)进行分组并为其命名。传播将流入/流出该区域,就像所有内容都内嵌一样。

这可用于处理通过调用指令传播到其他函数的情况。Shardy 的任何用户都应编写一个导入/导出传递,以将其调用操作转换为 sdy.named_computation 操作,将被调函数的正文复制/复制到 named_computation 的正文中。

该区域中每个块参数和返回值的类型必须与运算元的类型和运算结果类型相同。

示例:

%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

trait:IsolatedFromAboveRecursiveMemoryEffectsRecursivelySpeculatableImplTraitSingleBlockImplicitTerminator<ReturnOp>SingleBlock

接口:ConditionallySpeculatableInferTypeOpInterfaceShardableDataFlowOpInterface

属性:

属性MLIR 类型说明
name::mlir::StringAttr字符串属性
in_shardings::mlir::sdy::TensorShardingPerValueAttr按操作数/操作结果对张量进行分片
out_shardings::mlir::sdy::TensorShardingPerValueAttr按操作数/操作结果对张量进行分片

运算数:

Operand 说明
operands 任意类型的变参

结果:

结果 说明
«unnamed» 任意类型的变参

sdy.propagation_barrier (sdy::PropagationBarrierOp)

传播屏障操作

语法:

operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)

此运算的运作方式与恒等运算类似,输出与作为输入的值相同的值。但在传播方面,这只会允许传播沿着特定方向流经它。

这可防止在使用屏障运算的结果及其运算子之间传播分片。

  • FORWARD 表示分片只能从运算对象流向结果。
  • BACKWARD 表示分片只能从结果流向运算对象。
  • NONE 表示无法通过此操作传播任何分片。
  • 无法指定 BOTH,因为此操作会重复。

trait:AlwaysSpeculatableImplTraitSameOperandsAndResultType

接口:ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影响:MemoryEffects::Effect{}

属性:

属性MLIR 类型说明
allowed_direction::mlir::sdy::PropagationDirectionAttr传播方向枚举

运算数:

Operand 说明
input 任意类型值的排名张量

结果:

结果 说明
result 任意类型值的排名张量

sdy.reshard (sdy::ReshardOp)

将张量重新分片到其他分片

语法:

operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)

使用指定的分片(不同于输入张量的现有分片)重新分片输入张量。

ShardingConstraintOp 和 ReshardOp 都会将分片附加到张量。其生命周期如下:

  1. 在分片传播之前,用户会添加 ShardingConstraintOp。
  2. 分片传播会使用 ShardingConstraintOp。分片传播结果中没有 ShardingConstraintOp。不过,如果需要,可以添加 ReshardOp。
  3. 分区器会将 ReshardOp 转换为集合运算(或身份运算)。分区器的结果中不应包含 ReshardOp。

// TODO(b/331680067). 添加了规范化模式以移除多余的 // reshard 操作。

trait:AlwaysSpeculatableImplTraitSameOperandsAndResultType

接口:ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影响:MemoryEffects::Effect{}

属性:

属性MLIR 类型说明
sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
input 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.return (sdy::ReturnOp)

sdy.return 操作会终止附加到 sdy 基于区域的操作和任何其他基于 Shardy 区域的操作的区域。它是可变参数的:它接受值列表作为参数,值的类型可以是任意类型(但属于同一类,例如 AnyTensor),因此可以在 Shardy IR 堆栈的各个级别重复使用。

语法:

operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?

trait:AlwaysSpeculatableImplTraitTerminator

接口:ConditionallySpeculatableNoMemoryEffect (MemoryEffectOpInterface)

影响:MemoryEffects::Effect{}

运算数:

Operand 说明
results 任意类型的变参

sdy.sharding_constraint (sdy::ShardingConstraintOp)

将张量限制为指定的分片

语法:

operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)

将分片附加到中间张量(例如矩阵乘法结果),以指明该张量或其用途的一部分应采用这种方式进行分片。

如果分片具有开放维度和未约束的轴,则表示张量可以沿开放维度进一步分片。

此操作可以:

  • 没有用途(悬空)- 这意味着附加的分片方式应为输入张量本身的分片方式。
  • 有用途 - 这意味着附加的分片方式是分片约束操作的用途应采用的分片方式,而输入张量的其他用途可能采用不同的分片方式(如果输入张量没有其他用途,则行为与没有用途的情况相同)。

特征:SameOperandsAndResultType

接口:InferTypeOpInterface

属性:

属性MLIR 类型说明
sharding::mlir::sdy::TensorShardingAttr张量分片

运算数:

Operand 说明
input 任意类型值的张量

结果:

结果 说明
result 任意类型值的张量

sdy.sharding_group (sdy::ShardingGroupOp)

限制组中的张量具有相同的分片。

语法:

operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)

此运算提供了一个接口,用于将张量分配给分片组(系统会强制要求这些张量组具有相同的分片)。在传播过程中,只要一个分组元素被分片,所有其他成员都会以完全相同的方式进行分片。此操作会接受实参组 ID,但不会返回任何结果,而是会修改内部分片组表示法,以将输入张量添加到具有给定 ID 的组。

接口:InferTypeOpInterface

属性:

属性MLIR 类型说明
group_id::mlir::IntegerAttr64 位无符号整数属性

运算数:

Operand 说明
input 任意类型值的排名张量

属性

AllToAllParamAttr

全局参数

语法:

#sdy.all_to_all_param<
  ::llvm::ArrayRef<AxisRefAttr>,   # axes
  int64_t,   # src_dim
  int64_t   # tgt_dim
>

一个元组,包含要执行全局连接的轴和源/目标维度。

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<AxisRefAttr> 要执行全对全操作的轴
src_dim int64_t 来源维度索引
tgt_dim int64_t 目标维度索引

AlltoAllParamListAttr

所有对所有参数的列表

语法:

#sdy.all_to_all_param_list<
  ::llvm::ArrayRef<AllToAllParamAttr>   # value
>

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<AllToAllParamAttr>

AxisRefAttr

对完整轴或分屏子轴的引用

语法:

#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>

限制

  • name 必须存在于绑定的 MeshAttr 中。
  • 如果存在 sub_axis_info,则它必须满足 SubAxisInfoAttr 的约束条件。

参数:

参数 C++ 类型 说明
name ::llvm::StringRef 此轴的名称
sub_axis_info SubAxisInfoAttr 如果这是子轴,则提供其他信息

AxisRefListAttr

轴引用列表

语法:

#sdy.axis_ref_list<
  ::llvm::ArrayRef<AxisRefAttr>   # value
>

限制

  • value 中的元素必须满足 AxisRefAttr 的约束条件。
  • 没有重复的轴引用或彼此重叠的子轴。
  • 任何两个相邻的轴引用都不能是同一全轴的连续子轴,即它们可以合并为一个子轴或全轴。

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<AxisRefAttr>

DimMappingAttr

维度的因子指数列表

空列表表示这是 null 映射(使用 * 解析/输出),即维度未映射到任何因子。

限制

  • 至少有一个因子索引。
  • 因子索引必须在 [0, $factor_sizes) 范围内。
  • 如果有多个因素,则其中任何一个因素的大小都不能为 1。
  • 没有重复的因子索引。

参数:

参数 C++ 类型 说明
factor_indices ::llvm::ArrayRef<int64_t> 此维度映射到的因素

DimensionShardingAttr

维度分片

用于按主轴到次轴对张量维度进行分片的轴名称列表、一个布尔值(表示该维度是否可以进一步分片),以及一个可选整数(表示此维度分片的优先级,分片传播期间会遵循此优先级)。优先级来自用户分片注解,值越低,优先级越高。如果注解中缺少优先级,则系统会假定优先级为最高。

限制

  • axes 中的元素必须满足 AxisRefListAttr 中列出的约束条件。
  • 如果维度分片具有优先级:
    • 优先级大于或等于 0。
    • 如果维度已关闭,则至少有一个轴。

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<AxisRefAttr> 轴引用
is_closed bool 此维度是否无法进一步分片
优先级 std::optional<int64_t> 基于用户优先级传播期间使用的优先级

ListOfAxisRefListsAttr

轴引用列表

语法:

#sdy.list_of_axis_ref_lists<
  ::llvm::ArrayRef<AxisRefListAttr>   # value
>

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<AxisRefListAttr>

ManualAxesAttr

ManualComputationOp 为手动计算的轴的列表

语法:

#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<StringAttr>

MeshAttr

轴网格和设备列表

语法:

#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>

网格是轴列表和可选的设备 ID 列表,用于指定设备排序。

如果轴列表为空,则网格具有大小为 1 的隐式未命名轴。在这种情况下,如果未提供设备 ID 列表,则隐式设备 ID 列表为 [0];如果提供了设备 ID 列表,则其中必须包含一个值为任何非负值的整数。我们将这种情况称为“最大分片”情况。

对于所有非最大分片情况,如果指定了设备 ID 列表,则轴大小的乘积应与设备数量一致。如果未指定设备 ID 列表,则隐式设备 ID 列表为 iota(product(axes))。为简单起见,我们还禁止指定与 iota(product(axes)) 相同的设备 ID 列表;在这种情况下,不应指定设备 ID 列表。

以下是一些网格的示例:

  • 空网格表示可在传播期间替换的占位符网格:<[]>
  • 网格具有未命名的轴和显式设备 ID,通常用于表示最大分片:<[], device_ids=[3]>
  • 具有两个轴和隐式设备 ID 的网格 iota(6):<["a"=2, "b"=3]>
  • 具有两个轴和显式设备 ID 来指定设备排序的网格:<["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>

限制

  • axes 中的元素名称不得重复。
  • 如果指定了 device_ids
    • 轴大小的乘积必须与设备数量一致。
    • 其所有元素都必须为非负数。
    • device_ids 不应等于 iota(product(axis_sizes))
    • 已排序的 device_ids 必须为 iota(product(axis_sizes))

参数:

参数 C++ 类型 说明
::llvm::ArrayRef<MeshAxisAttr> 网格轴
device_ids ::llvm::ArrayRef<int64_t> 显式设备排序或设备 ID 上限

MeshAxisAttr

网格中的命名轴

语法:

#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>

参数:

参数 C++ 类型 说明
name ::llvm::StringRef name
size int64_t 此轴的大小

OpShardingRuleAttr

指定操作的分区方式。

语法:

#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  ::llvm::ArrayRef<int64_t>,   # reduction_factors
  ::llvm::ArrayRef<int64_t>,   # need_replication_factors
  ::llvm::ArrayRef<int64_t>,   # permutation_factors
  ::llvm::ArrayRef<int64_t>,   # blocked_propagation_factors
  bool   # is_custom_rule
>

分片规则指定了如何根据操作的各种属性(任何属性、运算元的形状、结果的形状等)对操作进行分区。例如:

%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>

请注意,我们允许大小为 1 的因子,即使它们无法分片也是如此,这主要是为了完整起见,因为许多运算(例如点运算)都具有大小为 1 的维度,这些维度在运算数和结果之间对应。

因子类型

  • reduction_factors 包含需要缩减的因子的索引,例如点运算中的收缩维度。
  • need_replication_factors 包含需要完全复制的因素(例如排序操作中的已排序维度)的索引。
  • permutation_factors 包含需要集体排序的因子的索引(如果这些因子已分片),例如填充操作中的填充维度。
  • 所有其他因素都被视为透传因子,即如果在映射到它们的所有张量中以相同方式进行分片,则不需要进行任何通信。

blocked_propagation_factors 包含不允许传播分片的因素。它与因子类型正交。也就是说,被阻塞的传播因子可以是任何因子类型。

is_custom_rule 用于描述这是否为用户定义的规则。用户可以为自定义调用定义分片规则,也可以覆盖标准操作的预定义分片规则。自定义规则始终保留/永不移除。

限制

  • 运算数/结果映射的数量必须与运算的运算数/结果数量一致。
  • 至少有一个映射(不能为没有运算对象/运算结果的运算提供规则)。
  • 每个 TensorMappingAttr 的秩与相应张量类型的秩相匹配。
  • 对于每组因子(reduction_factorsneed_replication_factorspermutation_factors):
    • 元素必须在 [0, $factor_sizes] 范围内。
    • 每个组内和各组之间均无重复的因子索引。

参数:

参数 C++ 类型 说明
factor_sizes ::llvm::ArrayRef<int64_t> 此规则中所有因子的大小
operand_mappings ::llvm::ArrayRef<TensorMappingAttr> 运算数映射
result_mappings ::llvm::ArrayRef<TensorMappingAttr> 结果映射
reduction_factors ::llvm::ArrayRef<int64_t> 需要减少的因素
need_replication_factors ::llvm::ArrayRef<int64_t> 需要完全复制的因素
permutation_factors ::llvm::ArrayRef<int64_t> 需要使用 collective-permute 的因子
blocked_propagation_factors ::llvm::ArrayRef<int64_t> 不会传播分片的因素
is_custom_rule bool 规则是否适用于 stablehlo.custom_call

SubAxisInfoAttr

有关此子轴如何从完整轴派生的相关信息

语法:

#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>

将完整轴拆分为 n 个子轴时,该轴会重塑为 [k_1,...,k_n],并且第 i 个子轴可以表示为其左侧所有轴大小 m=prod(k_1,...,k_(i-1))(也称为“前大小”)和大小 k_i 的乘积。因此,sub-axis-info 属性会存储这两个数字,并表示为:(m)k(预设大小为 m,大小为 k)。

限制

  • pre-size 至少为 1。
  • size 大于 1。
  • pre-size 必须对全轴的大小进行除法,即 pre-sizesize 都对全轴的大小进行除法,并且子轴不超出全轴。
  • 子轴的大小不等于相应全轴的大小,在这种情况下,应改用全轴。

参数:

参数 C++ 类型 说明
pre_size int64_t 此子轴左侧子轴大小的乘积
size int64_t 此子轴的大小

TensorMappingAttr

张量每个维度的因子映射。

语法:

#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>

限制

  • dim_mappings 中的元素必须满足 DimMappingAttr 中的约束条件。
  • 各个维度之间没有重复的因子索引。

参数:

参数 C++ 类型 说明
dim_mappings ::llvm::ArrayRef<DimMappingAttr> 维度映射

TensorShardingAttr

张量分片

语法:

#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>   # replicated_axes
>

张量分片绑定到特定网格,并且只能引用该网格中的轴名称。维度分片可告诉我们张量每个维度沿哪些轴(或子轴)进行分片,从主轴到次轴。所有其他不对维度进行分片的轴都会隐式或显式(如果它们显示在复制轴列表中)复制。

此分片绑定的网格可以通过符号名称(引用相应的 MeshOp 符号)或内嵌的 MeshAttr 指定。

限制

  • dim_shardings 中的元素必须满足 DimensionShardingAttr 中列出的约束条件。
  • replicated_axes 中的元素必须满足 AxisRefListAttr 中列出的约束条件。
  • 如果相应的张量类型不是 ShapedType,则分片必须具有秩 0 且没有复制轴。
  • 张量应具有 rank。
  • 维度分片数量等于张量的秩。
  • 大小为 0 的维度不会分片。
  • replicated_axes 中的项相对于 mesh_or_ref 是有序的(请参阅 AxisRefAttr::getMeshComparator)。

参数:

参数 C++ 类型 说明
mesh_or_ref ::mlir::Attribute 网格属性或平面网格符号引用属性
dim_shardings ::llvm::ArrayRef<DimensionShardingAttr> 维度分片
replicated_axes ::llvm::ArrayRef<AxisRefAttr> 轴引用

TensorShardingPerValueAttr

按运算的运算数/结果对张量进行分片

语法:

#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>

TensorShardingAttr 列表,每个操作数/结果对应一个 TensorShardingAttr

限制

  • shardings 中的元素必须满足 TensorShardingAttr 的约束条件。

参数:

参数 C++ 类型 说明
分片 ::llvm::ArrayRef<TensorShardingAttr> 按值分片

枚举

PropagationDirection

传播方向枚举

支持请求:

符号 字符串
0
前进 1 前进
向后 2 向后
双方 3 双方