Compiler API

背景

読者は、少なくともシャーディング表現の基本を理解していることを前提としています。これは、テンソルのシャーディングを Shardy で表現する方法について説明しています。このドキュメントでは、シャーディング表現をプログラムで使用する方法(プログラムの特定のテンサーにシャーディングを適用する方法など)について説明します。

シャーディング伝播は、テンソルのサブセットに対するシャーディング制約が指定されているプログラム内のすべてのテンソルに対してシャーディングを決定するプロセスです。Shardy のコンパイラ API には、シャーディング伝播に影響を与えたり、制御したりするための複数の方法が公開されています。また、ユーザーは手動でシャーディングされた計算をプログラムに挿入できます。

目標

このドキュメントでは、Shardy のこのような API コンポーネントの設計と、その動作と不変性について説明します。この API はシャーディング伝播の制御に使用されますが、このドキュメントでは伝播の動作や設計方法については説明しません。

概要

  • 入力/出力シャーディング - メイン関数の入力または出力にシャーディングを接続し、関数に渡すとき/関数から返すときに入力/出力テンソルをシャーディングする方法であることを示します。

  • シャーディング制約 - 中間テンソル(matmul の結果など)にシャーディングを適用して、そのテンソルまたはその使用のサブセットをシャーディングする方法を示します。

  • シャーディング グループ - 複数のテンサーを ID でグループ化し、同じ方法でシャーディングする必要があることを示します。

  • 手動計算 - メッシュ軸のサブセットを使用して手動でパーティショニングされたサブ計算を囲みます。ここで、これらの手動軸に沿ったシャーディングがすべての入力と出力に指定され、サブ計算内でテンソル型はこれらのシャーディングに対してローカルになります。

詳細な設計

入出力シャーディング

ユーザーがメイン関数の入出力のシャーディングを指定できるようにします。

MLIR では、関数の引数と結果に属性を適用できるため、ユーザーはシャーディング属性を関数に適用できます。

例:

@mesh_xy = <["x"=2, "y"=2]>

// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
            {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
            %arg1: tensor<8x16xf32>)
    -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
  ...
}

シャーディング制約

ユーザーはプログラムの中間テンソルにシャーディングを適用できます。これにより、そのテンソルまたはその使用のサブセットをシャーディングする方法がパーティショナーに通知されます。

これは、テンソルを入力として受け取り、シャーディング属性が付加された MLIR オペレーションです。オペレーションは次のいずれかです。

  • 使用されていない(ダングリング) - つまり、アタッチされたシャーディングは、テンソル自体がシャーディングされる方法です。
  • 使用がある - つまり、適用されたシャーディングは、シャーディング制約オペレーションの使用方法をシャーディングする方法ですが、入力テンソルの他の使用方法は異なるシャーディングを持つ場合があります(入力テンソルに他の使用方法がない場合は、使用なしの場合と同じ動作になります)。伝播では、テンソル自体のシャーディングが決定され、必要に応じてシャーディングが再作成されます。

オープン ディメンション シャーディングを使用できます。つまり、オペランドは使用可能な軸に沿ってさらにシャーディングできます。

@mesh_xy = <["x"=2, "y"=2]>

%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>

シャーディング グループ

2 つ以上のテンソル間にデータ依存関係や強いデータ依存関係がないが、それらのテンソルを同じ方法または類似の方法でパーティショニングする必要がある場合、Shardy API ではこの関係を指定できます。これにより、テンソルを互いにパーティショニングする必要があることをユーザーが自由に明示的に指定できます。

これを実現するために、シャード グループという概念を導入します。各グループには、同じシャード グループ ID に関連付けられた任意の数の命令が含まれます。シャーディング グループは、同じグループ内のシャーディングが同じになるようにします。

たとえば、次のような架空のユーザー プログラムでは、プログラムの出力をプログラムの入力とまったく同じ方法でシャーディングし、2 つの間にデータ依存関係がないようにします。

このプログラムを実行すると、シャーディング伝播はテンソル %1%2 のシャーディングを推論できず、最終的に複製されます。ただし、入力 %0 と出力 %2 が同じ shard_group 内にあることを示す shard_group 属性を接続すると、シャーディング @mesh_xy, [{"x"},{"y"}]> を入力 %0 から出力 %2 に伝播し、さらに残りのグラフに伝播できます。ここでは、定数 %1 がブロードキャストされます。sdy.sharding_group オペレーションを使用して、グループに値を割り当てることができます。

@mesh_xy = <["x"=2, "y"=2]>

module @"jit_zeros_like" {
  func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
    %0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
    %1 = stablehlo.constant dense<0> : tensor<8x2xi64>
    %2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
    return %2 : tensor<8x2xi64>
  }
}

上記の単純な例では、入力と同じシャーディングを出力に明示的に指定することもできます。これにより、入力に割り当てるシャーディングが事前にわかっているため、同じ効果が得られます。ただし、より現実的なケースでは、シャーディングを使用して、複数のテンサーのシャーディングを同期させます。シャーディングは必ずしも必要ありませんが、Shardy が残りの処理を行い、割り当てる最適なシャーディングを見つけます。

手動計算

ユーザーは、計算の一部をどのように分割するか、どのような集合を使用するかを明示的に制御したい場合があります。たとえば、コンパイラにデリゲートするのではなく、集約 matmul を手動で(フロントエンド API から)適用したいというユーザーもいます。Google は、そのようにすることを可能にする手動計算 API を提供しています。

これは、手動サブ計算用の単一リージョンを持つ MLIR オペレーションです。ユーザーは、メッシュ軸のサブセット(すべてを含む場合もあります)を使用して、このサブ計算への入出力シャーディングを指定します。サブ計算は、指定されたメッシュ軸(手動軸)に対してローカル/手動で、指定されていない軸(フリー軸)に対してグローバル/パーティショニングなしになります。このオペレーションの外部で計算できる方法と同様に、このサブ計算は伝播中に自由軸に沿ってさらにシャーディングできます。

例:

@mesh_name = <["data"=2, "model"=2]>

%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
    in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
    out_shardings=[<@mesh_name, [{"data"}, {?}]>]
    manual_axes={"data"}
    (%arg1: tensor<8x32xf32>) {
  // body
  return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

不変量

  1. すべての in_shardingsout_shardingsmanual_axes は同じメッシュを参照する必要があります。manual_axes はメッシュに対して並べ替えられます。

  2. manual_axes は、インバウンド シャーディングとアウトバウンド シャーディングですべて明示的に使用する必要があります。つまり、シャーディングごとに、すべての手動軸でディメンションをシャーディングするか、明示的にレプリケートする必要があります。

  3. いずれかのインバウンド/アウトバウンド シャーディングにフリー 軸(manual_axes にないメッシュ軸)が存在する場合、その軸は同じディメンション シャーディング内の手動軸よりもマイナーである必要があります(上記の例では、ディメンション シャーディング {"model", "data"} は無効です)。

  4. 計算のリージョン/ボディはローカル計算です(ユーザー指定の集合など)。手動軸に沿ったインバウンド/アウトバウンド シャーディングに関してローカルである必要があります(上記の注記を参照)。

手動計算のネスト

複数の手動計算を互いにネストできます。ただし、各計算が独自の手動軸セットで動作している必要があります。