「sdy」方言

Shardy(SDY)言語では、軸ベースのテンソル シャーディング表現と、シャーディングをテンソルに関連付ける追加の API コンポーネントを定義します。

運用

sdy.constant(sdy::ConstantOp)

定数オペレーション

定数 value から output テンソルを生成します。

参照: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant

例:

%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>

特性: AlwaysSpeculatableImplTrait

インターフェース: ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

属性:

属性MLIR タイプ説明
value::mlir::ElementsAttr定数ベクトル/テンソル属性

結果:

結果 説明
output 任意の型の値のテンサー

sdy.data_flow_edge(sdy::DataFlowEdgeOp)

データフロー エッジ オペレーション。

構文:

operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)

ある演算 X のデータフロー エッジは、一連のソース(それぞれ X のオペランドまたは X のブロック終端子のオペランド)と一連のターゲット(それぞれ X の結果または X のブロック引数)間のブリッジを定義します。これにより、すべてのソースとターゲットが同じ方法でシャーディングされるようにします。

オペレーションには、互いに直交する複数のデータフローのエッジを含めることができます。

例:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

この while オペレーションには n 個のデータフロー エッジがあり、i 番目のデータフロー エッジはソース x_ireturn_value_i とターゲット y_ipred_arg_ibody_arg_i の間にあります。

sdy.data_flow_edge は、エッジのルート ターゲットを入力として受け取ります(任意のターゲットを使用できますが、ブロック引数ではなくオペレーションの結果を使用することをおすすめします)。このターゲットは他の用途に使用しないでください。このオペレーションは、元々使用されていなかった入力を受け取ることができるため、純粋ではありません。

また、sdy.data_flow_edge はエッジのすべてのターゲットに対してオプションのシャーディングを保持し、伝播中にターゲットのシャーディング(接続できる場合)ではなくシャーディングを更新する必要があります。これは、オペレーションにエッジが多数ある場合に便利です。次のように、はるかに効率的です。

  • 各エッジを個別に伝播します。
  • すべてのターゲットを一度に更新するのではなく、各エッジのシャーディングを個別に更新します(たとえば、オペレーションに結果シャーディング用の不変の TensorShardingPerValueAttr が 1 つあります)。
  • ソースのシャーディングが変更されたときに、各エッジをワークリストに個別に追加します。

伝播では、sdy.data_flow_edge のすべてのソースとターゲット間でシャーディングが伝播されます。これは、ソースがオペランドで、ターゲットが結果で、ID が sdy.op_sharding_rule である通常のオペレーションの場合と同じです。つまり、順方向の伝播はソースからターゲットへの伝播であり、逆方向の伝播はターゲットからソースへの伝播です。

sdy.data_flow_edge の入力を SdyDialect 演算で定義することはできないため、sdy.sharding 属性が未登録の演算によって定義されると想定できます。

特性: SameOperandsAndResultType

インターフェース: InferTypeOpInterface

属性:

属性MLIR タイプ説明
sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
input 任意の型の値の形状

結果:

結果 説明
result 任意の型の値の形状

sdy.manual_computation(sdy::ManualComputationOp)

手動グループによるマルチデバイス並列処理オペレーション

構文:

operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)

明示的な集約を使用してデバイスごとのローカルコードで記述されたリージョンにジャンプします。ここでは、論理シェイプがデバイスごとのローカル物理バッファ シェイプと一致し、集約は物理的なクロスデバイス通信と完全に一致します。

本文は manual_axes に基づいてローカルです。伝播は、manual_axes リストにない自由軸のボディを介して行われます。

特性: IsolatedFromAboveRecursiveMemoryEffectsSingleBlockImplicitTerminator<ReturnOp>SingleBlock

属性:

属性MLIR タイプ説明
in_shardings::mlir::sdy::TensorShardingPerValueAttr演算のオペランド/結果ごとのテンソルのシャーディング
out_shardings::mlir::sdy::TensorShardingPerValueAttr演算のオペランド/結果ごとのテンソルのシャーディング
manual_axes::mlir::sdy::ManualAxesAttr

オペランド:

オペランド 説明
tensors 任意の型の値のランク付けされたテンソルの可変引数

結果:

結果 説明
results 任意の型の値のランク付けされたテンソルの可変引数

sdy.mesh(sdy::MeshOp)

名前付きメッシュ

構文:

operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict

