Dokumen ini menjelaskan analisis pengindeksan HLO, yang memungkinkan Anda menghitung peta pengindeksan secara simbolis untuk operasi HLO. Peta pengindeksan adalah fungsi yang memetakan indeks satu tensor ke indeks tensor lain, misalnya indeks output petunjuk HLO ke indeks input petunjuk HLO atau sebaliknya.
Contoh
Untuk siaran dari tensor<20xf32>
sampai tensor<10x20x30xf32>
p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}
peta pengindeksan dari output ke input adalah \((i, j, k) \mapsto (j)\) untuk $i \in [0, 10]\(, \)j \in [0, 20]\( and \)k \in [0, 30]$.
Motivasi
GPU XLA menggunakan beberapa solusi khusus untuk alasan penggabungan, penggunaan operan, dan skema penyusunan ubin (detail selengkapnya di bawah). Sasaran analisis pengindeksan adalah menyediakan komponen yang dapat digunakan kembali untuk kasus penggunaan tersebut. Analisis pengindeksan dikembangkan di infrastruktur Affine Map MLIR dan menambahkan semantik HLO.
Penggabungan
Penalaran tentang penggabungan memori menjadi mungkin dilakukan untuk kasus yang sulit, saat kita tahu elemen/irisan input apa yang dibaca untuk menghitung elemen output.
Pemanfaatan Operand
Pemanfaatan operand dalam XLA menunjukkan seberapa banyak setiap input petunjuk digunakan dengan anggapan bahwa output-nya telah digunakan sepenuhnya. Saat ini, pemanfaatan juga tidak dihitung untuk kasus generik. Analisis pengindeksan memungkinkan penghitungan penggunaan secara tepat.
Susunan persegi
Kartu/slice adalah subset hyper-persegi panjang dari tensor yang diparameterisasi oleh offset, ukuran, dan langkah. Penerapan ubin adalah cara untuk menghitung parameter ubin dari produsen/konsumen pengoperasian menggunakan parameter pemasangan ubin dari operasi itu sendiri. Sudah ada library yang melakukannya untuk softmax dan dot. Propagasi kartu dapat dibuat lebih generik dan kokoh jika dinyatakan melalui peta pengindeksan.
Fungsi dan Domain
Peta pengindeksan adalah fungsi \(\boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\) yang memetakan multi-indeks \(\boldsymbol{d}\) tensor \(A\) ke elemen/rentang tensor \(B\). Parameter ini \(\boldsymbol{s}\) mengacu pada rentang indeks dimensi yang ada di tensor \(B\), tetapi tidak di tensor \(A\).
Misalnya, jika kita melakukan pengurangan dari tensor<2x4x8x16xf32>
menjadi
tensor<4x8xf32>
, peta pengindeksan dari output 2D ke input 4D adalah
\((d_0, d_1) \mapsto (s_0, d_0, d_1, s_1)\), dengan \(d_i\) parameter dimensi
yang sesuai dengan indeks tensor output. Parameter \(s_j\)
mengenkode beberapa nilai, yaitu untuk menghitung \((d_0, d_1)\) elemen output, kita
memerlukan \((s_0, d_0, d_1, s_1)\) elemen input, tempat \(s_0 \in [0, 2)\) dan
\(s_1 \in [0, 16)\).
Pemetaan ini dapat dibuat dari atribut petunjuk HLO atau pemetaan petunjuk yang tidak digabungkan dapat disusun untuk mendapatkan pengindeksan untuk perpaduan. Pemetaan juga memiliki domain, yang menentukan elemen tensor yang pemetaannya berada.
\[ \begin{eqnarray} \boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\; &s.t.& \\ \boldsymbol{lb}_d &\leq& \boldsymbol{d} \leq \boldsymbol{ub}_d \\ \boldsymbol{lb}_s &\leq& \boldsymbol{s} \leq \boldsymbol{ub}_s \\ \boldsymbol{lb}_g &\leq& \boldsymbol{g}(\boldsymbol{d}, \boldsymbol{s}) \leq \boldsymbol{ub}_g \end{eqnarray} \]
Karena kita ingin meminimalkan penghitungan ulang, kita memerlukan library untuk komputasi simbolis. XLA sudah bergantung pada MLIR, jadi kami menggunakan mlir::AffineMap, bukan menulis library aritmetika simbolis.
Tampilan AffineMap
standar
(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)
AffineMap
dengan mudah memiliki dua jenis parameter: dimensi dan simbol
yang masing-masing dapat kita gunakan untuk \(\boldsymbol d\) dan \(\boldsymbol s\) .
AffineMap
tidak berisi metadata apa pun tentang rentang dimensi, jadi kita harus menyediakan data ini sendiri.
struct Range {
int64_t lower_bound;
int64_t upper_bound;
};
struct IndexingMap {
mlir::AffineMap affine_map;
std::vector<Range> dimension_ranges;
std::vector<Range> symbol_ranges;
llvm::DenseMap<mlir::AffineExpr, Range> expr_ranges;
};
dim_ranges
mengenkode batasan kotak inklusif untuk parameter dimensi \(\boldsymbol{d}\) peta pengindeksan, yang biasanya bertepatan dengan bentuk tensor output untuk operasi seperti transposisi, kurangi, elemen, titik, tetapi ada beberapa pengecualian seperti HloConcatenateInstruction.
symbol_ranges
mengenkode kemungkinan nilai yang dapat diambil \(\boldsymbol {s}\) parameter.
Mari kita pelajari setiap contoh untuk memahami arti dari semua penjelasan di atas.
Mengindeks Maps untuk Operasi yang Tidak Digabungkan
Dasar
Untuk operasi elementwise, peta pengindeksan adalah sebuah identitas.
p0 = f32[10, 20] parameter(0)
p1 = f32[10, 20] parameter(1)
add = f32[10, 20] add(p0, p1)
Output ke peta input:
- output -> input_0: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in [0,9] \times [0, 19]\(, i.e. \)\boldsymbol{d} \in {\rm Dom}(output)$
- output -> input_1: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in {\rm Dom} (output)$
Peta input ke output
- input_i -> output: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$
Siarkan
Penyiaran berarti beberapa dimensi akan dihapus saat kita memetakan output ke input dan ditambahkan saat kita memetakan input ke output.
p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}
Output ke peta input:
- output -> input: \((d_0, d_1, d_2) \mapsto (d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$
Peta input ke output
- input -> output: \((d_0) \mapsto (s_0, d_1, s_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)\( and \)\boldsymbol{s} \in [0, 9] \times [0, 29]$.
Perhatikan bahwa sekarang kita memiliki \(\boldsymbol s\) di sisi kanan untuk pemetaan input-ke-output. Itu adalah simbol yang mewakili rentang nilai. Misalnya, dalam kasus khusus ini, setiap elemen input dengan indeks \(d_0\) dipetakan ke irisan output berukuran 10x1x30.
Constant dan Iota
Ringkasnya, parameter ini tidak memiliki parameter input apa pun, sehingga tidak ada yang dapat dihitung pengindeksannya.
Transpose
Peta pengindeksan untuk transposisi adalah permutasi dimensi input/output.
p0 = f32[3, 12288, 6, 128] parameter(0)
transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1}
Output ke peta input:
- output -> input: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_3, d_1, d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)
Peta input ke output:
- input -> output: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_2, d_3, d_1)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input)\)
Reverse Dunk
Peta pengindeksan untuk terbalik mengubah dimensi yang dikembalikan menjadi $upper_bound(d_i) - d_i$:
p0 = f32[1, 17, 9, 9] parameter(0)
reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2}
Output ke peta input:
- output -> input: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(output)$
Peta input ke output:
- input -> output: $(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(input)$
(Variadic)Pengurangan
Pengurangan variadik memiliki beberapa input dan beberapa init. Peta dari output ke input menambahkan dimensi yang dikurangi. Jadi, dalam beberapa hal, ia berperilaku seperti kebalikan dari siaran.
p0 = f32[256,10] parameter(0)
p0_init = f32[] constant(-inf)
p1 = s32[256,10] parameter(1)
p1_init = s32[] constant(0)
reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init),
dimensions={0}, to_apply=min
Output ke peta input:
- output -> input_j: \((d_0) \mapsto (s_0, d_0)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)\( and \)\boldsymbol{s} \in [0, 9]$
- output -> init_j: \((d_0) \mapsto ()\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$
Peta input ke output:
- input_i -> output_j: \((d_0, d_1) \mapsto (d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(input)$
- init_i -> output_j: \(() \mapsto (s_0)\) untuk \(\boldsymbol{s} \in [0, 9]\)
untuk \(i, j = 0, \ldots, INPUT\\_COUNT\).
Slice
Pengindeksan dari output ke input untuk slice menghasilkan peta pengindeksan bertingkat yang valid untuk setiap elemen output. Pemetaan dari input ke output dibatasi pada rentang elemen bertingkat di input tersebut.
p0 = f32[10, 20, 50] parameter(0)
slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0),
slice={[5:10:1], [3:20:7], [0:50:2]}
Output ke peta input:
- output -> input: \((d_0, d_1, d_2) \mapsto (d_0 + 5, 7d_1 + 3, 2d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)
Peta input ke output:
- input -> output: \((d_0, d_1, d_2) \mapsto (d_0, d_1 / 7, d_2 / 2)\) untuk \(\boldsymbol{d} \in [5, 9] \times [3, 19] \times [0, 49]\) dengan langkah $[1, 7, 2]$.
TBD: pengindeksan input-ke-output
Membentuk ulang
{i>Reshape<i} hadir dalam berbagai rasa.
Ciutkan bentuk
Ini adalah bentuk ulang "linearisasi" dari N-D ke 1D.
p0 = f32[4,8] parameter(0)
reshape = f32[32] reshape(p0)
Output ke peta input:
- output -> input: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$
Peta input ke output:
- input -> output: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(input)$.
Luaskan bentuk
Ini adalah kebalikan dari operasi “ciutkan bentuk”, yang membentuk ulang masukan 1D menjadi keluaran N-D.
p0 = f32[32] parameter(0)
reshape = f32[4, 8] reshape(p0)
Output ke peta input:
- output -> input: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(output)$
Peta input ke output:
- input -> output: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) untuk $\boldsymbol{d} \in {\rm Dom}(input)$.
Pembentukan ulang generik
Ini adalah operasi bentuk ulang yang tidak dapat direpresentasikan sebagai bentuk luaskan atau ciutkan. Elemen ini hanya dapat direpresentasikan sebagai komposisi 2 atau lebih bentuk luaskan atau ciutkan.
Contoh 1: Linearisasi-delinearisasi.
p0 = f32[4,8] parameter(0)
reshape = f32[2, 4, 4] reshape(p0)
Bentuk ulang ini dapat direpresentasikan sebagai komposisi bentuk penciutan
tensor<4x8xf32>
ke tensor<32xf32>
, lalu perluasan bentuk ke
tensor<2x4x4xf32>
.
Output ke peta input:
- output -> input: $(d_0, d_1, d_2) \mapsto (2d_0 + (4d_1 + d_2) / 8, 4d_1 + d_2) \mod 8)$
untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)
Peta input ke output:
- input -> output: $(d_0, d_1) \mapsto ((8d_0 + d_1) / 16, ((8d_0 + d_1) \mod 16) / 4, d_1 \mod 4)$
untuk \(\boldsymbol{d} \in {\rm Dom}(input)\).
Contoh 2: Subbentuk yang diluaskan dan diciutkan
p0 = f32[4, 8, 12] parameter(0)
reshape = f32[32, 3, 4] reshape(p0)
Pembentukan ulang ini dapat direpresentasikan sebagai komposisi dari dua bentuk ulang. Yang pertama
menciutkan dimensi terluar tensor<4x8x12xf32>
menjadi tensor<32x12xf32>
,
dan yang kedua memperluas dimensi terdalam tensor<32x12xf32>
menjadi
tensor<32x3x4xf32>
.
Output ke peta input:
- output -> input: \((d_0, d_1, d_2) \mapsto (d_0 / 8, d_0 \mod 8, 4d_1 + d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\)
Peta input ke output:
- input -> output: \((d_0, d_1, d_2) \mapsto (8d_0 + d_1, d_2 / 4, d_2 \mod 4)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input)\).
Bitcast
Operasi bitcast dapat direpresentasikan sebagai urutan transpose-reshape-transpose. Oleh karena itu, peta pengindeksannya hanyalah komposisi peta pengindeksan untuk urutan ini.
Gabungkan
Pemetaan output-to-input untuk concat ditentukan untuk semua input, tetapi dengan domain yang tidak tumpang-tindih, yaitu hanya satu input yang akan digunakan dalam satu waktu.
p0 = f32[3,50] parameter(0)
p1 = f32[3,30] parameter(1)
concat = f32[3,80] concatenate(f32[3,50] p0, f32[3,30] p1),
dimensions={1}
Output ke peta input:
- {i>output<i} -> input 1:
\((d_0, d_1) \mapsto (d_0, d_1)\) untuk \(\boldsymbol{d} \in [0, 2] \times [0, 49]\)
- {i>output<i} -> input 2:
\((d_0, d_1) \mapsto (d_0, d_1 - 50)\) untuk $\boldsymbol{d} \in [0, 2] \times [50, 79]$
Input untuk peta output:
- input 1 -> output: \((d_0, d_1) \mapsto (d_0, d_1)\) untuk $\boldsymbol{d} \in {\rm Dom}(input_1)$.
- input 2 -> output: \((d_0, d_1) \mapsto (d_0, d_1 + 50)\) untuk $\boldsymbol{d} \in {\rm Dom}(input_2)$.
Titik (output-ke-input diterapkan
Peta pengindeksan untuk titik sangat mirip dengan dengan reduksi.
p0 = f32[4, 128, 256] parameter(0)
p1 = f32[4, 256, 64] parameter(1)
dot = f32[4, 128, 64] dot(p0, p1),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={1}
Output ke input memetakan:
- output -> input_1: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\) dan \(\boldsymbol{s} \in [0, 255]\)
- output -> input_2: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_2)\) untuk \(\boldsymbol{d} \in {\rm Dom}(output)\) dan \(\boldsymbol{s} \in [0, 255]\)
Input untuk output peta:
- input_1 -> output: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input_1)\) dan \(\boldsymbol{s} \in [0, 63]\)
- input_2 -> output: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_1)\) untuk \(\boldsymbol{d} \in {\rm Dom}(input_2)\) dan \(\boldsymbol{s} \in [0, 127]\)
Kurangi periode (TBD)
Bantalan (TBD)
Mengindeks Maps for Fusion
Peta pengindeksan untuk fusion op adalah komposisi peta pengindeksan untuk setiap operasi dalam cluster. Beberapa input dapat dibaca beberapa kali dengan pola akses yang berbeda.
Satu input, beberapa peta pengindeksan
Berikut contoh untuk \(p_0 + p_0^T\)
f {
p0 = f32[1000, 1000] parameter(0)
transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0}
ROOT a0 = f32[1000, 1000] add(p0, transpose_p0)
}
Peta pengindeksan output-ke-input untuk p0
adalah $(d_0, d_1) \mapsto (d_0,
d_1)\( and \)(d_0, d_1) \mapsto (d_1, d_0)$. Artinya, untuk menghitung satu elemen
output, kita mungkin perlu membaca parameter input dua kali.
Satu input, peta pengindeksan yang dihapus duplikatnya
Ada kasus saat peta pengindeksan sebenarnya sama, meskipun tidak langsung terlihat.
f {
p0 = f32[20, 10, 50] parameter(0)
lhs_transpose_1 = f32[10, 20, 50] transpose(p0), dimensions={1, 0, 2}
lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1)
lhs_transpose_2 = f32[10, 50, 20] transpose(lhs_e), dimensions={0, 2, 1}
rhs_transpose_1 = f32[50, 10, 20] transpose(p0), dimensions={2, 1, 0}
rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1)
rhs_transpose_2 = f32[10, 50, 20] transpose(rhs_log), dimensions={1, 0, 2}
ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2)
}
Peta pengindeksan output-ke-input untuk p0
dalam kasus ini hanyalah $(d_0, d_1, d_2)
\mapsto (d_2, d_0, d_1)$.
Softmax
Pemetaan pengindeksan output-ke-input untuk parameter 0
untuk softmax:
- \((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\)
- \((d_0, d_1, d_2)[s_0] \mapsto (d_0, d_1, s_0)\)
untuk \(\boldsymbol{d} \in {\rm Dom}(output)\) dan \(\boldsymbol{s} \in [0, 124]\) mengacu pada dimensi input terdalam.
Penyederhanaan Peta Pengindeksan
Penyederhana default untuk upstream mlir::AffineMap
tidak dapat membuat asumsi apa pun tentang rentang dimensi/simbol. Oleh karena itu, pengujian ini tidak dapat
menyederhanakan ekspresi dengan mod
dan div
secara efisien.
Kita dapat memanfaatkan pengetahuan tentang batas bawah dan atas dari sub-ekspresi di peta affine untuk lebih menyederhanakannya.
Penyederhana dapat menulis ulang ekspresi berikut.
- \((d_0, d_1) \mapsto (d_0 + d1 / 16, d1 \mod 16)\) untuk $\boldsymbol{d} \in [0, 6] \times [0, 14]\( becomes \)(d_0, d_1) \mapsto (d_0, d_1)$
- $(d_0, d_1, d_2) \mapsto ((100d_0 + 10d_1 + d_2) /100, ((100d_0 + 10d_1 + d_2) \mod 100) / 10, d_2 \mod 10)\( for \)d_0, d_2 \mod 10)\( for \)d_i\( becomes \)
- $(d_0, d_1, d_2) \mapsto ((16d_0 + 4d_1 + d_2) /8, (16d_0 + 4d_1 + d_2) \mod 8)\( for \)d_i \di [0, 9]\( becomes \)(d_0, d_1, d_2) (d_0, d_1, d_2) (d_0, d_1, d_2)
- \((d_0, d_1) \mapsto (-(-11d_0 - d_1 + 109) / 11 + 9)\) untuk $\boldsymbol{d} \in [0, 9] \times [0, 10]\( becomes \)(d_0, d_1) \mapsto (d_0)$.
Penyederhanaan peta pengindeksan memungkinkan kita memahami bahwa beberapa bentuk ulang berantai di HLO akan saling membatalkan.
p0 = f32[10, 10, 10] parameter(0)
reshape1 = f32[50, 20] reshape(p0)
reshape2 = f32[10, 10, 10] reshape(reshape1)
Setelah komposisi peta pengindeksan dan penyederhannya, kita akan mendapatkan
\((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\).