OpInterface 定义

CollectiveOpInterface (Sdy_CollectiveOpInterface)

所有集合运算的接口。封装了 outSharding 属性的常见 get/set。

限制

  • 运算数必须具有分片,否则 allowMissingInputSharding() 会返回 true。
  • out_sharding 相对于相应类型是有效的。
  • 如果 allowDifferentMeshes() 返回 false,则操作数和结果分片必须具有相同的网格。
  • 运算数和结果分片具有相同的排名。

方法:

getOutSharding

::mlir::sdy::TensorShardingAttr getOutSharding();

返回集合运算的输出张量分片。

setOutShardingAttr

void setOutShardingAttr(::mlir::sdy::TensorShardingAttr sharding);

设置集合运算的输出张量分片。

getTensor

::mlir::TypedValue<::mlir::TensorType> getTensor();

获取集合运算的张量运算数。

getType

::mlir::Type getType();

获取集合运算结果的类型。

allowDifferentMeshes

bool allowDifferentMeshes();

指示集合运算是否允许输入和输出分片具有不同的网格。

allowMissingInputSharding

bool allowMissingInputSharding();

指示集合运算是否允许输入没有分片,即隐式完全复制。

ShardableDataFlowOpInterface (Sdy_ShardableDataFlowOpInterface)

一种操作接口,允许 Shardy 通过扩展此接口的操作的数据流边缘传播分片。

某个操作 X 的数据流边定义了一组来源(每个都是 X 的运算数或 X 的块终止符的运算数)和一组目标(每个都是 X 的结果或 X 的块参数)之间的桥梁,以便所有来源和目标都应以相同的方式进行分片。一个运算可以有多个彼此正交的数据流边。

所有者是 shardy 传播所用的数据流边的用户指定目标。用户可以任意选择,但该值需要是静态的。

例如:

  y_1, ..., y_n = custom_op (x_1, ..., x_n)
                  ((body_arg_1,..., body_arg_n) {
                    ...
                    return return_value_1, ..., return_value_n
                  })

此 custom_op 的数据流边有两种类型,return_value_i(来源)和 y_i(目标)之间有 n 条边,x_i(来源)和 body_arg_i(目标)之间也有 n 条边。在这种情况下,边缘所有者与目标相同。

下面是一个具有多个目标的操作示例:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

此 while 运算有 n 条数据流边,第 i 条数据流边位于源 x_ireturn_value_i 和目标 y_ipred_arg_ibody_arg_i 之间。

方法:

getBlockArgumentEdgeOwnerShardings

mlir::SmallVector<mlir::sdy::TensorShardingAttr> getBlockArgumentEdgeOwnerShardings();

返回所有块参数数据流边缘所有者的分片。

setBlockArgumentEdgeOwnerShardings

void setBlockArgumentEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

设置所有块参数边缘所有者的 shardings

getOpResultEdgeOwnerShardings

mlir::SmallVector<mlir::sdy::TensorShardingAttr> getOpResultEdgeOwnerShardings();

返回所有操作结果数据流边缘所有者的分片。

setOpResultEdgeOwnerShardings

void setOpResultEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

设置所有操作结果边缘所有者的 shardings

transformTargetSharding

mlir::sdy::TensorShardingAttr transformTargetSharding(mlir::Value target, mlir::sdy::TensorShardingAttr sharding, mlir::sdy::DataFlowShardingTransformType transformType);

根据 transformType 转换目标的 sharding

如需了解详情,请参阅 DataFlowShardingTransformType

getBlockArgumentEdgeOwners

mlir::ArrayRef<mlir::BlockArgument> getBlockArgumentEdgeOwners();

获取所有块参数边所有者。

getOpResultEdgeOwners

mlir::ResultRange getOpResultEdgeOwners();

获取所有操作结果边缘所有者。

getEdgeSources

mlir::SmallVector<mlir::OpOperand*> getEdgeSources(mlir::Value owner);

给定边 owner,获取数据流边源。

getEdgeOwnerFromTarget

mlir::Value getEdgeOwnerFromTarget(mlir::Value target);

给定可能或不可能是所有者的 target,获取数据流边的所有者 target

getEdgeOwnerFromSource

mlir::Value getEdgeOwnerFromSource(mlir::OpOperand&source);

给定 source 时,获取数据流边的所有者目标。

getNonEdgeOwnerTargets

mlir::SmallVector<mlir::Value> getNonEdgeOwnerTargets(mlir::Value owner);

给定边 owner,获取数据流边的非所有者目标。

ShardingRuleOpInterface (Sdy_ShardingRuleOpInterface)

操作接口,允许操作定义自己的分片规则。分片规则指定了如何根据操作的各种属性(任何属性、运算元的形状、结果的形状等)对操作进行分区。如需了解详情,请参阅 OpShardingRuleAttr

方法:

getShardingRule

mlir::sdy::OpShardingRuleAttr getShardingRule();

返回操作的分片规则。