「sdy」方言

Shardy(SDY)方言は、軸ベースのテンソル シャーディング表現と、シャーディングをテンソルに接続するための追加の API コンポーネントを定義します。

運用

sdy.all_gather(sdy::AllGatherOp)

軸に沿ってオールギャザー通信を実行する

構文:

operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

gathering_axes で指定された軸に沿ってテンソルのチャンクを収集します。

gathering_axes は、軸のリストです。外側のリストはテンソルのディメンションを超えています。各内部リストには、各ディメンションで個別の集計を実行する軸を指定します。これは、オペランド(tensor)のシャーディングに適用され、結果(out_sharding)のシャーディングが取得されます。

out_sharding は、結果のシャーディングを決定するために使用されません。代わりに、結果のシャーディングはオペランドと gathering_axes のシャーディングによって決定され、out_sharding はこの推定シャーディングと一致する必要があります。

例:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>

制約:

  • Sdy_CollectiveOpInterface に記載されている制約を満たす必要があります。
  • gathering_axes の要素は、AxisRefListAttr に記載されている制約を満たす必要があります。
  • オペランド シャーディングに gathering_axes を適用すると、out_sharding が得られます。

特性: SameOperandsAndResultType

インターフェース: InferTypeOpInterfaceSdy_CollectiveOpInterface

属性:

属性MLIR タイプ説明
gathering_axes::mlir::sdy::ListOfAxisRefListsAttr軸参照リストのリスト
out_sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
tensor 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンサー

sdy.all_reduce(sdy::AllReduceOp)

軸に沿ってオール リデュース通信を実行する

構文:

operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

reduction_axes で指定された軸に沿ってテンソルのチャンクを減算します。reduction_axes の順序は結果に影響しませんが、対応するレプリカ グループの順序に影響する可能性があります。

制約:

  • Sdy_CollectiveOpInterface に記載されている制約を満たす必要があります。
  • reduction_axesAxisRefListAttr に記載されている制約を満たしている必要があります。
  • reduction_axes はオペランドのシャーディング軸と重複してはいけません。

特性: SameOperandsAndResultType

インターフェース: CollectiveOpInterfaceInferTypeOpInterface

属性:

属性MLIR タイプ説明
reduction_axes::mlir::sdy::AxisRefListAttr軸参照のリスト
out_sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
tensor 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンサー

sdy.all_slice(sdy::AllSliceOp)

軸に沿って動的スライス オペレーションを実行する

構文:

operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

slicing_axes で指定された軸に沿ってテンソルのチャンクをスライスします。sdy.all_slicesdy.all_gather の間に代数的双対性があります。

slicing_axes は、軸のリストです。外側のリストはテンソルのディメンションを超えています。各内部リストには、各ディメンションでスライスを実行する軸を指定します。これはオペランド(tensor)のシャーディングに適用され、結果(out_sharding)のシャーディングが取得されます。

out_sharding は、結果のシャーディングを決定するために使用されません。代わりに、結果のシャーディングはオペランドと slicing_axes のシャーディングによって決定され、out_sharding はこの推定シャーディングと一致する必要があります。

例:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>

制約:

  • slicing_axes の要素は、AxisRefListAttr に記載されている制約を満たす必要があります。
  • Sdy_CollectiveOpInterface に記載されている制約を満たす必要があります。
  • オペランド シャーディングに slicing_axes を適用すると、out_sharding が得られます。

特性: SameOperandsAndResultType

インターフェース: CollectiveOpInterfaceInferTypeOpInterface

属性:

属性MLIR タイプ説明
slicing_axes::mlir::sdy::ListOfAxisRefListsAttr軸参照リストのリスト
out_sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
tensor 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンサー

sdy.all_to_all(sdy::AllToAllOp)

軸に沿ってオールツーオール通信を実行する

構文:

operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

このオペレーションは、パラメータリスト内の(axes、src_dim、tgt_dim)タプルごとに、テンソルのチャンクをディメンション tgt_dimaxes で指定された軸に沿ってスライスし、それらのチャンクを軸に沿って分散し、ディメンション src_dim に沿って連結します。

