Analisis pengindeksan

Dokumen ini menjelaskan analisis pengindeksan HLO, yang memungkinkan Anda menghitung peta pengindeksan secara simbolis untuk operasi HLO. Peta pengindeksan adalah fungsi yang memetakan indeks satu tensor ke indeks tensor lain, misalnya indeks output petunjuk HLO ke indeks input petunjuk HLO atau sebaliknya.

Contoh

Untuk siaran dari tensor<20xf32> sampai tensor<10x20x30xf32>

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

peta pengindeksan dari output ke input adalah \((i, j, k) \mapsto (j)\) untuk $i \in [0, 10]\(, \)j \in [0, 20]\( and \)k \in [0, 30]$.

Motivasi

GPU XLA menggunakan beberapa solusi khusus untuk alasan penggabungan, penggunaan operan, dan skema penyusunan ubin (detail selengkapnya di bawah). Sasaran analisis pengindeksan adalah menyediakan komponen yang dapat digunakan kembali untuk kasus penggunaan tersebut. Analisis pengindeksan dikembangkan di infrastruktur Affine Map MLIR dan menambahkan semantik HLO.

Penggabungan

Penalaran tentang penggabungan memori menjadi mungkin dilakukan untuk kasus yang sulit, saat kita tahu elemen/irisan input apa yang dibaca untuk menghitung elemen output.

Pemanfaatan Operand

Pemanfaatan operand dalam XLA menunjukkan seberapa banyak setiap input petunjuk digunakan dengan anggapan bahwa output-nya telah digunakan sepenuhnya. Saat ini, pemanfaatan juga tidak dihitung untuk kasus generik. Analisis pengindeksan memungkinkan penghitungan penggunaan secara tepat.

Susunan persegi

Kartu/slice adalah subset hyper-persegi panjang dari tensor yang diparameterisasi oleh offset, ukuran, dan langkah. Penerapan ubin adalah cara untuk menghitung parameter ubin dari produsen/konsumen pengoperasian menggunakan parameter pemasangan ubin dari operasi itu sendiri. Sudah ada library yang melakukannya untuk softmax dan dot. Propagasi kartu dapat dibuat lebih generik dan kokoh jika dinyatakan melalui peta pengindeksan.

Fungsi dan Domain

Peta pengindeksan adalah fungsi \(\boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\) yang memetakan multi-indeks \(\boldsymbol{d}\) tensor \(A\) ke elemen/rentang tensor \(B\). Parameter ini \(\boldsymbol{s}\) mengacu pada rentang indeks dimensi yang ada di tensor \(B\), tetapi tidak di tensor \(A\).

Misalnya, jika kita melakukan pengurangan dari tensor<2x4x8x16xf32> menjadi tensor<4x8xf32>, peta pengindeksan dari output 2D ke input 4D adalah \((d_0, d_1) \mapsto (s_0, d_0, d_1, s_1)\), dengan \(d_i\) parameter dimensi yang sesuai dengan indeks tensor output. Parameter \(s_j\) mengenkode beberapa nilai, yaitu untuk menghitung \((d_0, d_1)\) elemen output, kita memerlukan \((s_0, d_0, d_1, s_1)\) elemen input, tempat \(s_0 \in [0, 2)\) dan \(s_1 \in [0, 16)\).

Pemetaan ini dapat dibuat dari atribut petunjuk HLO atau pemetaan petunjuk yang tidak digabungkan dapat disusun untuk mendapatkan pengindeksan untuk perpaduan. Pemetaan juga memiliki domain, yang menentukan elemen tensor yang pemetaannya berada.

\[ \begin{eqnarray} \boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\; &s.t.& \\ \boldsymbol{lb}_d &\leq& \boldsymbol{d} \leq \boldsymbol{ub}_d \\ \boldsymbol{lb}_s &\leq& \boldsymbol{s} \leq \boldsymbol{ub}_s \\ \boldsymbol{lb}_g &\leq& \boldsymbol{g}(\boldsymbol{d}, \boldsymbol{s}) \leq \boldsymbol{ub}_g \end{eqnarray} \]

Karena kita ingin meminimalkan penghitungan ulang, kita memerlukan library untuk komputasi simbolis. XLA sudah bergantung pada MLIR, jadi kami menggunakan mlir::AffineMap, bukan menulis library aritmetika simbolis.

Tampilan AffineMap standar

(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)

AffineMap dengan mudah memiliki dua jenis parameter: dimensi dan simbol yang masing-masing dapat kita gunakan untuk \(\boldsymbol d\) dan \(\boldsymbol s\) . AffineMap tidak berisi metadata apa pun tentang rentang dimensi, jadi kita harus menyediakan data ini sendiri.

struct Range {
 int64_t lower_bound;
 int64_t upper_bound;
};

