インデックス分析

このドキュメントでは、HLO オペレーションのインデックス マップをシンボル的に計算できる HLO インデックス分析について説明します。インデックス マップは、あるテンソルのインデックスを別のテンソルのインデックスにマッピングする関数です。たとえば、HLO 命令出力のインデックスから HLO 命令入力のインデックスへ、またはその逆を行います。

tensor<20xf32> から tensor<10x20x30xf32> へのブロードキャスト

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

出力から入力へのインデックス マップは、 \((i, j, k) \mapsto (j)\) for $i \in[0, 10]\(, \)j \in [0, 20]\( and \)k \in [0, 30]$ です。

目的

XLA GPU は、いくつかの独自のソリューションを使用して、結合、オペランド使用率、タイリング スキームを推測します(詳細は後述)。インデックス分析の目標は、そのようなユースケースに再利用可能なコンポーネントを提供することです。インデックス分析は MLIR の Affine Map インフラストラクチャ上に構築され、HLO セマンティクスが追加されています。

合体

出力の要素を計算するために入力のどの要素/スライスが読み取られるかがわかっている場合は、メモリ結合についての推論が可能になります。

オペランド使用率

XLA のオペランド使用率は、命令の各入力の使用量(その出力をすべて使用したと仮定)を示します。現在、一般的なケースでは使用率も計算されていません。インデックス分析により使用率を正確に計算できます

並べ方

タイル/スライスは、オフセット、サイズ、ストライドによってパラメータ化されたテンソルの超長方形のサブセットです。タイルの伝播は、オペレーション自体のタイリング パラメータを使用して、オペレーションのプロデューサー/コンシューマのタイル パラメータを計算する方法です。softmax と dot 用のライブラリはすでに存在します。タイルの伝播は、インデックス マップを使用して表現すると、より汎用的かつ堅牢にできます。

機能とドメイン

インデックス マップは、 \(\boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\)テンソルの \(\boldsymbol{d}\) マルチ インデックスを \(A\) テンソルの要素/範囲 \(A\) にマッピングする関数です。 \(B\)パラメータ \(\boldsymbol{s}\) は、テンソル \(B\)には存在し、テンソルには存在しないディメンションのインデックス範囲を指します。 \(A\)

たとえば、tensor<2x4x8x16xf32> から tensor<4x8xf32> にリダクションする場合、2D 出力から 4D 入力へのインデックス マップは\((d_0, d_1) \mapsto (s_0, d_0, d_1, s_1)\)になります。ここで、 \(d_i\) は出力テンソルのインデックスに対応するディメンション パラメータです。パラメータ \(s_j\)は複数の値をエンコードします。つまり、出力の \((d_0, d_1)\) 要素を計算するには、 \((s_0, d_0, d_1, s_1)\) 入力の要素が必要です。ここで、 \(s_0 \in [0, 2)\) と\(s_1 \in [0, 16)\)です。

このマッピングは、HLO 命令の属性から構築することも、融合されていない命令のマッピングを構成して、融合のインデックスを取得することもできます。マッピングにはドメインもあり、マッピングが存在するテンソルの要素を指定します。

\[ \begin{eqnarray} \boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\; &s.t.& \\ \boldsymbol{lb}_d &\leq& \boldsymbol{d} \leq \boldsymbol{ub}_d \\ \boldsymbol{lb}_s &\leq& \boldsymbol{s} \leq \boldsymbol{ub}_s \\ \boldsymbol{lb}_g &\leq& \boldsymbol{g}(\boldsymbol{d}, \boldsymbol{s}) \leq \boldsymbol{ub}_g \end{eqnarray} \]

再計算を最小限に抑えるため、シンボリック計算用のライブラリが必要です。XLA はすでに MLIR に依存しているため、シンボリック算術ライブラリを記述する代わりに mlir::AffineMap を使用します。

一般的な AffineMap は次のようになります。

