OpInterface 定义

ShardableDataFlowOpInterface (ShardableDataFlowOpInterface)

一个操作接口,允许分片通过扩展此接口的操作的数据流边缘传播分片。

某个操作 X 的数据流边定义了一组来源(每个都是 X 的运算数或 X 的块终止符的运算数)和一组目标(每个都是 X 的结果或 X 的块参数)之间的桥梁,以便所有来源和目标都应以相同的方式进行分片。一个操作可以有多个彼此正交的数据流边缘。

所有者是 shardy 传播所用的数据流边的用户指定目标。用户可以任意选择,但需要是静态的。

例如:

  y_1, ..., y_n = custom_op (x_1, ..., x_n)
                  ((body_arg_1,..., body_arg_n) {
                    ...
                    return return_value_1, ..., return_value_n
                  })

此 custom_op 的数据流边有两种类型,return_value_i(来源)和 y_i(目标)之间有 n 条边,x_i(来源)和 body_arg_i(目标)之间也有 n 条边。在这种情况下,边缘所有者与目标相同。

下面是一个具有多个目标的操作示例:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

此 while 运算有 n 条数据流边,第 i 条数据流边位于源 x_ireturn_value_i 和目标 y_ipred_arg_ibody_arg_i 之间。

方法:

getBlockArgumentEdgeOwnerShardings

mlir::ArrayRef<mlir::sdy::TensorShardingAttr> getBlockArgumentEdgeOwnerShardings();

返回所有块参数数据流边缘所有者的分片。

setBlockArgumentEdgeOwnerSharding

void setBlockArgumentEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);

使用给定的 index 设置块参数边缘所有者的 sharding

setBlockArgumentEdgeOwnerShardings

void setBlockArgumentEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

设置所有块参数边缘所有者的 shardings

getOpResultEdgeOwnerShardings

mlir::ArrayRef<mlir::sdy::TensorShardingAttr> getOpResultEdgeOwnerShardings();

返回所有操作结果数据流边缘所有者的分片。

setOpResultEdgeOwnerSharding

void setOpResultEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);

设置具有给定 index 的操作结果边缘所有者的 sharding

setOpResultEdgeOwnerShardings

void setOpResultEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

设置所有操作结果边缘所有者的 shardings

getBlockArgumentEdgeOwners

mlir::ArrayRef<mlir::BlockArgument> getBlockArgumentEdgeOwners();

获取所有块参数边所有者。

getOpResultEdgeOwners

mlir::ResultRange getOpResultEdgeOwners();

获取所有操作结果边缘所有者。

getEdgeSources

mlir::SmallVector<mlir::Value> getEdgeSources(mlir::Value target);

给定 target 值时,获取数据流边缘来源。

getEdgeOwnerFromTarget

mlir::Value getEdgeOwnerFromTarget(mlir::Value target);

给定可能或不可能是所有者的 target,获取数据流边的所有者 target

getEdgeOwnerFromSource

mlir::Value getEdgeOwnerFromSource(mlir::OpOperand&source);

在给定 source 的情况下,获取数据流边缘的所有者目标。

getNonEdgeOwnerTargets

mlir::SmallVector<mlir::Value> getNonEdgeOwnerTargets(mlir::Value owner);

给定边 owner,获取数据流边的非所有者目标。

ShardingRuleOpInterface (ShardingRuleOpInterface)

一个操作接口,允许操作定义自己的分片规则。 分片规则指定了如何根据操作的各种属性(任何属性、运算元的形状、结果的形状等)对操作进行分区。如需了解详情,请参阅 OpShardingRuleAttr

方法:

getShardingRule

mlir::sdy::OpShardingRuleAttr getShardingRule();

返回操作的分片规则。