このオペレーションは、基本的に src_dimaxes に沿ったオールギャザーと、tgt_dimaxes に沿ったオールスライスの組み合わせです。つまり、入力テンソルの軸シャーディング ディメンション src_dim の接尾辞が、出力テンソルの軸シャーディング ディメンション tgt_dim に追加されます。

オールツーオール シャーディングはオペランド(tensor)のシャーディングに適用され、結果のシャーディング(out_sharding)が取得されます。

out_sharding は、結果のシャーディングを決定するために使用されません。代わりに、結果のシャーディングはオペランドのシャーディング(src_dimtgt_dimaxes)によって決定され、out_sharding はこの推定シャーディングと一致する必要があります。

例:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>

制約:

  • Sdy_CollectiveOpInterface に記載されている制約を満たす必要があります。
  • パラメータ リストを空にすることはできません。
  • params のパラメータごとに、次の操作を行います。
    • axes の要素は AxisRefAttr の制約を満たす必要があります。
    • src_dimtgt_dim は有効なディメンション(正でテンソルのランクより小さい)にする必要があります。
    • src_dim または tgt_dim は、すべてのパラメータで一意である必要があります。
    • src_dim は、すべてのパラメータで昇順に並べ替える必要があります。
  • オペランド シャーディングで axessrc_dim から tgt_dim に移動すると、out_sharding が返されます。

特性: SameOperandsAndResultType

インターフェース: InferTypeOpInterfaceSdy_CollectiveOpInterface

属性:

属性MLIR タイプ説明
params::mlir::sdy::AlltoAllParamListAttrすべてのパラメータのリスト
out_sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
tensor 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンサー

sdy.collective_permute(sdy::CollectivePermuteOp)

軸を置き換えるために集約型の並べ替え通信を実行する

構文:

operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

入力テンソルのチャンクを各デバイスから別のデバイスに送信し、テンソルをシャーディングする軸を並べ替えまたは置き換えます。

集約シャッフルでは、各ディメンションが以前と同じようにシャーディングされるように入力シャーディングを変換できます。つまり、サイズの積が以前にテンソルをシャーディングした軸の積と一致する軸に沿ってシャーディングする必要があります。

これは、単一のディメンションまたは異なるディメンション間で軸の順序を変更したり、シャーディングされた軸を複製された軸と入れ替えたりする場合に便利です。

次の例では、シャーディングされたテンソルのサイズは tensor<1x4x2xf32> で、これは集約並べ替えによって保持されます。

例:

sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>

制約:

  • Sdy_CollectiveOpInterface に記載されている制約を満たす必要があります。
  • 入力シャーディングと出力シャーディングのメッシュが異なる場合は、これらのメッシュの軸が完全に同じで、デバイス ID の順序が異なる必要があります。
  • 各ディメンションで、out_sharding のシャーディング軸サイズの積は、対応するオペランド ディメンション シャーディングと一致している必要があります。

特性: SameOperandsAndResultType

インターフェース: CollectiveOpInterfaceInferTypeOpInterface

属性:

属性MLIR タイプ説明
out_sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
tensor 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンサー

sdy.constant(sdy::ConstantOp)

定数オペレーション

定数 value から output テンソルを生成します。

参照: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant

例:

%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>

特性: AlwaysSpeculatableImplTrait

インターフェース: ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

属性:

属性MLIR タイプ説明
value::mlir::ElementsAttr定数ベクトル/テンソル属性

結果:

結果 説明
output 任意の型の値の静的形状テンソル

sdy.data_flow_edge(sdy::DataFlowEdgeOp)

データフロー エッジ演算。

構文:

operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)

ある演算 X のデータフロー エッジは、一連のソース(それぞれ X のオペランドまたは X のブロック終端子のオペランド)と一連のターゲット(それぞれ X の結果または X のブロック引数)間のブリッジを定義します。これにより、すべてのソースとターゲットが同じ方法でシャーディングされるようになります。