(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)

AffineMap にはディメンションとシンボルという 2 種類のパラメータがあり、それぞれ \(\boldsymbol d\) と \(\boldsymbol s\) に使用できます。AffineMap にはディメンションの範囲に関するメタデータが含まれていないため、このデータを自分で指定する必要があります。

struct Range {
 int64_t lower_bound;
 int64_t upper_bound;
};

struct IndexingMap {
 mlir::AffineMap affine_map;
 std::vector<Range> dimension_ranges;
 std::vector<Range> symbol_ranges;
 llvm::DenseMap<mlir::AffineExpr, Range> expr_ranges;
};

dim_ranges はインデックス マップのディメンション パラメータの \(\boldsymbol{d}\) ボックス制約をエンコードします。これは通常、転置、reduce、要素単位、ドットなどの演算の出力テンソルの形状と一致しますが、HloConcatenateInstruction のような例外もあります。

symbol_ranges は、 \(\boldsymbol {s}\) パラメータが取り得る値をエンコードします。

上述の内容の実際の意味を理解するために、例を挙げて学習しましょう。

Unfused Ops のインデックス登録マップ

エレメントワイズ

要素単位のオペレーションでは、インデックス マップは ID です。

  p0 = f32[10, 20] parameter(0)
  p1 = f32[10, 20] parameter(1)
  add = f32[10, 20] add(p0, p1)

入力マップへの出力:

  • output -> input_0: \((d_0, d_1) \mapsto (d_0, d_1)\) $\boldsymbol{d} \in [0,9] \times [0, 19]\(, i.e. \)\boldsymbol{d} \in {\rm Dom}(output)$
  • output -> input_1: \((d_0, d_1) \mapsto (d_0, d_1)\) $\boldsymbol{d} \in {\rm Dom} (output)$

入力から出力へのマップ

  • input_i -> output: \((d_0, d_1) \mapsto (d_0, d_1)\) for $\boldsymbol{d} \in {\rm Dom}(output)$

ブロードキャスト

ブロードキャストとは、出力を入力にマッピングするときに一部のディメンションが削除され、入力と出力をマッピングするときに追加されます。

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

入力マップへの出力:

  • output -> input: \((d_0, d_1, d_2) \mapsto (d_1)\) $\boldsymbol{d} \in {\rm Dom}(output)$

入力から出力へのマップ

  • 入力 -> 出力: \((d_0) \mapsto (s_0, d_1, s_1)\) $\boldsymbol{d} \in {\rm Dom}(output)\( and \)\boldsymbol{s} \in [0, 9] \times [0, 29]$ の場合。

入力と出力のマッピングの右側に \(\boldsymbol s\) があります。これらは値の範囲を表す記号です。たとえば、この特定のケースでは、インデックス \(d_0\) を持つ入力のすべての要素が出力の 10x1x30 スライスにマッピングされます。

コンスタントおよびイオタ

便利なことに、インデックスには入力パラメータがないため、インデックスを計算する必要がありません。

行 / 列の入れ替え

転置のためのインデックス マップは、入力/出力ディメンションの順列です。

p0 = f32[3, 12288, 6, 128] parameter(0)
transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1}

入力マップへの出力:

  • output -> input: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_3, d_1, d_2)\) (\(\boldsymbol{d} \in {\rm Dom}(output)\)の場合)

入力と出力のマッピング:

  • 入力 -> 出力: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_2, d_3, d_1)\) (\(\boldsymbol{d} \in {\rm Dom}(input)\)の場合)

リバース

逆方向のインデックス登録マップは、元に戻すディメンションを $upper_bound(d_i) - d_i$ に変更します。

p0 = f32[1, 17, 9, 9] parameter(0)
reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2}

入力マップへの出力:

  • output -> 入力: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(output)$

入力と出力のマッピング:

  • 入力 -> 出力: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(input)$

(Variadic)Reduce

