ShardableDataFlowOpInterface (ShardableDataFlowOpInterface
)
一个操作接口,允许分片通过扩展此接口的操作的数据流边缘传播分片。
某个操作 X 的数据流边定义了一组来源(每个都是 X 的运算数或 X 的块终止符的运算数)和一组目标(每个都是 X 的结果或 X 的块参数)之间的桥梁,以便所有来源和目标都应以相同的方式进行分片。一个操作可以有多个彼此正交的数据流边缘。
所有者是 shardy 传播所用的数据流边的用户指定目标。用户可以任意选择,但需要是静态的。
例如:
y_1, ..., y_n = custom_op (x_1, ..., x_n)
((body_arg_1,..., body_arg_n) {
...
return return_value_1, ..., return_value_n
})
此 custom_op 的数据流边有两种类型,return_value_i
(来源)和 y_i
(目标)之间有 n 条边,x_i
(来源)和 body_arg_i
(目标)之间也有 n 条边。在这种情况下,边缘所有者与目标相同。
下面是一个具有多个目标的操作示例:
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
此 while 运算有 n 条数据流边,第 i 条数据流边位于源 x_i
、return_value_i
和目标 y_i
、pred_arg_i
、body_arg_i
之间。
方法:
getBlockArgumentEdgeOwnerShardings
mlir::ArrayRef<mlir::sdy::TensorShardingAttr> getBlockArgumentEdgeOwnerShardings();
返回所有块参数数据流边缘所有者的分片。
setBlockArgumentEdgeOwnerSharding
void setBlockArgumentEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);
使用给定的 index
设置块参数边缘所有者的 sharding
。
setBlockArgumentEdgeOwnerShardings
void setBlockArgumentEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);
设置所有块参数边缘所有者的 shardings
。
getOpResultEdgeOwnerShardings
mlir::ArrayRef<mlir::sdy::TensorShardingAttr> getOpResultEdgeOwnerShardings();
返回所有操作结果数据流边缘所有者的分片。
setOpResultEdgeOwnerSharding
void setOpResultEdgeOwnerSharding(unsigned index, mlir::sdy::TensorShardingAttr sharding);
设置具有给定 index
的操作结果边缘所有者的 sharding
。
setOpResultEdgeOwnerShardings
void setOpResultEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);
设置所有操作结果边缘所有者的 shardings
。
getBlockArgumentEdgeOwners
mlir::ArrayRef<mlir::BlockArgument> getBlockArgumentEdgeOwners();
获取所有块参数边所有者。
getOpResultEdgeOwners
mlir::ResultRange getOpResultEdgeOwners();
获取所有操作结果边缘所有者。
getEdgeSources
mlir::SmallVector<mlir::Value> getEdgeSources(mlir::Value target);
给定 target
值时,获取数据流边缘来源。
getEdgeOwnerFromTarget
mlir::Value getEdgeOwnerFromTarget(mlir::Value target);
给定可能或不可能是所有者的 target
,获取数据流边的所有者 target
。
getEdgeOwnerFromSource
mlir::Value getEdgeOwnerFromSource(mlir::OpOperand&source);
在给定 source
的情况下,获取数据流边缘的所有者目标。
getNonEdgeOwnerTargets
mlir::SmallVector<mlir::Value> getNonEdgeOwnerTargets(mlir::Value owner);
给定边 owner
,获取数据流边的非所有者目标。
ShardingRuleOpInterface (ShardingRuleOpInterface
)
一个操作接口,允许操作定义自己的分片规则。
分片规则指定了如何根据操作的各种属性(任何属性、运算元的形状、结果的形状等)对操作进行分区。如需了解详情,请参阅 OpShardingRuleAttr
。
方法:
getShardingRule
mlir::sdy::OpShardingRuleAttr getShardingRule();
返回操作的分片规则。