1 つのオペレーションには、互いに直交する複数のデータフロー エッジを含めることができます。

次に例を示します。

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

この while オペレーションには n 個のデータフロー エッジがあり、i 番目のデータフロー エッジはソース x_ireturn_value_i とターゲット y_ipred_arg_ibody_arg_i の間にあります。

sdy.data_flow_edge は、エッジのオーナー(任意のターゲットですが、ブロック引数ではなく op 結果が望ましい)を入力として受け取ります。この値は他の用途に使用しないでください。このオペレーションは、元々使用されていない入力を受け取ることができるため、純粋ではありません。

sdy.data_flow_edge には、エッジのすべてのターゲットに対してオプションのシャーディングも保持されます。このシャーディングは、伝播中にターゲットのシャーディング(接続可能であれば)ではなく、更新する必要があります。これは、オペレーションにエッジが多数ある場合に便利です。次のように、はるかに効率的です。

  • 各エッジを個別に伝播します。
  • すべてのターゲットを一度に更新するのではなく、各エッジのシャーディングを個別に更新します(たとえば、オペレーションに結果シャーディング用の不変の TensorShardingPerValueAttr が 1 つあります)。
  • ソースのシャーディングが変更されたときに、各エッジを個別にワークリストに追加します。

伝播では、sdy.data_flow_edge のすべてのソースとターゲット間でシャーディングが伝播されます。これは、ソースがオペランド、ターゲットが結果、ID が sdy.op_sharding_rule の通常のオペレーションの場合と同じです。つまり、順方向の伝播はソースからターゲットへの伝播であり、逆方向の伝播はターゲットからソースへの伝播です。

sdy.data_flow_edge の入力を SdyDialect オペレーションで定義することは許可されていないため、sdy.sharding 属性が登録されていないオペレーションで定義されていると想定できます。

特性: SameOperandsAndResultType

インターフェース: InferTypeOpInterface

属性:

属性MLIR タイプ説明
sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
input 任意の型の値の形状

結果:

結果 説明
result 任意の型の値の形状

sdy.manual_computation(sdy::ManualComputationOp)

手動集約を使用したマルチデバイス並列オペレーション

構文:

operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)

明示的な集合を使用してデバイスごとのローカルコードで記述されたリージョンにジャンプします。ここで、論理シェイプはデバイスごとのローカル物理バッファ シェイプと一致し、集合は物理的なクロスデバイス通信と完全に一致します。

ボディは manual_axes に対してローカルです。伝播は、manual_axes リストにない自由軸のボディを介して行われます。

制約:

  • in_shardingsout_shardings の要素は、TensorShardingAttr に記載されている制約を満たす必要があります。
  • OP リージョンのグローバル テンソルとローカル テンソルの入力/出力の数が一致している必要があります。
  • 手動軸は、各ディメンション シャーディング内のフリー軸の前に配置する必要があります。
  • 手動軸にはパディングを追加できません。つまり、ディメンションのサイズは、対応する手動軸のサイズで割り切れる必要があります。
  • op リージョン引数/結果のグローバル シェイプとローカル シェイプは一致している必要があります。
  • 手動で分割された軸はありません。

特性: IsolatedFromAboveRecursiveMemoryEffectsSingleBlockImplicitTerminator<ReturnOp>SingleBlock

インターフェース: ShardableDataFlowOpInterface

属性:

属性MLIR タイプ説明
in_shardings::mlir::sdy::TensorShardingPerValueAttrオペレーションのオペランド/結果ごとのテンソル シャーディング
out_shardings::mlir::sdy::TensorShardingPerValueAttrオペレーションのオペランド/結果ごとのテンソル シャーディング
manual_axes::mlir::sdy::ManualAxesAttrManualComputationOp が手動である軸のリスト

オペランド:

オペランド 説明
tensors 任意の型の値のランク付けされたテンソルの可変引数

結果:

結果 説明
results 任意の型の値のランク付けされたテンソルの可変引数

sdy.mesh(sdy::MeshOp)

名前付きメッシュ

構文:

operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict

新しい名前付きメッシュを定義します。モジュール内のすべてのメッシュのデバイス数は同じである必要があります(device_id が 1 つのメッシュを除く)。メッシュは、モジュールの SymbolTable に表示される Symbol オペレーションで、その name によって参照できます。

特性: HasParent<ModuleOp>

インターフェース: Symbol

属性:

属性MLIR タイプ説明
sym_name::mlir::StringAttrstring 属性
mesh::mlir::sdy::MeshAttr軸のメッシュとデバイスのリスト

sdy.named_computation(sdy::NamedComputationOp)

名前付き計算オペレーション

構文:

operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)

計算(オペレーションのブロック)をグループ化し、名前を付けます。すべてがインライン化されているかのように、リージョン内外への伝播が行われます。

これは、呼び出し命令を介して他の関数に伝播する処理に使用できます。Shardy を使用する場合は、呼び出しオペレーションを sdy.named_computation オペレーションに変換し、呼び出された関数本体を named_computation 本体に複製/コピーするインポート/エクスポート パスを記述する必要があります。

リージョン内の各ブロック引数と戻り値の型は、オペランドの型とオペレーションの結果型と同じである必要があります。

例:

%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

特性: IsolatedFromAboveRecursiveMemoryEffectsRecursivelySpeculatableImplTraitSingleBlockImplicitTerminator<ReturnOp>SingleBlock

インターフェース: ConditionallySpeculatableInferTypeOpInterfaceShardableDataFlowOpInterface

属性:

属性MLIR タイプ説明
name::mlir::StringAttrstring 属性
in_shardings::mlir::sdy::TensorShardingPerValueAttrオペレーションのオペランド/結果ごとのテンソル シャーディング
out_shardings::mlir::sdy::TensorShardingPerValueAttrオペレーションのオペランド/結果ごとのテンソル シャーディング

オペランド:

オペランド 説明
operands 任意の型の可変引数

結果:

結果 説明
«unnamed» 任意の型の可変引数

sdy.propagation_barrier(sdy::PropagationBarrierOp)

伝播バリアのオペレーション

構文:

operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)

このオペレーションは、入力として受け取った値と同じ値を出力する、ID オペレーションのように動作します。ただし、伝播に関しては、特定の方向にのみ伝播が許可されます。

これにより、バリア演算の結果とそのオペランドの使用間でシャーディングが伝播されるのを防ぐことができます。

  • FORWARD は、シャーディングがオペランドから結果にのみ流れることを意味します。
  • BACKWARD は、シャーディングが結果からオペランドにのみ流れることを意味します。
  • NONE は、このオペレーションを介してシャーディングを伝播できないことを意味します。
  • このオペレーションは冗長であるため、BOTH を指定できません。

特性: AlwaysSpeculatableImplTraitSameOperandsAndResultType

インターフェース: ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

属性:

属性MLIR タイプ説明
allowed_direction::mlir::sdy::PropagationDirectionAttr伝播方向の列挙型

オペランド:

オペランド 説明
input 任意の型の値のランク付けされたテンソル

結果:

結果 説明
result 任意の型の値のランク付けされたテンソル

sdy.reshard(sdy::ReshardOp)

テンソルを別のシャーディングに再シャーディングする

構文:

operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)

指定されたシャーディングで入力テンソルをシャーディングします。これは、入力テンソルの既存のシャーディングとは異なります。

ShardingConstraintOp と ReshardOp はどちらも、シャーディングをテンソルに接続します。有効期間は次のとおりです。

  1. シャーディング伝播の前に、ShardingConstraintOp がユーザーによって追加されます。
  2. シャーディング伝播では ShardingConstraintOp が使用されます。シャーディング伝播の結果に ShardingConstraintOp はありません。代わりに、必要に応じて ReshardOp を追加できます。
  3. パーティショナーは、ReshardOp を集約オペレーション(または ID オペレーション)に変換します。パーティショナーの結果に ReshardOp が含まれていないこと。