新しい名前付きメッシュを定義します。モジュール内のすべてのメッシュのデバイス数は同じである必要があります(device_id が 1 つのメッシュを除く)。メッシュは、モジュールの SymbolTable に表示される Symbol オペレーションで、その name によって参照できます。

特性: HasParent<ModuleOp>

インターフェース: Symbol

属性:

属性MLIR タイプ説明
sym_name::mlir::StringAttr文字列属性
mesh::mlir::sdy::MeshAttr軸のメッシュとデバイスのリスト

sdy.named_computation(sdy::NamedComputationOp)

名前付き計算オペレーション

構文:

operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)

計算(オペレーションのブロック)をグループ化し、名前を付けます。伝播は、すべてがインライン化されているかのように、リージョンの内外を流れます。

これは、呼び出し命令を介して他の関数に伝播する処理に使用できます。Shardy を使用する場合は、呼び出しオペレーションを sdy.named_computation オペレーションに変換し、呼び出された関数の本体を named_computation の本体に複製/コピーするインポート/エクスポート パスを記述する必要があります。

リージョンでの各ブロック引数の型と返される値の型は、op のオペランドの型と結果の型と同じでなければなりません。

例:

%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

特性: IsolatedFromAboveRecursiveMemoryEffectsRecursivelySpeculatableImplTraitSingleBlockImplicitTerminator<ReturnOp>SingleBlock

インターフェース: ConditionallySpeculatableShardableDataFlowOpInterface

属性:

属性MLIR タイプ説明
name::mlir::StringAttr文字列属性
in_shardings::mlir::sdy::TensorShardingPerValueAttr演算のオペランド/結果ごとのテンソルのシャーディング
out_shardings::mlir::sdy::TensorShardingPerValueAttrオペレーションのオペランド/結果ごとのテンソル シャーディング

オペランド:

オペランド 説明
operands 任意の型の可変引数

結果:

結果 説明
「名前なし」 任意の型の可変引数

sdy.propagation_barrier(sdy::PropagationBarrierOp)

伝播バリア オペレーション

構文:

operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)

このオペレーションは、入力として受け取った値を出力するアイデンティティ オペレーションのように動作します。ただし、伝播に関しては、特定の方向にのみ伝播が許可されます。

これにより、バリア演算の結果とそのオペランドの使用間でシャーディングが伝播されるのを防ぐことができます。

  • FORWARD は、シャーディングがオペランドから結果にのみ流れることを意味します。
  • BACKWARD は、シャーディングが結果からオペランドにのみ流れることを意味します。
  • NONE は、このオペレーションを介してシャーディングを伝播できないことを意味します。
  • この op は冗長であるため、BOTH を指定できません。

トレイト: AlwaysSpeculatableImplTraitElementwiseSameOperandsAndResultType

インターフェース: ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

属性:

属性MLIR タイプ説明
allowed_direction::mlir::sdy::PropagationDirectionAttr伝播方向の列挙型

オペランド:

オペランド 説明
input 任意の型の値のランク付けされたテンソル

結果:

結果 説明
result 任意の型の値のランク付けされたテンソル

sdy.reshard(sdy::ReshardOp)

テンソルを別のシャーディングに再シャーディングする

構文:

operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)

指定されたシャーディングで入力テンソルを再シャーディングします。これは、入力テンソルの既存のシャーディングとは異なります。

ShardingConstraintOp と ReshardOp はどちらも、シャーディングをテンソルに接続します。有効期間は次のとおりです。

  1. シャーディングの伝播の前に、ユーザーによって ShardingConstraintOp が追加されます。
  2. シャーディング伝播は ShardingConstraintOp を使用します。シャーディングの伝播の結果に ShardingConstraintOp はありません。代わりに、必要に応じて ReshardOp を追加できます。
  3. パーティショナーは、ReshardOp を集約オペレーション(または ID オペレーション)に変換します。パーティショナーの結果に ReshardOp が含まれていないこと。

// TODO(b/331680067). 重複する // シャード変更オペレーションを削除するために、正規化パターンを追加しました。

特性: AlwaysSpeculatableImplTraitElementwiseSameOperandsAndResultType

インターフェース: ConditionallySpeculatableInferTypeOpInterfaceNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

属性:

属性MLIR タイプ説明
sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
input 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンソル

sdy.return(sdy::ReturnOp)

sdy.return オペレーションは、sdy リージョンベースのオペレーションと他の Shardy リージョンベースのオペレーションにアタッチされているリージョンを終了します。可変長です。引数として、型が任意(ただし同じ種類、例: AnyTensor)の値のリストを受け取るため、Shardy IR スタックのさまざまなレベルで再利用できます。