可変次元削減には複数の入力と複数の init があり、出力から入力へのマップは削減されたディメンションを追加します。したがって、ある意味、ブロードキャストの逆のように動作します。

p0 = f32[256,10] parameter(0)
p0_init = f32[] constant(-inf)
p1 = s32[256,10] parameter(1)
p1_init = s32[] constant(0)
reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init),
  dimensions={0}, to_apply=min

入力マップへの出力:

  • output -> input_j: \((d_0) \mapsto (s_0, d_0)\) $\boldsymbol{d} \in {\rm Dom}(output)\( and \)\boldsymbol{s} \in [0, 9]$
  • output -> init_j: \((d_0) \mapsto ()\) for $\boldsymbol{d} \in {\rm Dom}(output)$

入力から出力へのマップは次のとおりです。

  • input_i -> output_j: \((d_0, d_1) \mapsto (d_1)\) $\boldsymbol{d} \in {\rm Dom}(input)$
  • init_i -> output_j: \(() \mapsto (s_0)\) の \(\boldsymbol{s} \in [0, 9]\)

( \(i, j = 0, \ldots, INPUT\\_COUNT\))。

スライス

スライスの出力から入力へのインデックス作成は、出力のすべての要素に対して有効なインデックス付きインデックス マップになります。入力から出力へのマッピングは、入力内の要素のストライド範囲に制限されます。

p0 = f32[10, 20, 50] parameter(0)
slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0),
  slice={[5:10:1], [3:20:7], [0:50:2]}

入力マップへの出力:

  • output -> input: \((d_0, d_1, d_2) \mapsto (d_0 + 5, 7d_1 + 3, 2d_2)\) (\(\boldsymbol{d} \in {\rm Dom}(output)\)の場合)

入力と出力のマッピング:

  • 入力 -> 出力: \((d_0, d_1, d_2) \mapsto (d_0, d_1 / 7, d_2 / 2)\) (ストライドが $[1, 7, 2]$の\(\boldsymbol{d} \in [5, 9] \times [3, 19] \times [0, 49]\) )

未定: 入出力インデックス

Reshape

形状変更にはさまざまなフレーバーがあります。

図形を閉じる

これは N 次元から 1 次元への形状変更の「線形化」です。

p0 = f32[4,8] parameter(0)
reshape = f32[32] reshape(p0)

入力マップへの出力:

  • output -> input: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) $\boldsymbol{d} \in {\rm Dom}(output)$

入力と出力のマッピング:

  • input -> output: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) $\boldsymbol{d} \in {\rm Dom}(input)$ に対応。

シェイプを開く

これは逆の「折りたたみシェイプ」オペレーションで、1 次元入力を N-D 出力に再形成します。

p0 = f32[32] parameter(0)
reshape = f32[4, 8] reshape(p0)

入力マップへの出力:

  • output -> input: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) $\boldsymbol{d} \in {\rm Dom}(output)$

入力と出力のマッピング:

  • input -> output: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) $\boldsymbol{d} \in {\rm Dom}(input)$ に対応

一般的な形状変更

これらは、単一の展開シェイプまたは折りたたみシェイプとして表すことができない形状変更オペレーションです。2 つ以上の展開または折りたたみのシェイプで構成されるコンポジションとしてのみ表すことができます。

例 1: 線形化 - 非線形化。
p0 = f32[4,8] parameter(0)
reshape = f32[2, 4, 4] reshape(p0)

この形状変更は、tensor<4x8xf32> から tensor<32xf32> への折りたたみシェイプと、tensor<2x4x4xf32> へのシェイプ拡張の合成として表すことができます。

入力マップへの出力:

  • output -> 入力: $(d_0, d_1, d_2) \mapsto (2d_0 + (4d_1 + d_2) / 8, 4d_1 + d_2) \mod 8)$

\(\boldsymbol{d} \in {\rm Dom}(output)\)

