OpInterface の定義

CollectiveOpInterface(Sdy_CollectiveOpInterface

すべての集約オペレーションのインターフェース。outSharding 属性の一般的な get/set をカプセル化します。

制約:

  • オペランドにシャーディングがないと、allowMissingInputSharding() は true を返します。
  • out_sharding は、対応する型に対して有効です。
  • allowDifferentMeshes() が false を返す場合、オペランドと結果のシャーディングは同じメッシュにする必要があります。
  • オペランドと結果のシャーディングに同じランク。

メソッド:

getOutSharding

::mlir::sdy::TensorShardingAttr getOutSharding();

集約オペレーションの出力テンソルのシャーディングを返します。

setOutShardingAttr

void setOutShardingAttr(::mlir::sdy::TensorShardingAttr sharding);

集約オペレーションの出力テンソルのシャーディングを設定します。

getTensor

::mlir::TypedValue<::mlir::TensorType> getTensor();

集約オペレーションのテンサー オペランドを取得します。

getType

::mlir::Type getType();

集約オペレーションの結果の型を取得します。

allowDifferentMeshes

bool allowDifferentMeshes();

集約オペレーションで入力シャーディングと出力シャーディングに異なるメッシュを許可するかどうかを示します。

allowMissingInputSharding

bool allowMissingInputSharding();

集約オペレーションで入力にシャーディングを適用しない(暗黙的に完全に複製する)かどうかを示します。

ShardableDataFlowOpInterface(Sdy_ShardableDataFlowOpInterface

shardy がこのインターフェースを拡張するオペレーションのデータフロー エッジを介してシャーディングを伝播できるようにするオペレーション インターフェース。

ある演算 X のデータフロー エッジは、一連のソース(それぞれ X のオペランドまたは X のブロック終端子のオペランド)と一連のターゲット(それぞれ X の結果または X のブロック引数)間のブリッジを定義します。これにより、すべてのソースとターゲットが同じ方法でシャーディングされるようにします。1 つのオペレーションには、互いに直交する複数のデータフロー エッジを含めることができます。

オーナーは、シャーディーのプロパゲーションによって使用されるデータフロー エッジのユーザー指定のターゲットです。ユーザーが任意に選択できますが、静的である必要があります。

次に例を示します。

  y_1, ..., y_n = custom_op (x_1, ..., x_n)
                  ((body_arg_1,..., body_arg_n) {
                    ...
                    return return_value_1, ..., return_value_n
                  })

この custom_op には、データフロー エッジに 2 つのタイプがあります。return_value_i(ソース)と y_i(ターゲット)の間に n 個のエッジ、x_i(ソース)と body_arg_i(ターゲット)の間に n 個のエッジです。この場合、エッジ オーナーはターゲットと同じです。

複数のターゲットを持つオペレーションの例を次に示します。

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

この while オペレーションには n 個のデータフロー エッジがあり、i 番目のデータフロー エッジはソース x_ireturn_value_i とターゲット y_ipred_arg_ibody_arg_i の間にあります。

メソッド:

getBlockArgumentEdgeOwnerShardings

mlir::SmallVector<mlir::sdy::TensorShardingAttr> getBlockArgumentEdgeOwnerShardings();

すべてのブロック引数データフロー エッジ オーナーのシャーディングを返します。

setBlockArgumentEdgeOwnerShardings

void setBlockArgumentEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

すべてのブロック引数のエッジ オーナーの shardings を設定します。

getOpResultEdgeOwnerShardings

mlir::SmallVector<mlir::sdy::TensorShardingAttr> getOpResultEdgeOwnerShardings();

すべてのオペレーション結果データフロー エッジ オーナーのシャーディングを返します。

setOpResultEdgeOwnerShardings

void setOpResultEdgeOwnerShardings(mlir::ArrayRef<mlir::sdy::TensorShardingAttr> shardings);

すべてのオペレーション結果エッジ オーナーの shardings を設定します。

transformTargetSharding

mlir::sdy::TensorShardingAttr transformTargetSharding(mlir::Value target, mlir::sdy::TensorShardingAttr sharding, mlir::sdy::DataFlowShardingTransformType transformType);

transformType に応じてターゲットの sharding を変換します。

詳しくは、DataFlowShardingTransformType をご覧ください。

getBlockArgumentEdgeOwners

mlir::ArrayRef<mlir::BlockArgument> getBlockArgumentEdgeOwners();

すべてのブロック引数のエッジ オーナーを取得します。

getOpResultEdgeOwners

mlir::ResultRange getOpResultEdgeOwners();

すべてのオペレーション結果エッジ オーナーを取得します。

getEdgeSources

mlir::SmallVector<mlir::OpOperand*> getEdgeSources(mlir::Value owner);

エッジ owner を指定して、データフロー エッジソースを取得します。

getEdgeOwnerFromTarget

mlir::Value getEdgeOwnerFromTarget(mlir::Value target);

所有者である可能性のある target を指定して、データフロー エッジの所有者 target を取得します。

getEdgeOwnerFromSource

mlir::Value getEdgeOwnerFromSource(mlir::OpOperand&source);

source を指定して、データフロー エッジのオーナー ターゲットを取得します。

getNonEdgeOwnerTargets

mlir::SmallVector<mlir::Value> getNonEdgeOwnerTargets(mlir::Value owner);

エッジ owner を指定して、データフロー エッジのオーナー以外のターゲットを取得します。

ShardingRuleOpInterface(Sdy_ShardingRuleOpInterface

OP が独自のシャーディング ルールを定義できる OP インターフェース。シャーディング ルールでは、オペレーションのさまざまなプロパティ(属性、オペランドの形状、結果の形状など)に応じてオペレーションをパーティショニングする方法が指定されます。詳細については、OpShardingRuleAttr をご覧ください。

メソッド:

getShardingRule

mlir::sdy::OpShardingRuleAttr getShardingRule();

OP のシャーディング ルールを返します。