構文:

operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?

トレイト: AlwaysSpeculatableImplTraitTerminator

インターフェース: ConditionallySpeculatableNoMemoryEffect (MemoryEffectOpInterface)

影響: MemoryEffects::Effect{}

オペランド:

オペランド 説明
results 任意の型の可変引数

sdy.sharding_constraint(sdy::ShardingConstraintOp)

テンソルを指定されたシャーディングに制約する

構文:

operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)

中間テンソル(matmul の結果など)にシャーディングをアタッチして、そのテンソルまたはその使用のサブセットをシャーディングする方法を示します。

シャーディングにオープンなディメンションと制約のない軸がある場合、テンソルをオープンなディメンションに沿ってさらにシャーディングできます。

この演算では次のいずれかを実行できます。

  • 使用されていない(ダングリング)- つまり、接続されたシャーディングは、入力テンソル自体がシャーディングされる方法です。
  • 使用あり - 接続されたシャーディングは、シャーディング制約オペレーションの使用方法をシャーディングする方法ですが、入力テンソルの他の使用方法は異なるシャーディングを持つ場合があります(入力テンソルに他の使用方法がない場合は、使用なしの場合と同じ動作になります)。

特性: ElementwiseSameOperandsAndResultType

インターフェース: InferTypeOpInterface

属性:

属性MLIR タイプ説明
sharding::mlir::sdy::TensorShardingAttrテンソル シャーディング

オペランド:

オペランド 説明
input 任意の型の値のテンサー

結果:

結果 説明
result 任意の型の値のテンソル

sdy.sharding_group(sdy::ShardingGroupOp)

シャーディング グループ オペレーション

構文:

operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)

このオペレーションは、シャーディング グループ(同じシャーディングが適用されるテンソルのグループ)にテンソルを割り当てるインターフェースを提供します。伝播中に、1 つのグループ要素がシャーディングされるとすぐに、他のすべてのメンバーがまったく同じようにシャーディングされます。このオペレーションは、引数のグループ ID を受け取り、結果を返しません。代わりに、内部シャーディング グループ表現を変更して、指定された ID のグループに入力テンソルを追加します。

属性:

属性MLIR タイプ説明
group_id::mlir::IntegerAttr64 ビットの符号なし整数属性

オペランド:

オペランド 説明
input 任意の型の値のランク付けされたテンソル

属性

AxisRefAttr

完全な軸または分割されたサブ軸への参照

構文:

#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>

パラメータ:

パラメータ C++ 型 説明
name ::llvm::StringRef name
sub_axis_info SubAxisInfoAttr

DimMappingAttr

ディメンションの要素インデックスのリスト

すべての要素インデックスは [0, num_factors] の範囲内にある必要があります。空のリストは、これが null マッピングであることを示します(これは * で解析/出力されます)。つまり、ディメンションはどの要素にもマッピングされていません。

パラメータ:

パラメータ C++ 型 説明
factor_indices ::llvm::ArrayRef<int64_t>

DimensionShardingAttr

ディメンション シャーディング

テンソル ディメンションをメジャーからマイナーにシャーディングする軸名のリスト、ディメンションをさらにシャーディングできるかどうかを示すブール値、シャーディング伝播時に考慮されるこのディメンション シャーディングのパリティを示す整数(省略可)。優先度はユーザー シャーディング アノテーションから取得され、値が小さいほど優先度が高くなります。アノテーションに優先度がない場合、最も高い優先度と見なされます。

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<AxisRefAttr> 軸参照のリスト
is_closed bool
priority std::optional<int64_t>

ManualAxesAttr

構文:

#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<StringAttr>

MeshAttr

軸のメッシュとデバイスのリスト

構文:

#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>

メッシュは、軸のリストと、デバイスの順序を指定するデバイス ID のリスト(省略可)です。

軸のリストが空の場合、メッシュにはサイズ 1 の無名の軸が暗黙的に設定されます。この場合、デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは [0] になります。デバイス ID リストを指定する場合は、負でない値を 1 つ含む必要があります。これを最大シャーディング ケースと呼びます。

最大シャーディング以外のすべてのケースで、デバイス ID リストが指定されている場合、軸サイズの積はデバイス数と一致する必要があります。デバイス ID リストが指定されていない場合、暗黙のデバイス ID リストは iota(product(axes)) です。簡素化のため、iota(product(axes)) と同じデバイス ID リストの指定も禁止されています。この場合、デバイス ID リストは指定しないでください。