入力と出力のマッピング:

  • 入力 -> 出力: $(d_0, d_1) \mapsto ((8d_0 + d_1) / 16, ((8d_0 + d_1) \mod 16) / 4, d_1 \mod 4)$

( \(\boldsymbol{d} \in {\rm Dom}(input)\))。

例 2: サブシェイプの展開と折りたたみ
p0 = f32[4, 8, 12] parameter(0)
reshape = f32[32, 3, 4] reshape(p0)

この形状の変更は、2 つの形状の変更の組み合わせとして表すことができます。1 つ目のメソッドでは、最も外側のディメンション tensor<4x8x12xf32>tensor<32x12xf32> に折りたたみ、2 つ目のメソッドでは、最も内側のディメンション tensor<32x12xf32>tensor<32x3x4xf32> に拡張しています。

入力マップへの出力:

  • output -> input: \((d_0, d_1, d_2) \mapsto (d_0 / 8, d_0 \mod 8, 4d_1 + d_2)\)( \(\boldsymbol{d} \in {\rm Dom}(output)\)の場合)

入力と出力のマッピング:

  • input -> output: \((d_0, d_1, d_2) \mapsto (8d_0 + d_1, d_2 / 4, d_2 \mod 4)\)( \(\boldsymbol{d} \in {\rm Dom}(input)\)の場合)

ビットキャスト

ビットキャスト演算は、transpose-reshape-transpose のシーケンスとして表すことができます。したがって、インデックス マップは、このシーケンスのインデックス マップを組み合わせたものです。

Concatenate

concat の出力から入力へのマッピングはすべての入力に対して定義されますが、ドメインが重複していません。つまり、一度に 1 つの入力のみが使用されます。

p0 = f32[3,50] parameter(0)
p1 = f32[3,30] parameter(1)
concat = f32[3,80] concatenate(f32[3,50] p0, f32[3,30] p1),
  dimensions={1}

入力マップへの出力:

  • 出力 -> 入力 1:

\((d_0, d_1) \mapsto (d_0, d_1)\) / \(\boldsymbol{d} \in [0, 2] \times [0, 49]\)

  • 出力 -> 入力 2:

\((d_0, d_1) \mapsto (d_0, d_1 - 50)\) $\boldsymbol{d} \in [0, 2] \times [50, 79]$

出力マップへの入力:

  • 入力 1 -> 出力: \((d_0, d_1) \mapsto (d_0, d_1)\) $\boldsymbol{d} \in {\rm Dom}(input_1)$ です。
  • 入力 2 -> 出力: \((d_0, d_1) \mapsto (d_0, d_1 + 50)\) $\boldsymbol{d} \in {\rm Dom}(input_2)$。

ドット(output-to-input が実装されている

ドットのインデックス登録マップは、Reduce のマップとよく似ています。

p0 = f32[4, 128, 256] parameter(0)
p1 = f32[4, 256, 64] parameter(1)
dot = f32[4, 128, 64] dot(p0, p1),
  lhs_batch_dims={0}, rhs_batch_dims={0},
  lhs_contracting_dims={2}, rhs_contracting_dims={1}

入力マップへの出力は次のとおりです。

  • output -> input_1: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) (\(\boldsymbol{d} \in {\rm Dom}(output)\) と \(\boldsymbol{s} \in [0, 255]\)の場合)
  • \(\boldsymbol{d} \in {\rm Dom}(output)\) と \(\boldsymbol{s} \in [0, 255]\)の output -> input_2: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_2)\)

出力マップへの入力は次のとおりです。

  • input_1 -> output: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) (\(\boldsymbol{d} \in {\rm Dom}(input_1)\) と \(\boldsymbol{s} \in [0, 63]\)の場合)
  • input_2 -> output: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_1)\) (\(\boldsymbol{d} \in {\rm Dom}(input_2)\) と \(\boldsymbol{s} \in [0, 127]\)の場合)

