Shardy(SDY)言語では、軸ベースのテンソル シャーディング表現と、シャーディングをテンソルに関連付ける追加の API コンポーネントを定義します。
運用
sdy.constant
(sdy::ConstantOp)
定数オペレーション
定数 value
から output
テンソルを生成します。
参照: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
例:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
特性: AlwaysSpeculatableImplTrait
インターフェース: ConditionallySpeculatable
、InferTypeOpInterface
、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
value | ::mlir::ElementsAttr | 定数ベクトル/テンソル属性 |
結果:
結果 | 説明 |
---|---|
output |
任意の型の値のテンサー |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
データフロー エッジ オペレーション。
構文:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
ある演算 X のデータフロー エッジは、一連のソース(それぞれ X のオペランドまたは X のブロック終端子のオペランド)と一連のターゲット(それぞれ X の結果または X のブロック引数)間のブリッジを定義します。これにより、すべてのソースとターゲットが同じ方法でシャーディングされるようにします。
オペレーションには、互いに直交する複数のデータフローのエッジを含めることができます。
例:
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
この while オペレーションには n 個のデータフロー エッジがあり、i 番目のデータフロー エッジはソース x_i
、return_value_i
とターゲット y_i
、pred_arg_i
、body_arg_i
の間にあります。
sdy.data_flow_edge
は、エッジのルート ターゲットを入力として受け取ります(任意のターゲットを使用できますが、ブロック引数ではなくオペレーションの結果を使用することをおすすめします)。このターゲットは他の用途に使用しないでください。このオペレーションは、元々使用されていなかった入力を受け取ることができるため、純粋ではありません。
また、sdy.data_flow_edge
はエッジのすべてのターゲットに対してオプションのシャーディングを保持し、伝播中にターゲットのシャーディング(接続できる場合)ではなくシャーディングを更新する必要があります。これは、オペレーションにエッジが多数ある場合に便利です。次のように、はるかに効率的です。
- 各エッジを個別に伝播します。
- すべてのターゲットを一度に更新するのではなく、各エッジのシャーディングを個別に更新します(たとえば、オペレーションに結果シャーディング用の不変の
TensorShardingPerValueAttr
が 1 つあります)。 - ソースのシャーディングが変更されたときに、各エッジをワークリストに個別に追加します。
伝播では、sdy.data_flow_edge
のすべてのソースとターゲット間でシャーディングが伝播されます。これは、ソースがオペランドで、ターゲットが結果で、ID が sdy.op_sharding_rule
である通常のオペレーションの場合と同じです。つまり、順方向の伝播はソースからターゲットへの伝播であり、逆方向の伝播はターゲットからソースへの伝播です。
sdy.data_flow_edge
の入力を SdyDialect
演算で定義することはできないため、sdy.sharding
属性が未登録の演算によって定義されると想定できます。
特性: SameOperandsAndResultType
インターフェース: InferTypeOpInterface
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
オペランド | 説明 |
---|---|
input |
任意の型の値の形状 |
結果:
結果 | 説明 |
---|---|
result |
任意の型の値の形状 |
sdy.manual_computation
(sdy::ManualComputationOp)
手動グループによるマルチデバイス並列処理オペレーション
構文:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
明示的な集約を使用してデバイスごとのローカルコードで記述されたリージョンにジャンプします。ここでは、論理シェイプがデバイスごとのローカル物理バッファ シェイプと一致し、集約は物理的なクロスデバイス通信と完全に一致します。
本文は manual_axes に基づいてローカルです。伝播は、manual_axes リストにない自由軸のボディを介して行われます。
特性: IsolatedFromAbove
、RecursiveMemoryEffects
、SingleBlockImplicitTerminator<ReturnOp>
、SingleBlock
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 演算のオペランド/結果ごとのテンソルのシャーディング |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 演算のオペランド/結果ごとのテンソルのシャーディング |
manual_axes | ::mlir::sdy::ManualAxesAttr |
オペランド:
オペランド | 説明 |
---|---|
tensors |
任意の型の値のランク付けされたテンソルの可変引数 |
結果:
結果 | 説明 |
---|---|
results |
任意の型の値のランク付けされたテンソルの可変引数 |
sdy.mesh
(sdy::MeshOp)
名前付きメッシュ
構文:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
新しい名前付きメッシュを定義します。モジュール内のすべてのメッシュのデバイス数は同じである必要があります(device_id が 1 つのメッシュを除く)。メッシュは、モジュールの SymbolTable
に表示される Symbol
オペレーションで、その name
によって参照できます。
特性: HasParent<ModuleOp>
インターフェース: Symbol
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
sym_name | ::mlir::StringAttr | 文字列属性 |
mesh | ::mlir::sdy::MeshAttr | 軸のメッシュとデバイスのリスト |
sdy.named_computation
(sdy::NamedComputationOp)
名前付き計算オペレーション
構文:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
計算(オペレーションのブロック)をグループ化し、名前を付けます。伝播は、すべてがインライン化されているかのように、リージョンの内外を流れます。
これは、呼び出し命令を介して他の関数に伝播する処理に使用できます。Shardy を使用する場合は、呼び出しオペレーションを sdy.named_computation
オペレーションに変換し、呼び出された関数の本体を named_computation
の本体に複製/コピーするインポート/エクスポート パスを記述する必要があります。
リージョンでの各ブロック引数の型と返される値の型は、op のオペランドの型と結果の型と同じでなければなりません。
例:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
特性: IsolatedFromAbove
、RecursiveMemoryEffects
、RecursivelySpeculatableImplTrait
、SingleBlockImplicitTerminator<ReturnOp>
、SingleBlock
インターフェース: ConditionallySpeculatable
、ShardableDataFlowOpInterface
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
name | ::mlir::StringAttr | 文字列属性 |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 演算のオペランド/結果ごとのテンソルのシャーディング |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | オペレーションのオペランド/結果ごとのテンソル シャーディング |
オペランド:
オペランド | 説明 |
---|---|
operands |
任意の型の可変引数 |
結果:
結果 | 説明 |
---|---|
「名前なし」 | 任意の型の可変引数 |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
伝播バリア オペレーション
構文:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
このオペレーションは、入力として受け取った値を出力するアイデンティティ オペレーションのように動作します。ただし、伝播に関しては、特定の方向にのみ伝播が許可されます。
これにより、バリア演算の結果とそのオペランドの使用間でシャーディングが伝播されるのを防ぐことができます。
FORWARD
は、シャーディングがオペランドから結果にのみ流れることを意味します。BACKWARD
は、シャーディングが結果からオペランドにのみ流れることを意味します。NONE
は、このオペレーションを介してシャーディングを伝播できないことを意味します。- この op は冗長であるため、
BOTH
を指定できません。
トレイト: AlwaysSpeculatableImplTrait
、Elementwise
、SameOperandsAndResultType
インターフェース: ConditionallySpeculatable
、InferTypeOpInterface
、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | 伝播方向の列挙型 |
オペランド:
オペランド | 説明 |
---|---|
input |
任意の型の値のランク付けされたテンソル |
結果:
結果 | 説明 |
---|---|
result |
任意の型の値のランク付けされたテンソル |
sdy.reshard
(sdy::ReshardOp)
テンソルを別のシャーディングに再シャーディングする
構文:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
指定されたシャーディングで入力テンソルを再シャーディングします。これは、入力テンソルの既存のシャーディングとは異なります。
ShardingConstraintOp と ReshardOp はどちらも、シャーディングをテンソルに接続します。有効期間は次のとおりです。
- シャーディングの伝播の前に、ユーザーによって ShardingConstraintOp が追加されます。
- シャーディング伝播は ShardingConstraintOp を使用します。シャーディングの伝播の結果に ShardingConstraintOp はありません。代わりに、必要に応じて ReshardOp を追加できます。
- パーティショナーは、ReshardOp を集約オペレーション(または ID オペレーション)に変換します。パーティショナーの結果に ReshardOp が含まれていないこと。
// TODO(b/331680067). 重複する // シャード変更オペレーションを削除するために、正規化パターンを追加しました。
特性: AlwaysSpeculatableImplTrait
、Elementwise
、SameOperandsAndResultType
インターフェース: ConditionallySpeculatable
、InferTypeOpInterface
、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
オペランド | 説明 |
---|---|
input |
任意の型の値のテンサー |
結果:
結果 | 説明 |
---|---|
result |
任意の型の値のテンソル |
sdy.return
(sdy::ReturnOp)
sdy.return
オペレーションは、sdy
リージョンベースのオペレーションと他の Shardy リージョンベースのオペレーションにアタッチされているリージョンを終了します。可変長です。引数として、型が任意(ただし同じ種類、例: AnyTensor
)の値のリストを受け取るため、Shardy IR スタックのさまざまなレベルで再利用できます。
構文:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
トレイト: AlwaysSpeculatableImplTrait
、Terminator
インターフェース: ConditionallySpeculatable
、NoMemoryEffect (MemoryEffectOpInterface)
影響: MemoryEffects::Effect{}
オペランド:
オペランド | 説明 |
---|---|
results |
任意の型の可変引数 |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
テンソルを指定されたシャーディングに制約する
構文:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
中間テンソル(matmul の結果など)にシャーディングをアタッチして、そのテンソルまたはその使用のサブセットをシャーディングする方法を示します。
シャーディングにオープンなディメンションと制約のない軸がある場合、テンソルをオープンなディメンションに沿ってさらにシャーディングできます。
この演算では次のいずれかを実行できます。
- 使用されていない(ダングリング)- つまり、接続されたシャーディングは、入力テンソル自体がシャーディングされる方法です。
- 使用あり - 接続されたシャーディングは、シャーディング制約オペレーションの使用方法をシャーディングする方法ですが、入力テンソルの他の使用方法は異なるシャーディングを持つ場合があります(入力テンソルに他の使用方法がない場合は、使用なしの場合と同じ動作になります)。
特性: Elementwise
、SameOperandsAndResultType
インターフェース: InferTypeOpInterface
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | テンソル シャーディング |
オペランド:
オペランド | 説明 |
---|---|
input |
任意の型の値のテンサー |
結果:
結果 | 説明 |
---|---|
result |
任意の型の値のテンソル |
sdy.sharding_group
(sdy::ShardingGroupOp)
シャーディング グループ オペレーション
構文:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
このオペレーションは、シャーディング グループ(同じシャーディングが適用されるテンソルのグループ)にテンソルを割り当てるインターフェースを提供します。伝播中に、1 つのグループ要素がシャーディングされるとすぐに、他のすべてのメンバーがまったく同じようにシャーディングされます。このオペレーションは、引数のグループ ID を受け取り、結果を返しません。代わりに、内部シャーディング グループ表現を変更して、指定された ID のグループに入力テンソルを追加します。
属性:
属性 | MLIR タイプ | 説明 |
---|---|---|
group_id | ::mlir::IntegerAttr | 64 ビットの符号なし整数属性 |
オペランド:
オペランド | 説明 |
---|---|
input |
任意の型の値のランク付けされたテンソル |
属性
AxisRefAttr
完全な軸または分割されたサブ軸への参照
構文:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
name | ::llvm::StringRef |
name |
sub_axis_info | SubAxisInfoAttr |
DimMappingAttr
ディメンションの要素インデックスのリスト
すべての要素インデックスは [0, num_factors] の範囲内にある必要があります。空のリストは、これが null マッピングであることを示します(これは *
で解析/出力されます)。つまり、ディメンションはどの要素にもマッピングされていません。
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
DimensionShardingAttr
ディメンション シャーディング
テンソル ディメンションをメジャーからマイナーにシャーディングする軸名のリスト、ディメンションをさらにシャーディングできるかどうかを示すブール値、シャーディング伝播時に考慮されるこのディメンション シャーディングのパリティを示す整数(省略可)。優先度はユーザー シャーディング アノテーションから取得され、値が小さいほど優先度が高くなります。アノテーションに優先度がない場合、最も高い優先度と見なされます。
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
軸 | ::llvm::ArrayRef<AxisRefAttr> |
軸参照のリスト |
is_closed | bool |
|
priority | std::optional<int64_t> |
ManualAxesAttr
構文:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
値 | ::llvm::ArrayRef<StringAttr> |
MeshAttr
軸のメッシュとデバイスのリスト
構文:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
メッシュは、軸のリストと、デバイスの順序を指定するデバイス ID のリスト(省略可)です。
軸のリストが空の場合、メッシュにはサイズ 1 の無名の軸が暗黙的に設定されます。この場合、デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは [0] になります。デバイス ID リストを指定する場合は、負でない値を 1 つ含む必要があります。これを最大シャーディング ケースと呼びます。
最大シャーディング以外のすべてのケースで、デバイス ID リストが指定されている場合、軸サイズの積はデバイス数と一致する必要があります。デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは iota(product(axes)) です。簡素化のため、iota(product(axes)) と同じデバイス ID リストの指定も禁止されています。この場合、デバイス ID リストは指定しないでください。
メッシュの例を次に示します。
- 空のメッシュは、伝播中に置き換え可能なプレースホルダ メッシュを表します。<[]>
- 名前のない軸と明示的なデバイス ID を持つメッシュ。通常は、最大のシャーディングを表すために使用されます。<[], device_ids=[3]>
- 2 つの軸と暗黙のデバイス ID を持つメッシュ iota(6): <["a"=2, "b"=3]>
- 2 つの軸と、デバイスの順序を指定する明示的なデバイス ID を持つメッシュ: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
軸 | ::llvm::ArrayRef<MeshAxisAttr> |
|
device_ids | ::llvm::ArrayRef<int64_t> |
MeshAxisAttr
メッシュ内の名前付き軸
構文:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
name | ::llvm::StringRef |
name |
サイズ | int64_t |
OpShardingRuleAttr
オペレーションのパーティショニング方法を指定します。
構文:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
bool # is_custom_rule
>
シャーディング ルールは、オペレーションのさまざまなプロパティ(属性、オペランドの形状、結果の形状など)に応じてオペレーションをパーティショニングする方法を指定するものです。次に例を示します。
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
サイズ 1 の因子はシャーディングできませんが、この因子は完全性のために許可されています。これは、点オペレーションなどの多くのオペレーションで、オペランドと結果にわたって対応するサイズ 1 のディメンションがあるためです。
is_custom_rule
は、これが stablehlo.custom_call
オペレーションに対してユーザーが定義したルールかどうかを記述します。パーティショナーはこれらのオペレーションをパーティショニングする方法を知らないので、ユーザーが方法を指定する必要があります。カスタムルールの場合、ルールは常に保持され、削除されることはありません。is_custom_rule
は stablehlo.custom_call
オペレーションでのみ true にできます。
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
|
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
is_custom_rule | bool |
SubAxisInfoAttr
このサブ軸が完全な軸からどのように導出されたかに関する情報
構文:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
完全な軸を n 個のサブ軸に分割すると、軸は [k_1,...,k_n] に再形成されます。i 番目のサブ軸は、左側のすべての軸サイズ m=prod(k_1,...,k_(i-1))
(事前サイズ)とサイズ k_i の積で表すことができます。したがって、サブ軸情報属性にはこれらの 2 つの数値が保持され、次のように表されます。プリサイズ m とサイズ k の場合は (m)k
です。
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
pre_size | int64_t |
|
サイズ | int64_t |
TensorMappingAttr
テンソルの各ディメンションの因子マッピング。
構文:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
TensorShardingAttr
テンソル シャーディング
構文:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
テンソル シャーディングは特定のメッシュにバインドされ、そのメッシュの軸名のみを参照できます。ディメンションのシャーディングでは、テンソルの各次元と、その軸(またはサブ軸)をメジャーからマイナーにシャーディングします。ディメンションをシャーディングしない他の軸はすべて、暗黙的または明示的に(複製された軸のリストに含まれている場合)複製されます。
このシャーディングがバインドされているメッシュは、対応する MeshOp
シンボルを参照するシンボル名、またはインライン化された MeshAttr
で指定できます。
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
メッシュ属性またはフラットメッシュ シンボル参照属性 |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
|
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
軸参照のリスト |
TensorShardingPerValueAttr
オペレーションのオペランド/結果ごとのテンソル シャーディング
構文:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
パラメータ:
パラメータ | C++ 型 | 説明 |
---|---|---|
シャーディング | ::llvm::ArrayRef<TensorShardingAttr> |
列挙型
PropagationDirection
伝播方向の列挙型
Cases:
記号 | 値 | 文字列 |
---|---|---|
なし | 0 |
なし |
転送 | 1 |
転送 |
BACKWARD | 2 |
BACKWARD |
両方 | 3 |
両方 |