struct IndexingMap {
 mlir::AffineMap affine_map;
 std::vector<Range> dimension_ranges;
 std::vector<Range> symbol_ranges;
 llvm::DenseMap<mlir::AffineExpr, Range> expr_ranges;
};

dim_ranges mengenkode batasan kotak inklusif untuk parameter dimensi \(\boldsymbol{d}\) peta pengindeksan, yang biasanya bertepatan dengan bentuk tensor output untuk operasi seperti transposisi, kurangi, elemen, titik, tetapi ada beberapa pengecualian seperti HloConcatenateInstruction.

symbol_ranges mengenkode kemungkinan nilai yang dapat diambil \(\boldsymbol {s}\) parameter.

Mari kita pelajari setiap contoh untuk memahami arti dari semua penjelasan di atas.

Mengindeks Maps untuk Operasi yang Tidak Digabungkan

Dasar

Untuk operasi elementwise, peta pengindeksan adalah sebuah identitas.

  p0 = f32[10, 20] parameter(0)
  p1 = f32[10, 20] parameter(1)
  add = f32[10, 20] add(p0, p1)

Output ke peta input:

  • output -> input_0: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in [0,9] \times [0, 19]\(, i.e. \)\boldsymbol{d} \in {\rm Dom}(output)$
  • output -> input_1: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in {\rm Dom} (output)$

Peta input ke output

  • input_i -> output: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$

Siarkan

Penyiaran berarti beberapa dimensi akan dihapus saat kita memetakan output ke input dan ditambahkan saat kita memetakan input ke output.

p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}

Output ke peta input:

  • output -> input: \((d_0, d_1, d_2) \mapsto (d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$

Peta input ke output

  • input -> output: \((d_0) \mapsto (s_0, d_1, s_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)\( and \)\boldsymbol{s} \in [0, 9] \times [0, 29]$.

Perhatikan bahwa sekarang kita memiliki \(\boldsymbol s\) di sisi kanan untuk pemetaan input-ke-output. Itu adalah simbol yang mewakili rentang nilai. Misalnya, dalam kasus khusus ini, setiap elemen input dengan indeks \(d_0\) dipetakan ke irisan output berukuran 10x1x30.

Constant dan Iota

Ringkasnya, parameter ini tidak memiliki parameter input apa pun, sehingga tidak ada yang dapat dihitung pengindeksannya.

Transpose

Peta pengindeksan untuk transposisi adalah permutasi dimensi input/output.

p0 = f32[3, 12288, 6, 128] parameter(0)
transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1}

Output ke peta input:

  • output -> input: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_3, d_1, d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)

Peta input ke output:

  • input -> output: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_2, d_3, d_1)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input)\)

Reverse Dunk

Peta pengindeksan untuk terbalik mengubah dimensi yang dikembalikan menjadi $upper_bound(d_i) - d_i$:

p0 = f32[1, 17, 9, 9] parameter(0)
reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2}

Output ke peta input:

  • output -> input: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(output)$

Peta input ke output:

  • input -> output: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(input)$

(Variadic)Pengurangan

Pengurangan variadik memiliki beberapa input dan beberapa init. Peta dari output ke input menambahkan dimensi yang dikurangi. Jadi, dalam beberapa hal, ia berperilaku seperti kebalikan dari siaran.

p0 = f32[256,10] parameter(0)
p0_init = f32[] constant(-inf)
p1 = s32[256,10] parameter(1)
p1_init = s32[] constant(0)
reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init),
  dimensions={0}, to_apply=min