// TODO(b/331680067). 冗長な // シャード変更オペレーションを削除するように正規化パターンを追加しました。

特性: AlwaysSpeculatableImplTraitSameOperandsAndResultType

インターフェース: ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

属性:

属性MLIR タイプ説明
sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
input 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンサー

sdy.return(sdy::ReturnOp)

sdy.return オペレーションは、sdy リージョンベースのオペレーションと他の Shardy リージョンベースのオペレーションに接続されているリージョンを終了します。可変長です。引数として、型が任意(ただし同じ種類、例: AnyTensor)の値のリストを受け取るため、Shardy IR スタックのさまざまなレベルで再利用できます。

構文:

operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?

特性: AlwaysSpeculatableImplTraitTerminator

インターフェース: ConditionallySpeculatableNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

オペランド:

オペランド 説明
results 任意の型の可変引数

sdy.sharding_constraint(sdy::ShardingConstraintOp)

テンソルを指定されたシャーディングに制約する

構文:

operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)

中間テンソル(matmul の結果など)にシャーディングをアタッチして、そのテンソルまたはその使用のサブセットをシャーディングする方法を示します。

シャーディングにオープンなディメンションと制約のない軸がある場合、テンソルをオープンなディメンションに沿ってさらにシャーディングできます。

このオペレーションは次のいずれかです。

  • 使用されていない(ダングリング)- つまり、適用されたシャーディングは、入力テンソル自体がシャーディングされる方法です。
  • 使用がある - つまり、適用されたシャーディングは、シャーディング制約オペレーションの使用方法をシャーディングする方法ですが、入力テンソルの他の使用方法は異なるシャーディングを持つ場合があります(入力テンソルに他の使用方法がない場合は、使用なしの場合と同じ動作になります)。

特性: SameOperandsAndResultType

インターフェース: InferTypeOpInterface

属性:

属性MLIR タイプ説明
sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
input 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンサー

sdy.sharding_group(sdy::ShardingGroupOp)

グループ内のテンソルが同じシャーディングを持つように制約します。

構文:

operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)

このオペレーションは、シャーディング グループ(同じシャーディングが適用されるテンソルのグループ)にテンソルを割り当てるインターフェースを提供します。伝播中、1 つのグループ要素がシャーディングされるとすぐに、他のすべてのメンバーがまったく同じ方法でシャーディングされます。このオペレーションは、引数のグループ ID を受け取り、結果を返しません。代わりに、内部シャーディング グループの表現を変更して、指定された ID のグループに入力テンソルを追加します。

インターフェース: InferTypeOpInterface

属性:

属性MLIR タイプ説明
group_id::mlir::IntegerAttr64 ビット符号なし整数属性

オペランド:

オペランド 説明
input 任意の型の値のランク付けされたテンソル

属性

AllToAllParamAttr

オールツーオール パラメータ

構文:

#sdy.all_to_all_param<
  ::llvm::ArrayRef<AxisRefAttr>,   # axes
  int64_t,   # src_dim
  int64_t   # tgt_dim
>

アベイルズを実行する軸とソース/ターゲット ディメンションを含むタプル。

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<AxisRefAttr> アベイルズ マトリックスのすべての要素を計算する軸
src_dim int64_t ソース ディメンションのインデックス
tgt_dim int64_t ターゲット ディメンションのインデックス

AlltoAllParamListAttr

オールツーオール パラメータのリスト

構文:

#sdy.all_to_all_param_list<
  ::llvm::ArrayRef<AllToAllParamAttr>   # value
>

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<AllToAllParamAttr>

AxisRefAttr

完全な軸または分割サブ軸への参照

構文:

#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>

制約:

  • name は、バウンド MeshAttr に存在している必要があります。
  • sub_axis_info が存在する場合は、SubAxisInfoAttr の制約を満たす必要があります。

パラメータ:

パラメータ C++ 型 説明
name ::llvm::StringRef この軸の名前
sub_axis_info SubAxisInfoAttr サブ軸の場合の追加情報

AxisRefListAttr

軸参照のリスト

構文:

#sdy.axis_ref_list<
  ::llvm::ArrayRef<AxisRefAttr>   # value
>

制約:

  • value の要素は AxisRefAttr の制約を満たす必要があります。
  • 重複する軸参照や重複するサブ軸はありません。
  • 隣接する 2 つの軸参照が、同じ完全な軸の連続するサブ軸ではない。つまり、1 つのサブ軸または完全な軸に統合できる。

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<AxisRefAttr>

DimMappingAttr

ディメンションのファクタ インデックスのリスト

空のリストは、これが null マッピングであることを示します(これは * で解析/出力されます)。つまり、ディメンションがどのファクタにもマッピングされていないことを示します。

制約:

  • 少なくとも 1 つのファクタ インデックスがあります。
  • ファクタ インデックスは [0, $factor_sizes] の範囲内にする必要があります。
  • 複数の要因がある場合は、どの要因もサイズ 1 にすることはできません。
  • 重複するファクタ インデックスはありません。

パラメータ:

パラメータ C++ 型 説明
factor_indices ::llvm::ArrayRef<int64_t> このディメンションがマッピングされているファクタ

DimensionShardingAttr

ディメンション シャーディング

テンソル ディメンションをメジャーからマイナーにシャーディングする軸名のリスト、ディメンションをさらにシャーディングできるかどうかを示すブール値、シャーディング伝播時に考慮されるこのディメンション シャーディングのパラメータの優先度を示す整数(省略可)。優先度はユーザー シャーディング アノテーションから取得され、値が小さいほど優先度が高くなります。アノテーションに優先度が指定されていない場合は、最も高い優先度が想定されます。

制約:

  • axes の要素は、AxisRefListAttr に記載されている制約を満たす必要があります。
  • ディメンション シャーディングに優先度がある場合:
    • 優先度は 0 以上です。
    • 閉じているディメンションには、少なくとも 1 つの軸があります。

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<AxisRefAttr> 軸の参照
is_closed bool このディメンションをさらにシャーディングできないかどうか
priority std::optional<int64_t> ユーザー優先度ベースの伝播で使用される優先度

ListOfAxisRefListsAttr

軸リファレンス リストのリスト

構文:

#sdy.list_of_axis_ref_lists<
  ::llvm::ArrayRef<AxisRefListAttr>   # value
>

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<AxisRefListAttr>

ManualAxesAttr

ManualComputationOp が手動である軸のリスト

構文:

#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<StringAttr>

MeshAttr

軸のメッシュとデバイスのリスト

構文:

#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>

メッシュは、軸のリストと、デバイスの順序を指定するデバイス ID のリスト(省略可)です。

軸のリストが空の場合、メッシュにはサイズ 1 の無名の軸が暗黙的に設定されます。この場合、デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは [0] になります。デバイス ID リストが指定されている場合は、正の値の整数を 1 つ含める必要があります。これを最大シャーディング ケースと呼びます。

最大シャーディング以外のすべてのケースで、デバイス ID リストが指定されている場合、軸サイズの積はデバイス数と一致する必要があります。デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは iota(product(axes)) です。簡素化のため、iota(product(axes)) と同じデバイス ID リストの指定も禁止されています。この場合、デバイス ID リストは指定しないでください。

メッシュの例を次に示します。

  • 空のメッシュは、伝播中に置き換え可能なプレースホルダ メッシュを表します。<[]>
  • 名前のない軸と明示的なデバイス ID を持つメッシュ。通常は最大シャーディングを表すために使用されます。<[], device_ids=[3]>
  • 2 つの軸と暗黙のデバイス ID を持つメッシュ iota(6): <["a"=2, "b"=3]>
  • 2 つの軸と、デバイスの順序を指定する明示的なデバイス ID を持つメッシュ: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>