メッシュの例を次に示します。

  • 空のメッシュは、伝播中に置き換え可能なプレースホルダ メッシュを表します。<[]>
  • 名前のない軸と明示的なデバイス ID を持つメッシュ。通常は、最大のシャーディングを表すために使用されます。<[], device_ids=[3]>
  • 2 つの軸と暗黙のデバイス ID を持つメッシュ iota(6): <["a"=2, "b"=3]>
  • 2 つの軸と、デバイスの順序を指定する明示的なデバイス ID を持つメッシュ: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>

パラメータ:

パラメータ C++ 型 説明
::llvm::ArrayRef<MeshAxisAttr>
device_ids ::llvm::ArrayRef<int64_t>

MeshAxisAttr

メッシュ内の名前付き軸

構文:

#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>

パラメータ:

パラメータ C++ 型 説明
name ::llvm::StringRef name
サイズ int64_t

OpShardingRuleAttr

オペレーションのパーティショニング方法を指定します。

構文:

#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  bool   # is_custom_rule
>

シャーディング ルールは、オペレーションのさまざまなプロパティ(属性、オペランドの形状、結果の形状など)に応じてオペレーションをパーティショニングする方法を指定するものです。次に例を示します。

%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>

サイズ 1 の因子はシャーディングできませんが、この因子は完全性のために許可されています。これは、点オペレーションなどの多くのオペレーションで、オペランドと結果にわたって対応するサイズ 1 のディメンションがあるためです。

is_custom_rule は、これが stablehlo.custom_call オペレーションに対してユーザーが定義したルールかどうかを記述します。パーティショナーはこれらのオペレーションをパーティショニングする方法を知らないので、ユーザーが方法を指定する必要があります。カスタムルールの場合、ルールは常に保持され、削除されることはありません。is_custom_rulestablehlo.custom_call オペレーションでのみ true にできます。

パラメータ:

パラメータ C++ 型 説明
factor_sizes ::llvm::ArrayRef<int64_t>
operand_mappings ::llvm::ArrayRef<TensorMappingAttr>
result_mappings ::llvm::ArrayRef<TensorMappingAttr>
is_custom_rule bool

SubAxisInfoAttr

このサブ軸が完全な軸からどのように導出されたかに関する情報

構文:

#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>

完全な軸を n 個のサブ軸に分割すると、軸は [k_1,...,k_n] に再形成されます。i 番目のサブ軸は、左側のすべての軸サイズ m=prod(k_1,...,k_(i-1))(事前サイズ)とサイズ k_i の積で表すことができます。したがって、サブ軸情報属性にはこれらの 2 つの数値が保持され、次のように表されます。プリサイズ m とサイズ k の場合は (m)k です。

パラメータ:

パラメータ C++ 型 説明
pre_size int64_t
サイズ int64_t

TensorMappingAttr

テンソルの各ディメンションの因子マッピング。

構文:

#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>

パラメータ:

パラメータ C++ 型 説明
dim_mappings ::llvm::ArrayRef<DimMappingAttr>

TensorShardingAttr

テンソル シャーディング

構文:

#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>   # replicated_axes
>

テンソル シャーディングは特定のメッシュにバインドされ、そのメッシュの軸名のみを参照できます。ディメンションのシャーディングでは、テンソルの各次元と、その軸(またはサブ軸)をメジャーからマイナーにシャーディングします。ディメンションをシャーディングしない他の軸はすべて、暗黙的または明示的に(複製された軸のリストに含まれている場合)複製されます。

このシャーディングがバインドされているメッシュは、対応する MeshOp シンボルを参照するシンボル名、またはインライン化された MeshAttr で指定できます。

パラメータ:

パラメータ C++ 型 説明
mesh_or_ref ::mlir::Attribute メッシュ属性またはフラットメッシュ シンボル参照属性
dim_shardings ::llvm::ArrayRef<DimensionShardingAttr>
replicated_axes ::llvm::ArrayRef<AxisRefAttr> 軸参照のリスト

TensorShardingPerValueAttr

オペレーションのオペランド/結果ごとのテンソル シャーディング

構文:

#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>

パラメータ:

パラメータ C++ 型 説明
シャーディング ::llvm::ArrayRef<TensorShardingAttr>

列挙型

PropagationDirection

伝播方向の列挙型

Cases:

記号 文字列
なし 0 なし
転送 1 転送
BACKWARD 2 BACKWARD
両方 3 両方