Output ke peta input:

  • output -> input_j: \((d_0) \mapsto (s_0, d_0)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)\( and \)\boldsymbol{s} \in [0, 9]$
  • output -> init_j: \((d_0) \mapsto ()\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$

Peta input ke output:

  • input_i -> output_j: \((d_0, d_1) \mapsto (d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(input)$
  • init_i -> output_j: \(() \mapsto (s_0)\) untuk \(\boldsymbol{s} \in [0, 9]\)

untuk \(i, j = 0, \ldots, INPUT\\_COUNT\).

Slice

Pengindeksan dari output ke input untuk slice menghasilkan peta pengindeksan bertingkat yang valid untuk setiap elemen output. Pemetaan dari input ke output dibatasi pada rentang elemen bertingkat di input tersebut.

p0 = f32[10, 20, 50] parameter(0)
slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0),
  slice={[5:10:1], [3:20:7], [0:50:2]}

Output ke peta input:

  • output -> input: \((d_0, d_1, d_2) \mapsto (d_0 + 5, 7d_1 + 3, 2d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)

Peta input ke output:

  • input -> output: \((d_0, d_1, d_2) \mapsto (d_0, d_1 / 7, d_2 / 2)\) untuk \(\boldsymbol{d} \in [5, 9] \times [3, 19] \times [0, 49]\) dengan langkah $[1, 7, 2]$.

TBD: pengindeksan input-ke-output

Membentuk ulang

{i>Reshape<i} hadir dalam berbagai rasa.

Ciutkan bentuk

Ini adalah bentuk ulang "linearisasi" dari N-D ke 1D.

p0 = f32[4,8] parameter(0)
reshape = f32[32] reshape(p0)

Output ke peta input:

  • output -> input: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$

Peta input ke output:

  • input -> output: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(input)$.

Luaskan bentuk

Ini adalah kebalikan dari operasi “ciutkan bentuk”, yang membentuk ulang masukan 1D menjadi keluaran N-D.

p0 = f32[32] parameter(0)
reshape = f32[4, 8] reshape(p0)

Output ke peta input:

  • output -> input: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$

Peta input ke output:

  • input -> output: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) untuk $\boldsymbol{d} \in {\rm Dom}(input)$.

Pembentukan ulang generik

Ini adalah operasi bentuk ulang yang tidak dapat direpresentasikan sebagai bentuk luaskan atau ciutkan. Elemen ini hanya dapat direpresentasikan sebagai komposisi 2 atau lebih bentuk luaskan atau ciutkan.

Contoh 1: Linearisasi-delinearisasi.
p0 = f32[4,8] parameter(0)
reshape = f32[2, 4, 4] reshape(p0)

Bentuk ulang ini dapat direpresentasikan sebagai komposisi bentuk penciutan tensor<4x8xf32> ke tensor<32xf32>, lalu perluasan bentuk ke tensor<2x4x4xf32>.

Output ke peta input:

  • output -> input: $(d_0, d_1, d_2) \mapsto (2d_0 + (4d_1 + d_2) / 8, 4d_1 + d_2) \mod 8)$

untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)

Peta input ke output:

  • input -> output: $(d_0, d_1) \mapsto ((8d_0 + d_1) / 16, ((8d_0 + d_1) \mod 16) / 4, d_1 \mod 4)$

untuk \(\boldsymbol{d} \in {\rm Dom}(input)\).

Contoh 2: Subbentuk yang diluaskan dan diciutkan
p0 = f32[4, 8, 12] parameter(0)
reshape = f32[32, 3, 4] reshape(p0)

Pembentukan ulang ini dapat direpresentasikan sebagai komposisi dari dua bentuk ulang. Yang pertama menciutkan dimensi terluar tensor<4x8x12xf32> menjadi tensor<32x12xf32>, dan yang kedua memperluas dimensi terdalam tensor<32x12xf32> menjadi tensor<32x3x4xf32>.

Output ke peta input:

  • output -> input: \((d_0, d_1, d_2) \mapsto (d_0 / 8, d_0 \mod 8, 4d_1 + d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)

Peta input ke output:

  • input -> output: \((d_0, d_1, d_2) \mapsto (8d_0 + d_1, d_2 / 4, d_2 \mod 4)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input)\).

Bitcast

Operasi bitcast dapat direpresentasikan sebagai urutan transpose-reshape-transpose. Oleh karena itu, peta pengindeksannya hanyalah komposisi peta pengindeksan untuk urutan ini.

Gabungkan

Pemetaan output-to-input untuk concat ditentukan untuk semua input, tetapi dengan domain yang tidak tumpang-tindih, yaitu hanya satu input yang akan digunakan dalam satu waktu.

p0 = f32[3,50] parameter(0)
p1 = f32[3,30] parameter(1)
concat = f32[3,80] concatenate(f32[3,50] p0, f32[3,30] p1),
  dimensions={1}

Output ke peta input:

  • {i>output<i} -> input 1:

\((d_0, d_1) \mapsto (d_0, d_1)\) untuk \(\boldsymbol{d} \in [0, 2] \times [0, 49]\)

  • {i>output<i} -> input 2:

\((d_0, d_1) \mapsto (d_0, d_1 - 50)\) untuk $\boldsymbol{d} \in [0, 2] \times [50, 79]$

Input untuk peta output:

  • input 1 -> output: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(input_1)$.
  • input 2 -> output: \((d_0, d_1) \mapsto (d_0, d_1 + 50)\) untuk $\boldsymbol{d} \in {\rm Dom}(input_2)$.

Titik (output-ke-input diterapkan

Peta pengindeksan untuk titik sangat mirip dengan dengan reduksi.

p0 = f32[4, 128, 256] parameter(0)
p1 = f32[4, 256, 64] parameter(1)
dot = f32[4, 128, 64] dot(p0, p1),
  lhs_batch_dims={0}, rhs_batch_dims={0},
  lhs_contracting_dims={2}, rhs_contracting_dims={1}

Output ke input memetakan:

  • output -> input_1: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\) dan \(\boldsymbol{s} \in [0, 255]\)
  • output -> input_2: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\) dan \(\boldsymbol{s} \in [0, 255]\)