制約:

  • axes 内の要素の名前を重複させることはできません。
  • device_ids が指定されている場合:
    • 軸サイズの積はデバイスの数と一致している必要があります。
    • 要素はすべて正の値にする必要があります。
    • device_idsiota(product(axis_sizes)) と等しくしてはなりません。
    • 並べ替えられた device_idsiota(product(axis_sizes)) である必要があります。

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<MeshAxisAttr> メッシュ軸
device_ids ::llvm::ArrayRef<int64_t> 明示的なデバイスの順序付けまたは最大デバイス ID

MeshAxisAttr

メッシュ内の名前付き軸

構文:

#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>

パラメータ:

パラメータ C++ 型 説明
name ::llvm::StringRef name
サイズ int64_t この軸のサイズ

OpShardingRuleAttr

オペレーションをパーティショニングする方法を指定します。

構文:

#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  ::llvm::ArrayRef<int64_t>,   # reduction_factors
  ::llvm::ArrayRef<int64_t>,   # need_replication_factors
  ::llvm::ArrayRef<int64_t>,   # permutation_factors
  ::llvm::ArrayRef<int64_t>,   # blocked_propagation_factors
  bool   # is_custom_rule
>

シャーディング ルールは、オペレーションのさまざまなプロパティ(属性、オペランドの形状、結果の形状など)に応じてオペレーションをパーティショニングする方法を指定するものです。次に例を示します。

%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>

サイズ 1 の因子はシャーディングできませんが、この因子は完全性のために許可されています。これは、点オペレーションなどの多くのオペレーションで、オペランドと結果にわたって対応するサイズ 1 のディメンションがあるためです。

要因の種類:

  • reduction_factors には、減算が必要な要素のインデックス(ドット演算の収縮ディメンションなど)が含まれます。
  • need_replication_factors には、完全なレプリケーションを必要とする要素のインデックス(並べ替えオペレーションの並べ替え済みディメンションなど)が含まれます。
  • permutation_factors には、シャーディングされている場合に集約並べ替えを必要とする要素のインデックス(パディング オペレーションのパディング ディメンションなど)が含まれます。
  • 他のすべての要素はパススルー要素と見なされます。つまり、マッピングされているすべてのテンソル間で同じ方法でシャーディングされている場合、通信を必要としない要素です。

blocked_propagation_factors には、シャーディングが伝播されない要素が含まれています。要素タイプとは直交しています。つまり、ブロックされた伝播要因は、任意の要因タイプにすることができます。

is_custom_rule は、これがユーザーが定義したルールかどうかを記述します。ユーザーは、カスタム呼び出しのシャーディング ルールを定義したり、標準オペレーションの事前定義済みシャーディング ルールを上書きしたりできます。カスタムルールは常に保持され、削除されることはありません。

制約:

  • オペランド/結果のマッピングの数は、オペレーションのオペランド/結果の数と一致している必要があります。
  • マッピングが 1 つ以上あります(オペランド/結果のないオペレーションのルールは設定できません)。
  • TensorMappingAttr のランクは、対応するテンソル型のランクと一致します。
  • 各要因グループ(reduction_factorsneed_replication_factorspermutation_factors)の場合:
    • 要素は [0, $factor_sizes] の範囲内である必要があります。
    • 各グループ内およびグループ間で重複する因子インデックスがない。

パラメータ:

パラメータ C++ 型 説明
factor_sizes ::llvm::ArrayRef<int64_t> このルール内のすべての要素のサイズ
operand_mappings ::llvm::ArrayRef<TensorMappingAttr> オペランドのマッピング
result_mappings ::llvm::ArrayRef<TensorMappingAttr> 結果のマッピング
reduction_factors ::llvm::ArrayRef<int64_t> 削減が必要な要因
need_replication_factors ::llvm::ArrayRef<int64_t> 完全なレプリケーションを必要とする要因
permutation_factors ::llvm::ArrayRef<int64_t> 集約並べ替えを必要とする要因
blocked_propagation_factors ::llvm::ArrayRef<int64_t> シャーディングが伝播されない要素
is_custom_rule bool ルールが stablehlo.custom_call 用かどうか

