Shardy(SDY)方言は、軸ベースのテンソル シャーディング表現と、シャーディングをテンソルに接続するための追加の API コンポーネントを定義します。
運用
sdy.all_gather(sdy::AllGatherOp)
軸に沿ってオールギャザー通信を実行する
構文:
operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
gathering_axes で指定された軸に沿ってテンソルのチャンクを収集します。
gathering_axes は、軸のリストです。外側のリストはテンソルのディメンションを超えています。各内部リストには、各ディメンションで個別の集計を実行する軸を指定します。これは、オペランド(tensor)のシャーディングに適用され、結果(out_sharding)のシャーディングが取得されます。
out_sharding は、結果のシャーディングを決定するために使用されません。代わりに、結果のシャーディングはオペランドと gathering_axes のシャーディングによって決定され、out_sharding はこの推定シャーディングと一致する必要があります。
例:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>
制約:
Sdy_CollectiveOpInterfaceに記載されている制約を満たす必要があります。gathering_axesの要素は、AxisRefListAttrに記載されている制約を満たす必要があります。- オペランド シャーディングに
gathering_axesを適用すると、out_shardingが得られます。
特性: SameOperandsAndResultType
インターフェース: InferTypeOpInterface、Sdy_CollectiveOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
gathering_axes | ::mlir::sdy::ListOfAxisRefListsAttr | 軸参照リストのリスト |
out_sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
tensor |
任意の型の値のテンサー |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のテンサー |
sdy.all_reduce(sdy::AllReduceOp)
軸に沿ってオール リデュース通信を実行する
構文:
operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
reduction_axes で指定された軸に沿ってテンソルのチャンクを減算します。reduction_axes の順序は結果に影響しませんが、対応するレプリカ グループの順序に影響する可能性があります。
制約:
Sdy_CollectiveOpInterfaceに記載されている制約を満たす必要があります。reduction_axesはAxisRefListAttrに記載されている制約を満たしている必要があります。reduction_axesはオペランドのシャーディング軸と重複してはいけません。
特性: SameOperandsAndResultType
インターフェース: CollectiveOpInterface、InferTypeOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
reduction_axes | ::mlir::sdy::AxisRefListAttr | 軸参照のリスト |
out_sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
tensor |
任意の型の値のテンサー |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のテンサー |
sdy.all_slice(sdy::AllSliceOp)
軸に沿って動的スライス オペレーションを実行する
構文:
operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
slicing_axes で指定された軸に沿ってテンソルのチャンクをスライスします。sdy.all_slice と sdy.all_gather の間に代数的双対性があります。
slicing_axes は、軸のリストです。外側のリストはテンソルのディメンションを超えています。各内部リストには、各ディメンションでスライスを実行する軸を指定します。これはオペランド(tensor)のシャーディングに適用され、結果(out_sharding)のシャーディングが取得されます。
out_sharding は、結果のシャーディングを決定するために使用されません。代わりに、結果のシャーディングはオペランドと slicing_axes のシャーディングによって決定され、out_sharding はこの推定シャーディングと一致する必要があります。
例:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>
制約:
slicing_axesの要素は、AxisRefListAttrに記載されている制約を満たす必要があります。Sdy_CollectiveOpInterfaceに記載されている制約を満たす必要があります。- オペランド シャーディングに
slicing_axesを適用すると、out_shardingが得られます。
特性: SameOperandsAndResultType
インターフェース: CollectiveOpInterface、InferTypeOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
slicing_axes | ::mlir::sdy::ListOfAxisRefListsAttr | 軸参照リストのリスト |
out_sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
tensor |
任意の型の値のテンサー |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のテンサー |
sdy.all_to_all(sdy::AllToAllOp)
軸に沿ってオールツーオール通信を実行する
構文:
operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
このオペレーションは、パラメータリスト内の(axes、src_dim、tgt_dim)タプルごとに、テンソルのチャンクをディメンション tgt_dim と axes で指定された軸に沿ってスライスし、それらのチャンクを軸に沿って分散し、ディメンション src_dim に沿って連結します。
このオペレーションは、基本的に src_dim と axes に沿ったオールギャザーと、tgt_dim と axes に沿ったオールスライスの組み合わせです。つまり、入力テンソルの軸シャーディング ディメンション src_dim の接尾辞が、出力テンソルの軸シャーディング ディメンション tgt_dim に追加されます。
オールツーオール シャーディングはオペランド(tensor)のシャーディングに適用され、結果のシャーディング(out_sharding)が取得されます。
out_sharding は、結果のシャーディングを決定するために使用されません。代わりに、結果のシャーディングはオペランドのシャーディング(src_dim、tgt_dim、axes)によって決定され、out_sharding はこの推定シャーディングと一致する必要があります。
例:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>
制約:
Sdy_CollectiveOpInterfaceに記載されている制約を満たす必要があります。- パラメータ リストを空にすることはできません。
paramsのパラメータごとに、次の操作を行います。axesの要素はAxisRefAttrの制約を満たす必要があります。src_dimとtgt_dimは有効なディメンション(正でテンソルのランクより小さい)にする必要があります。src_dimまたはtgt_dimは、すべてのパラメータで一意である必要があります。src_dimは、すべてのパラメータで昇順に並べ替える必要があります。
- オペランド シャーディングで
axesをsrc_dimからtgt_dimに移動すると、out_shardingが返されます。
特性: SameOperandsAndResultType
インターフェース: InferTypeOpInterface、Sdy_CollectiveOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
params | ::mlir::sdy::AlltoAllParamListAttr | すべてのパラメータのリスト |
out_sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
tensor |
任意の型の値のテンサー |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のテンサー |
sdy.collective_permute(sdy::CollectivePermuteOp)
軸を置き換えるために集約型の並べ替え通信を実行する
構文:
operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
入力テンソルのチャンクを各デバイスから別のデバイスに送信し、テンソルをシャーディングする軸を並べ替えまたは置き換えます。
集約シャッフルでは、各ディメンションが以前と同じようにシャーディングされるように入力シャーディングを変換できます。つまり、サイズの積が以前にテンソルをシャーディングした軸の積と一致する軸に沿ってシャーディングする必要があります。
これは、単一のディメンションまたは異なるディメンション間で軸の順序を変更したり、シャーディングされた軸を複製された軸と入れ替えたりする場合に便利です。
次の例では、シャーディングされたテンソルのサイズは tensor<1x4x2xf32> で、これは集約並べ替えによって保持されます。
例:
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>
制約:
Sdy_CollectiveOpInterfaceに記載されている制約を満たす必要があります。- 入力シャーディングと出力シャーディングのメッシュが異なる場合は、これらのメッシュの軸が完全に同じで、デバイス ID の順序が異なる必要があります。
- 各ディメンションで、
out_shardingのシャーディング軸サイズの積は、対応するオペランド ディメンション シャーディングと一致している必要があります。
特性: SameOperandsAndResultType
インターフェース: CollectiveOpInterface、InferTypeOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
out_sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
tensor |
任意の型の値のテンサー |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のテンサー |
sdy.constant(sdy::ConstantOp)
定数オペレーション
定数 value から output テンソルを生成します。
参照: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
例:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
特性: AlwaysSpeculatableImplTrait
インターフェース: ConditionallySpeculatable、InferTypeOpInterface、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
value | ::mlir::ElementsAttr | 定数ベクトル/テンソル属性 |
結果:
| 結果 | 説明 |
|---|---|
output |
任意の型の値の静的形状テンソル |
sdy.data_flow_edge(sdy::DataFlowEdgeOp)
データフロー エッジ演算。
構文:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
ある演算 X のデータフロー エッジは、一連のソース(それぞれ X のオペランドまたは X のブロック終端子のオペランド)と一連のターゲット(それぞれ X の結果または X のブロック引数)間のブリッジを定義します。これにより、すべてのソースとターゲットが同じ方法でシャーディングされるようになります。
1 つのオペレーションには、互いに直交する複数のデータフロー エッジを含めることができます。
次に例を示します。
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
この while オペレーションには n 個のデータフロー エッジがあり、i 番目のデータフロー エッジはソース x_i、return_value_i とターゲット y_i、pred_arg_i、body_arg_i の間にあります。
sdy.data_flow_edge は、エッジのオーナー(任意のターゲットですが、ブロック引数ではなく op 結果が望ましい)を入力として受け取ります。この値は他の用途に使用しないでください。このオペレーションは、元々使用されていない入力を受け取ることができるため、純粋ではありません。
sdy.data_flow_edge には、エッジのすべてのターゲットに対してオプションのシャーディングも保持されます。このシャーディングは、伝播中にターゲットのシャーディング(接続可能であれば)ではなく、更新する必要があります。これは、オペレーションにエッジが多数ある場合に便利です。次のように、はるかに効率的です。
- 各エッジを個別に伝播します。
- すべてのターゲットを一度に更新するのではなく、各エッジのシャーディングを個別に更新します(たとえば、オペレーションに結果シャーディング用の不変の
TensorShardingPerValueAttrが 1 つあります)。 - ソースのシャーディングが変更されたときに、各エッジを個別にワークリストに追加します。
伝播では、sdy.data_flow_edge のすべてのソースとターゲット間でシャーディングが伝播されます。これは、ソースがオペランド、ターゲットが結果、ID が sdy.op_sharding_rule の通常のオペレーションの場合と同じです。つまり、順方向の伝播はソースからターゲットへの伝播であり、逆方向の伝播はターゲットからソースへの伝播です。
sdy.data_flow_edge の入力を SdyDialect オペレーションで定義することは許可されていないため、sdy.sharding 属性が登録されていないオペレーションで定義されていると想定できます。
特性: SameOperandsAndResultType
インターフェース: InferTypeOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
input |
任意の型の値の形状 |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値の形状 |
sdy.manual_computation(sdy::ManualComputationOp)
手動集約を使用したマルチデバイス並列オペレーション
構文:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
明示的な集合を使用してデバイスごとのローカルコードで記述されたリージョンにジャンプします。ここで、論理シェイプはデバイスごとのローカル物理バッファ シェイプと一致し、集合は物理的なクロスデバイス通信と完全に一致します。
ボディは manual_axes に対してローカルです。伝播は、manual_axes リストにない自由軸のボディを介して行われます。
制約:
in_shardingsとout_shardingsの要素は、TensorShardingAttrに記載されている制約を満たす必要があります。- OP リージョンのグローバル テンソルとローカル テンソルの入力/出力の数が一致している必要があります。
- 手動軸は、各ディメンション シャーディング内のフリー軸の前に配置する必要があります。
- 手動軸にはパディングを追加できません。つまり、ディメンションのサイズは、対応する手動軸のサイズで割り切れる必要があります。
- op リージョン引数/結果のグローバル シェイプとローカル シェイプは一致している必要があります。
- 手動で分割された軸はありません。
特性: IsolatedFromAbove、RecursiveMemoryEffects、SingleBlockImplicitTerminator<ReturnOp>、SingleBlock
インターフェース: ShardableDataFlowOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | オペレーションのオペランド/結果ごとのテンソル シャーディング |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | オペレーションのオペランド/結果ごとのテンソル シャーディング |
manual_axes | ::mlir::sdy::ManualAxesAttr | ManualComputationOp が手動である軸のリスト |
オペランド:
| オペランド | 説明 |
|---|---|
tensors |
任意の型の値のランク付けされたテンソルの可変引数 |
結果:
| 結果 | 説明 |
|---|---|
results |
任意の型の値のランク付けされたテンソルの可変引数 |
sdy.mesh(sdy::MeshOp)
名前付きメッシュ
構文:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
新しい名前付きメッシュを定義します。モジュール内のすべてのメッシュのデバイス数は同じである必要があります(device_id が 1 つのメッシュを除く)。メッシュは、モジュールの SymbolTable に表示される Symbol オペレーションで、その name によって参照できます。
特性: HasParent<ModuleOp>
インターフェース: Symbol
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
sym_name | ::mlir::StringAttr | string 属性 |
mesh | ::mlir::sdy::MeshAttr | 軸のメッシュとデバイスのリスト |
sdy.named_computation(sdy::NamedComputationOp)
名前付き計算オペレーション
構文:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
計算(オペレーションのブロック)をグループ化し、名前を付けます。すべてがインライン化されているかのように、リージョン内外への伝播が行われます。
これは、呼び出し命令を介して他の関数に伝播する処理に使用できます。Shardy を使用する場合は、呼び出しオペレーションを sdy.named_computation オペレーションに変換し、呼び出された関数本体を named_computation 本体に複製/コピーするインポート/エクスポート パスを記述する必要があります。
リージョン内の各ブロック引数と戻り値の型は、オペランドの型とオペレーションの結果型と同じである必要があります。
例:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
特性: IsolatedFromAbove、RecursiveMemoryEffects、RecursivelySpeculatableImplTrait、SingleBlockImplicitTerminator<ReturnOp>、SingleBlock
インターフェース: ConditionallySpeculatable、InferTypeOpInterface、ShardableDataFlowOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
name | ::mlir::StringAttr | string 属性 |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | オペレーションのオペランド/結果ごとのテンソル シャーディング |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | オペレーションのオペランド/結果ごとのテンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
operands |
任意の型の可変引数 |
結果:
| 結果 | 説明 |
|---|---|
| «unnamed» | 任意の型の可変引数 |
sdy.propagation_barrier(sdy::PropagationBarrierOp)
伝播バリアのオペレーション
構文:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
このオペレーションは、入力として受け取った値と同じ値を出力する、ID オペレーションのように動作します。ただし、伝播に関しては、特定の方向にのみ伝播が許可されます。
これにより、バリア演算の結果とそのオペランドの使用間でシャーディングが伝播されるのを防ぐことができます。
FORWARDは、シャーディングがオペランドから結果にのみ流れることを意味します。BACKWARDは、シャーディングが結果からオペランドにのみ流れることを意味します。NONEは、このオペレーションを介してシャーディングを伝播できないことを意味します。- このオペレーションは冗長であるため、
BOTHを指定できません。
特性: AlwaysSpeculatableImplTrait、SameOperandsAndResultType
インターフェース: ConditionallySpeculatable、InferTypeOpInterface、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | 伝播方向の列挙型 |
オペランド:
| オペランド | 説明 |
|---|---|
input |
任意の型の値のランク付けされたテンソル |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のランク付けされたテンソル |
sdy.reshard(sdy::ReshardOp)
テンソルを別のシャーディングに再シャーディングする
構文:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
指定されたシャーディングで入力テンソルをシャーディングします。これは、入力テンソルの既存のシャーディングとは異なります。
ShardingConstraintOp と ReshardOp はどちらも、シャーディングをテンソルに接続します。有効期間は次のとおりです。
- シャーディング伝播の前に、ShardingConstraintOp がユーザーによって追加されます。
- シャーディング伝播では ShardingConstraintOp が使用されます。シャーディング伝播の結果に ShardingConstraintOp はありません。代わりに、必要に応じて ReshardOp を追加できます。
- パーティショナーは、ReshardOp を集約オペレーション(または ID オペレーション)に変換します。パーティショナーの結果に ReshardOp が含まれていないこと。
// TODO(b/331680067). 冗長な // シャード変更オペレーションを削除するように正規化パターンを追加しました。
特性: AlwaysSpeculatableImplTrait、SameOperandsAndResultType
インターフェース: ConditionallySpeculatable、InferTypeOpInterface、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
input |
任意の型の値のテンサー |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のテンサー |
sdy.return(sdy::ReturnOp)
sdy.return オペレーションは、sdy リージョンベースのオペレーションと他の Shardy リージョンベースのオペレーションに接続されているリージョンを終了します。可変長です。引数として、型が任意(ただし同じ種類、例: AnyTensor)の値のリストを受け取るため、Shardy IR スタックのさまざまなレベルで再利用できます。
構文:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
特性: AlwaysSpeculatableImplTrait、Terminator
インターフェース: ConditionallySpeculatable、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
オペランド:
| オペランド | 説明 |
|---|---|
results |
任意の型の可変引数 |
sdy.sharding_constraint(sdy::ShardingConstraintOp)
テンソルを指定されたシャーディングに制約する
構文:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
中間テンソル(matmul の結果など)にシャーディングをアタッチして、そのテンソルまたはその使用のサブセットをシャーディングする方法を示します。
シャーディングにオープンなディメンションと制約のない軸がある場合、テンソルをオープンなディメンションに沿ってさらにシャーディングできます。
このオペレーションは次のいずれかです。
- 使用されていない(ダングリング)- つまり、適用されたシャーディングは、入力テンソル自体がシャーディングされる方法です。
- 使用がある - つまり、適用されたシャーディングは、シャーディング制約オペレーションの使用方法をシャーディングする方法ですが、入力テンソルの他の使用方法は異なるシャーディングを持つ場合があります(入力テンソルに他の使用方法がない場合は、使用なしの場合と同じ動作になります)。
特性: SameOperandsAndResultType
インターフェース: InferTypeOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
| オペランド | 説明 |
|---|---|
input |
任意の型の値のテンサー |
結果:
| 結果 | 説明 |
|---|---|
result |
任意の型の値のテンサー |
sdy.sharding_group(sdy::ShardingGroupOp)
グループ内のテンソルが同じシャーディングを持つように制約します。
構文:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
このオペレーションは、シャーディング グループ(同じシャーディングが適用されるテンソルのグループ)にテンソルを割り当てるインターフェースを提供します。伝播中、1 つのグループ要素がシャーディングされるとすぐに、他のすべてのメンバーがまったく同じ方法でシャーディングされます。このオペレーションは、引数のグループ ID を受け取り、結果を返しません。代わりに、内部シャーディング グループの表現を変更して、指定された ID のグループに入力テンソルを追加します。
インターフェース: InferTypeOpInterface
属性:
| 属性 | MLIR タイプ | 説明 |
|---|---|---|
group_id | ::mlir::IntegerAttr | 64 ビット符号なし整数属性 |
オペランド:
| オペランド | 説明 |
|---|---|
input |
任意の型の値のランク付けされたテンソル |
属性
AllToAllParamAttr
オールツーオール パラメータ
構文:
#sdy.all_to_all_param<
::llvm::ArrayRef<AxisRefAttr>, # axes
int64_t, # src_dim
int64_t # tgt_dim
>
アベイルズを実行する軸とソース/ターゲット ディメンションを含むタプル。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| 軸 | ::llvm::ArrayRef<AxisRefAttr> |
アベイルズ マトリックスのすべての要素を計算する軸 |
| src_dim | int64_t |
ソース ディメンションのインデックス |
| tgt_dim | int64_t |
ターゲット ディメンションのインデックス |
AlltoAllParamListAttr
オールツーオール パラメータのリスト
構文:
#sdy.all_to_all_param_list<
::llvm::ArrayRef<AllToAllParamAttr> # value
>
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| 値 | ::llvm::ArrayRef<AllToAllParamAttr> |
AxisRefAttr
完全な軸または分割サブ軸への参照
構文:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
制約:
nameは、バウンドMeshAttrに存在している必要があります。sub_axis_infoが存在する場合は、SubAxisInfoAttrの制約を満たす必要があります。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| name | ::llvm::StringRef |
この軸の名前 |
| sub_axis_info | SubAxisInfoAttr |
サブ軸の場合の追加情報 |
AxisRefListAttr
軸参照のリスト
構文:
#sdy.axis_ref_list<
::llvm::ArrayRef<AxisRefAttr> # value
>
制約:
valueの要素はAxisRefAttrの制約を満たす必要があります。- 重複する軸参照や重複するサブ軸はありません。
- 隣接する 2 つの軸参照が、同じ完全な軸の連続するサブ軸ではない。つまり、1 つのサブ軸または完全な軸に統合できる。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| 値 | ::llvm::ArrayRef<AxisRefAttr> |
DimMappingAttr
ディメンションのファクタ インデックスのリスト
空のリストは、これが null マッピングであることを示します(これは * で解析/出力されます)。つまり、ディメンションがどのファクタにもマッピングされていないことを示します。
制約:
- 少なくとも 1 つのファクタ インデックスがあります。
- ファクタ インデックスは [0,
$factor_sizes] の範囲内にする必要があります。 - 複数の要因がある場合は、どの要因もサイズ 1 にすることはできません。
- 重複するファクタ インデックスはありません。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| factor_indices | ::llvm::ArrayRef<int64_t> |
このディメンションがマッピングされているファクタ |
DimensionShardingAttr
ディメンション シャーディング
テンソル ディメンションをメジャーからマイナーにシャーディングする軸名のリスト、ディメンションをさらにシャーディングできるかどうかを示すブール値、シャーディング伝播時に考慮されるこのディメンション シャーディングのパラメータの優先度を示す整数(省略可)。優先度はユーザー シャーディング アノテーションから取得され、値が小さいほど優先度が高くなります。アノテーションに優先度が指定されていない場合は、最も高い優先度が想定されます。
制約:
axesの要素は、AxisRefListAttrに記載されている制約を満たす必要があります。- ディメンション シャーディングに優先度がある場合:
- 優先度は 0 以上です。
- 閉じているディメンションには、少なくとも 1 つの軸があります。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| 軸 | ::llvm::ArrayRef<AxisRefAttr> |
軸の参照 |
| is_closed | bool |
このディメンションをさらにシャーディングできないかどうか |
| priority | std::optional<int64_t> |
ユーザー優先度ベースの伝播で使用される優先度 |
ListOfAxisRefListsAttr
軸リファレンス リストのリスト
構文:
#sdy.list_of_axis_ref_lists<
::llvm::ArrayRef<AxisRefListAttr> # value
>
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| 値 | ::llvm::ArrayRef<AxisRefListAttr> |
ManualAxesAttr
ManualComputationOp が手動である軸のリスト
構文:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| 値 | ::llvm::ArrayRef<StringAttr> |
MeshAttr
軸のメッシュとデバイスのリスト
構文:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
メッシュは、軸のリストと、デバイスの順序を指定するデバイス ID のリスト(省略可)です。
軸のリストが空の場合、メッシュにはサイズ 1 の無名の軸が暗黙的に設定されます。この場合、デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは [0] になります。デバイス ID リストが指定されている場合は、正の値の整数を 1 つ含める必要があります。これを最大シャーディング ケースと呼びます。
最大シャーディング以外のすべてのケースで、デバイス ID リストが指定されている場合、軸サイズの積はデバイス数と一致する必要があります。デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは iota(product(axes)) です。簡素化のため、iota(product(axes)) と同じデバイス ID リストの指定も禁止されています。この場合、デバイス ID リストは指定しないでください。
メッシュの例を次に示します。
- 空のメッシュは、伝播中に置き換え可能なプレースホルダ メッシュを表します。<[]>
- 名前のない軸と明示的なデバイス ID を持つメッシュ。通常は最大シャーディングを表すために使用されます。<[], device_ids=[3]>
- 2 つの軸と暗黙のデバイス ID を持つメッシュ iota(6): <["a"=2, "b"=3]>
- 2 つの軸と、デバイスの順序を指定する明示的なデバイス ID を持つメッシュ: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
制約:
axes内の要素の名前を重複させることはできません。device_idsが指定されている場合:- 軸サイズの積はデバイスの数と一致している必要があります。
- 要素はすべて正の値にする必要があります。
device_idsはiota(product(axis_sizes))と等しくしてはなりません。- 並べ替えられた
device_idsはiota(product(axis_sizes))である必要があります。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| 軸 | ::llvm::ArrayRef<MeshAxisAttr> |
メッシュ軸 |
| device_ids | ::llvm::ArrayRef<int64_t> |
明示的なデバイスの順序付けまたは最大デバイス ID |
MeshAxisAttr
メッシュ内の名前付き軸
構文:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| name | ::llvm::StringRef |
name |
| サイズ | int64_t |
この軸のサイズ |
OpShardingRuleAttr
オペレーションをパーティショニングする方法を指定します。
構文:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
::llvm::ArrayRef<int64_t>, # reduction_factors
::llvm::ArrayRef<int64_t>, # need_replication_factors
::llvm::ArrayRef<int64_t>, # permutation_factors
::llvm::ArrayRef<int64_t>, # blocked_propagation_factors
bool # is_custom_rule
>
シャーディング ルールは、オペレーションのさまざまなプロパティ(属性、オペランドの形状、結果の形状など)に応じてオペレーションをパーティショニングする方法を指定するものです。次に例を示します。
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
サイズ 1 の因子はシャーディングできませんが、この因子は完全性のために許可されています。これは、点オペレーションなどの多くのオペレーションで、オペランドと結果にわたって対応するサイズ 1 のディメンションがあるためです。
要因の種類:
reduction_factorsには、減算が必要な要素のインデックス(ドット演算の収縮ディメンションなど)が含まれます。need_replication_factorsには、完全なレプリケーションを必要とする要素のインデックス(並べ替えオペレーションの並べ替え済みディメンションなど)が含まれます。permutation_factorsには、シャーディングされている場合に集約並べ替えを必要とする要素のインデックス(パディング オペレーションのパディング ディメンションなど)が含まれます。- 他のすべての要素はパススルー要素と見なされます。つまり、マッピングされているすべてのテンソル間で同じ方法でシャーディングされている場合、通信を必要としない要素です。
blocked_propagation_factors には、シャーディングが伝播されない要素が含まれています。要素タイプとは直交しています。つまり、ブロックされた伝播要因は、任意の要因タイプにすることができます。
is_custom_rule は、これがユーザーが定義したルールかどうかを記述します。ユーザーは、カスタム呼び出しのシャーディング ルールを定義したり、標準オペレーションの事前定義済みシャーディング ルールを上書きしたりできます。カスタムルールは常に保持され、削除されることはありません。
制約:
- オペランド/結果のマッピングの数は、オペレーションのオペランド/結果の数と一致している必要があります。
- マッピングが 1 つ以上あります(オペランド/結果のないオペレーションのルールは設定できません)。
- 各
TensorMappingAttrのランクは、対応するテンソル型のランクと一致します。 - 各要因グループ(
reduction_factors、need_replication_factors、permutation_factors)の場合:- 要素は [0,
$factor_sizes] の範囲内である必要があります。 - 各グループ内およびグループ間で重複する因子インデックスがない。
- 要素は [0,
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| factor_sizes | ::llvm::ArrayRef<int64_t> |
このルール内のすべての要素のサイズ |
| operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
オペランドのマッピング |
| result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
結果のマッピング |
| reduction_factors | ::llvm::ArrayRef<int64_t> |
削減が必要な要因 |
| need_replication_factors | ::llvm::ArrayRef<int64_t> |
完全なレプリケーションを必要とする要因 |
| permutation_factors | ::llvm::ArrayRef<int64_t> |
集約並べ替えを必要とする要因 |
| blocked_propagation_factors | ::llvm::ArrayRef<int64_t> |
シャーディングが伝播されない要素 |
| is_custom_rule | bool |
ルールが stablehlo.custom_call 用かどうか |
SubAxisInfoAttr
このサブ軸が完全な軸からどのように派生しているかに関する情報
構文:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
完全な軸を n 個のサブ軸に分割すると、軸は [k_1,...,k_n] に再形成されます。i 番目のサブ軸は、左側のすべての軸サイズ m=prod(k_1,...,k_(i-1))(前サイズ)とサイズ k_i の積で表すことができます。したがって、サブ軸情報属性にはこれらの 2 つの数値が保持され、次のように表されます。プリサイズ m とサイズ k の場合は (m)k です。
制約:
pre-sizeは 1 以上です。sizeは 1 より大きい。pre-sizeは、全体の軸のサイズを分割する必要があります。つまり、pre-sizeとsizeの両方が全体の軸のサイズを分割し、サブ軸が全体の軸を超えないようにします。- サブ軸のサイズが対応するフル軸のサイズと等しくない。この場合は、フル軸を使用する必要があります。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| pre_size | int64_t |
このサブ軸の左側のサブ軸のサイズの積 |
| サイズ | int64_t |
このサブ軸のサイズ |
TensorMappingAttr
テンソルの各ディメンションの因子マッピング。
構文:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
制約:
dim_mappingsの要素はDimMappingAttrの制約を満たす必要があります。- ディメンション間で重複するファクタ インデックスがない。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
ディメンション マッピング |
TensorShardingAttr
Tensor シャーディング
構文:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
テンソル シャーディングは特定のメッシュにバインドされ、そのメッシュの軸名のみを参照できます。ディメンション シャーディングは、テンソルの各ディメンションで、どの軸(またはサブ軸)に沿ってメジャーからマイナーにシャーディングされているかを示します。ディメンションをシャーディングしない他の軸はすべて、暗黙的または明示的に(複製された軸のリストに含まれている場合)複製されます。
このシャーディングがバインドされているメッシュは、対応する MeshOp シンボルを参照するシンボル名、またはインライン化された MeshAttr で指定できます。
制約:
dim_shardingsの要素は、DimensionShardingAttrに記載されている制約を満たす必要があります。replicated_axesの要素は、AxisRefListAttrに記載されている制約を満たす必要があります。- 対応するテンソル型が
ShapedTypeでない場合は、シャーディングがランク 0 で、複製された軸がない必要があります。 - テンソルはランクを持つ必要があります。
- ディメンション シャーディングの数は、テンソルのランクと同じです。
- サイズ 0 のディメンションはシャーディングされません。
replicated_axes内のアイテムはmesh_or_refを基準に並べ替えられます(AxisRefAttr::getMeshComparatorを参照)。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| mesh_or_ref | ::mlir::Attribute |
メッシュ属性またはフラットメッシュ シンボル参照属性 |
| dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
ディメンション シャーディング |
| replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
軸の参照 |
TensorShardingPerValueAttr
オペランド/オペレーションの結果ごとのテンソル シャーディング
構文:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
TensorShardingAttr のリスト(オペレーションの各オペランドまたは結果に 1 つ)。
制約:
shardingsの要素はTensorShardingAttrの制約を満たす必要があります。
パラメータ:
| パラメータ | C++ 型 | 説明 |
|---|---|---|
| shardings | ::llvm::ArrayRef<TensorShardingAttr> |
値ごとのシャーディング |
列挙型
PropagationDirection
伝播方向の列挙型
Cases:
| 記号 | 値 | 文字列 |
|---|---|---|
| なし | 0 |
なし |
| 転送 | 1 |
転送 |
| BACKWARD | 2 |
BACKWARD |
| 両方 | 3 |
両方 |