Input untuk output peta:

  • input_1 -> output: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input_1)\) dan \(\boldsymbol{s} \in [0, 63]\)
  • input_2 -> output: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_1)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input_2)\) dan \(\boldsymbol{s} \in [0, 127]\)

Kurangi periode (TBD)

Bantalan (TBD)

Mengindeks Maps for Fusion

Peta pengindeksan untuk fusion op adalah komposisi peta pengindeksan untuk setiap operasi dalam cluster. Beberapa input dapat dibaca beberapa kali dengan pola akses yang berbeda.

Satu input, beberapa peta pengindeksan

Berikut contoh untuk \(p_0 + p_0^T\)

f {
  p0 = f32[1000, 1000] parameter(0)
  transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0}
  ROOT a0 = f32[1000, 1000] add(p0, transpose_p0)
}

Peta pengindeksan output-ke-input untuk p0 adalah $(d_0, d_1) \mapsto (d_0, d_1)\( and \)(d_0, d_1) \mapsto (d_1, d_0)$. Artinya, untuk menghitung satu elemen output, kita mungkin perlu membaca parameter input dua kali.

Satu input, peta pengindeksan yang dihapus duplikatnya

img

Ada kasus saat peta pengindeksan sebenarnya sama, meskipun tidak langsung terlihat.

f {
  p0 = f32[20, 10, 50] parameter(0)
  lhs_transpose_1 = f32[10, 20, 50] transpose(p0), dimensions={1, 0, 2}
  lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1)
  lhs_transpose_2 = f32[10, 50, 20] transpose(lhs_e), dimensions={0, 2, 1}
  rhs_transpose_1 = f32[50, 10, 20] transpose(p0), dimensions={2, 1, 0}
  rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1)
  rhs_transpose_2 = f32[10, 50, 20] transpose(rhs_log), dimensions={1, 0, 2}
  ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2)
}

Peta pengindeksan output-ke-input untuk p0 dalam kasus ini hanyalah $(d_0, d_1, d_2) \mapsto (d_2, d_0, d_1)$.

Softmax

img

Pemetaan pengindeksan output-ke-input untuk parameter 0 untuk softmax:

  • \((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\)
  • \((d_0, d_1, d_2)[s_0] \mapsto (d_0, d_1, s_0)\)

untuk \(\boldsymbol{d} \in {\rm Dom}(output)\) dan \(\boldsymbol{s} \in [0, 124]\) mengacu pada dimensi input terdalam.

Penyederhanaan Peta Pengindeksan

Penyederhana default untuk upstream mlir::AffineMap tidak dapat membuat asumsi apa pun tentang rentang dimensi/simbol. Oleh karena itu, pengujian ini tidak dapat menyederhanakan ekspresi dengan mod dan div secara efisien.

Kita dapat memanfaatkan pengetahuan tentang batas bawah dan atas dari sub-ekspresi di peta affine untuk lebih menyederhanakannya.

Penyederhana dapat menulis ulang ekspresi berikut.

  1. \((d_0, d_1) \mapsto (d_0 + d1 / 16, d1 \mod 16)\) untuk $\boldsymbol{d} \in [0, 6] \times [0, 14]\( becomes \)(d_0, d_1) \mapsto (d_0, d_1)$
  2. $(d_0, d_1, d_2) \mapsto ((100d_0 + 10d_1 + d_2) /100, ((100d_0 + 10d_1 + d_2) \mod 100) / 10, d_2 \mod 10)\( for \)d_0, d_2 \mod 10)\( for \)d_i\( becomes \)
  3. $(d_0, d_1, d_2) \mapsto ((16d_0 + 4d_1 + d_2) /8, (16d_0 + 4d_1 + d_2) \mod 8)\( for \)d_i \di [0, 9]\( becomes \)(d_0, d_1, d_2) (d_0, d_1, d_2) (d_0, d_1, d_2)
  4. \((d_0, d_1) \mapsto (-(-11d_0 - d_1 + 109) / 11 + 9)\) untuk $\boldsymbol{d} \in [0, 9] \times [0, 10]\( becomes \)(d_0, d_1) \mapsto (d_0)$.

Penyederhanaan peta pengindeksan memungkinkan kita memahami bahwa beberapa bentuk ulang berantai di HLO akan saling membatalkan.

p0 = f32[10, 10, 10] parameter(0)
reshape1 = f32[50, 20] reshape(p0)
reshape2 = f32[10, 10, 10] reshape(reshape1)

Setelah komposisi peta pengindeksan dan penyederhannya, kita akan mendapatkan

\((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\).