SubAxisInfoAttr

このサブ軸が完全な軸からどのように派生しているかに関する情報

構文:

#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>

完全な軸を n 個のサブ軸に分割すると、軸は [k_1,...,k_n] に再形成されます。i 番目のサブ軸は、左側のすべての軸サイズ m=prod(k_1,...,k_(i-1))(前サイズ)とサイズ k_i の積で表すことができます。したがって、サブ軸情報属性にはこれらの 2 つの数値が保持され、次のように表されます。プリサイズ m とサイズ k の場合は (m)k です。

制約:

  • pre-size は 1 以上です。
  • size は 1 より大きい。
  • pre-size は、全体の軸のサイズを分割する必要があります。つまり、pre-sizesize の両方が全体の軸のサイズを分割し、サブ軸が全体の軸を超えないようにします。
  • サブ軸のサイズが対応するフル軸のサイズと等しくない。この場合は、フル軸を使用する必要があります。

パラメータ:

パラメータ C++ 型 説明
pre_size int64_t このサブ軸の左側のサブ軸のサイズの積
サイズ int64_t このサブ軸のサイズ

TensorMappingAttr

テンソルの各ディメンションの因子マッピング。

構文:

#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>

制約:

  • dim_mappings の要素は DimMappingAttr の制約を満たす必要があります。
  • ディメンション間で重複するファクタ インデックスがない。

パラメータ:

パラメータ C++ 型 説明
dim_mappings ::llvm::ArrayRef<DimMappingAttr> ディメンション マッピング

TensorShardingAttr

Tensor シャーディング

構文:

#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>   # replicated_axes
>

テンソル シャーディングは特定のメッシュにバインドされ、そのメッシュの軸名のみを参照できます。ディメンション シャーディングは、テンソルの各ディメンションで、どの軸(またはサブ軸)に沿ってメジャーからマイナーにシャーディングされているかを示します。ディメンションをシャーディングしない他の軸はすべて、暗黙的または明示的に(複製された軸のリストに含まれている場合)複製されます。

このシャーディングがバインドされているメッシュは、対応する MeshOp シンボルを参照するシンボル名、またはインライン化された MeshAttr で指定できます。

制約:

  • dim_shardings の要素は、DimensionShardingAttr に記載されている制約を満たす必要があります。
  • replicated_axes の要素は、AxisRefListAttr に記載されている制約を満たす必要があります。
  • 対応するテンソル型が ShapedType でない場合は、シャーディングがランク 0 で、複製された軸がない必要があります。
  • テンソルはランクを持つ必要があります。
  • ディメンション シャーディングの数は、テンソルのランクと同じです。
  • サイズ 0 のディメンションはシャーディングされません。
  • replicated_axes 内のアイテムは mesh_or_ref を基準に並べ替えられます(AxisRefAttr::getMeshComparator を参照)。

パラメータ:

パラメータ C++ 型 説明
mesh_or_ref ::mlir::Attribute メッシュ属性またはフラットメッシュ シンボル参照属性
dim_shardings ::llvm::ArrayRef<DimensionShardingAttr> ディメンション シャーディング
replicated_axes ::llvm::ArrayRef<AxisRefAttr> 軸の参照

TensorShardingPerValueAttr

オペランド/オペレーションの結果ごとのテンソル シャーディング

構文:

#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>

TensorShardingAttr のリスト(オペレーションの各オペランドまたは結果に 1 つ)。

制約:

  • shardings の要素は TensorShardingAttr の制約を満たす必要があります。

パラメータ:

パラメータ C++ 型 説明
shardings ::llvm::ArrayRef<TensorShardingAttr> 値ごとのシャーディング

列挙型

PropagationDirection

伝播方向の列挙型

Cases:

記号 文字列
なし 0 なし
転送 1 転送
BACKWARD 2 BACKWARD
両方 3 両方