短縮ウィンドウ(未定)

パッド(未定)

Fusion 用マップのインデックス登録

fusion op のインデックス登録マップは、クラスタ内のすべての op のインデックス マップから構成されています。一部の入力は、異なるアクセス パターンで複数回読み取られることがあります。

1 つの入力、複数のインデックス マップ

\(p_0 + p_0^T\)の例を次に示します。

f {
  p0 = f32[1000, 1000] parameter(0)
  transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0}
  ROOT a0 = f32[1000, 1000] add(p0, transpose_p0)
}

p0 の出力から入力へのインデックス マップは $(d_0, d_1) \mapsto (d_0, d_1)\( and \)(d_0, d_1) \mapsto (d_1, d_0)$ となります。出力の 1 つの要素を計算するには、入力パラメータを 2 回読み取る必要があることを意味します。

1 つの入力、重複除去されたインデックス マップ

img

すぐには識別できない場合でも、インデックス マップが実際には同じである場合もあります。

f {
  p0 = f32[20, 10, 50] parameter(0)
  lhs_transpose_1 = f32[10, 20, 50] transpose(p0), dimensions={1, 0, 2}
  lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1)
  lhs_transpose_2 = f32[10, 50, 20] transpose(lhs_e), dimensions={0, 2, 1}
  rhs_transpose_1 = f32[50, 10, 20] transpose(p0), dimensions={2, 1, 0}
  rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1)
  rhs_transpose_2 = f32[10, 50, 20] transpose(rhs_log), dimensions={1, 0, 2}
  ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2)
}

この場合の p0 の出力から入力へのインデックス マップは、$(d_0, d_1, d_2) \mapsto (d_2, d_0, d_1)$ だけです。

ソフトマックス

img

出力から入力へのインデックスは、ソフトマックスの parameter 0 にマッピングされます。

  • \((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\)
  • \((d_0, d_1, d_2)[s_0] \mapsto (d_0, d_1, s_0)\)

for \(\boldsymbol{d} \in {\rm Dom}(output)\) と \(\boldsymbol{s} \in [0, 124]\)は、入力の最も内側のディメンションを指します。

Indexing Map Simplifier

mlir::AffineMap アップストリームのデフォルトの単純化では、次元/記号の範囲について想定することはできません。したがって、moddiv を使用して式を効率的に単純化することはできません。

アフィン マップ内のサブ式の下限と上限に関する知識を活用して、さらに簡素化できます。

単純化器は次の式を書き換えることができます。

  1. \((d_0, d_1) \mapsto (d_0 + d1 / 16, d1 \mod 16)\) $\boldsymbol{d} \in [0, 6] \times [0, 14]\( becomes \)(d_0, d_1) \mapsto (d_0, d_1)$
  2. $(d_0, d_1, d_2) \mapsto ((100d_0 + 10d_1 + d_2) /100, ((100d_0 + 10d_1 + d_2) \mod 100) / 10, d_2 d_0,\( becomes \)\mod 10)\( for \)
  3. $(d_0, d_1, d_2) \mapsto ((16d_0 + 4d_1 + d_2) /8, (16d_0 + 4d_1 + d_2) \mod 8)\( for \)d_i \in [0, 9]\( becomes \)(d_0, + 2_1, d_2)
  4. \((d_0, d_1) \mapsto (-(-11d_0 - d_1 + 109) / 11 + 9)\) : $\boldsymbol{d} \in [0, 9] \times [0, 10]\( becomes \)(d_0, d_1) \mapsto (d_0)$.

インデックス登録マップの簡素化により、HLO のチェーンされた形状変更の一部が相互にキャンセルされることを理解できます。

p0 = f32[10, 10, 10] parameter(0)
reshape1 = f32[50, 20] reshape(p0)
reshape2 = f32[10, 10, 10] reshape(reshape1)

インデックス登録マップの構成と簡素化が完了すると

\((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\).