Shardy

Shardy は、すべての方言に対応した MLIR ベースのテンソル分割システムです。GSPMD チームと PartIR チームの両方のコラボレーションから構築されたこのシステムは、両方のシステムの長所と、両方のチームとユーザーの共有された経験を組み込んでいます。

利点

  • GSPMD の伝播と PartIR の増分パーティショニングを組み合わせることで、ユーザーの制御と予測可能性が向上します。
  • 共有された経験に基づく新機能(たとえば、ユーザーが回避方法を知らない限り、余分なコミュニケーションが発生することが多いシェイプ変更の新しいサポート)。
  • 軸ベースのシャーディング表現を使用するなど、ユーザビリティとデバッグ可能性を向上させてエンドユーザーの速度を向上させます。
  • MLIR を使用するシンプルなオープンソース コードベース。ユーザーをサポートする幅広いアクティブなコントリビューター(内部、外部、さまざまなタイムゾーン)が参加しています。

コンポーネント

  • シャーディング表現: 特定の論理メッシュ(複数のメッシュの中から)にバインドされる軸ベースのシャーディング表現。制約付きディメンション シャーディングと軸、再シェイプなどのオペレーション用の軸の分割、増分パーティショニングの優先度などをサポートします。
  • コンパイラ API: シャーディング表現とともに使用してシャーディング伝播に影響を与えることができる一連のコンパイラ コンポーネント。
    • 入力 / 出力のシャーディング - メイン関数の入力または出力にシャーディングを接続して、関数に渡すときや関数から返すときに入力 / 出力テンソルをシャーディングする方法であることを示します。
    • シャーディング制約 - 中間テンソル(matmul の結果など)にシャーディングを適用して、そのテンソルまたはその使用のサブセットをシャーディングする方法を示します。
    • シャーディング アズ/シャーディング ライク - 複数のテンサーを ID でグループ化し、同じ方法でシャーディングする必要があることを示します。
    • 手動計算 - メッシュ軸のサブセットを使用して手動でパーティショニングされたサブ計算を囲みます。ここで、これらの手動軸に沿ったシャーディングがすべての入力と出力に指定され、サブ計算内でテンソル型はこれらのシャーディングに対してローカルになります。
  • シャーディング伝播: ユーザーの優先度とシャーディング制約を、コンパイラの費用モデルとヒューリスティクスと組み合わせた伝播アルゴリズム。
    • ユーザー定義の優先度(バッチ並列処理してから ZeRO など)
    • オペレーションベースの優先度(要素ごとのオペレーションが先、次に matmul など)。
    • よりきめ細かいヒューリスティック(バッチ ディメンションを優先するなど)。
  • SPMD パーティショナー: プログラムを SPMD プログラムにパーティショニングし、必要なデータ移動/フォーマットと集約オペレーションをプロセスに追加することで、シャーディング伝播の決定を軽減するコンポーネント。
    • 短期的には、最初の実装では現在の GSPMD SPMD パーティショナーが使用されます。
    • 長期的には、新しい MLIR ベースの SPMD パーティショナーを作成することを計画しています。

コード リポジトリ

Shardy プロジェクトは現在も開発中であり、オープンソース コミュニティからのフィードバックをお待ちしています。Shardy コードは https://github.com/openxla/shardy で入手できます。