本文件說明 HLO 索引分析,可讓您以符號方式計算 HLO 運算的索引編製地圖。索引對應是一種函式,可將一個張量的索引對應至另一個張量的索引,例如 HLO 指示輸出的索引和 HLO 指令輸入的索引,反之亦然。
範例
從 tensor<20xf32>
轉到 tensor<10x20x30xf32>
的廣播訊息
p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}
輸出至輸入的索引地圖為 \((i, j, k) \mapsto (j)\) $i \in [0, 10]\(, \)j \in [0, 20]\( and \)k \in [0, 30]$。
動機
XLA GPU 採用多種自訂解決方案,以推斷協調、運算元使用率和並排配置 (詳情請見下文)。索引分析的目標是提供可重複使用的元件,以供這類用途使用。索引分析是以 MLIR 的 Affine Map 基礎架構為基礎,並加入 HLO 語意。
煤炭
假如我們知道輸入的輸入內容元素/配量以便計算輸出元素,那麼記憶體融合的原因就能用於一般情況。
運算元使用率
XLA 的運算元使用率,代表假設每項指令的輸出都完全使用,而對每個輸入內容的用量。目前也未針對一般案例計算使用率。索引分析可準確計算使用率
圖塊
圖塊/配量是經偏移、大小和跨距參數化的張量參數,圖塊傳播是一種方法,可以使用運算本身的傾斜參數計算運算生產端/消費者的圖塊參數。已有用於 softmax 和 dot 的程式庫。如果圖塊的傳播可以透過索引地圖表示,可以採用更通用且簡潔的資訊方塊。
功能與領域
索引對應是一種函式 \(\boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\),可將一個 \(\boldsymbol{d}\) Tensor \(A\) 的多重索引對應至張量元素/範圍 \(B\)。參數 \(\boldsymbol{s}\) 是指張量 \(B\)(而非 Tensor 裡) 中顯示的維度索引範圍。 \(A\).
例如,如果我們從 tensor<2x4x8x16xf32>
縮減為 tensor<4x8xf32>
,則從 2D 輸出到 4D 輸入的索引對應為\((d_0, d_1) \mapsto (s_0, d_0, d_1, s_1)\),其中 \(d_i\) 是與輸出張量指數對應的維度參數。參數 \(s_j\)會對多個值進行編碼,例如,如要計算輸出的 \((d_0, d_1)\) 元素,我們需要輸入 \((s_0, d_0, d_1, s_1)\) 輸入內容元素,其中 \(s_0 \in [0, 2)\) 和\(s_1 \in [0, 16)\)。
這類對應可透過 HLO 指令的屬性或未融合指令的對應,可撰寫,用來建立融合作業的索引。對應也有一個網域,其會指定存在的 Tensor 元素。
\[ \begin{eqnarray} \boldsymbol{f}(\boldsymbol{d}, \boldsymbol{s})\; &s.t.& \\ \boldsymbol{lb}_d &\leq& \boldsymbol{d} \leq \boldsymbol{ub}_d \\ \boldsymbol{lb}_s &\leq& \boldsymbol{s} \leq \boldsymbol{ub}_s \\ \boldsymbol{lb}_g &\leq& \boldsymbol{g}(\boldsymbol{d}, \boldsymbol{s}) \leq \boldsymbol{ub}_g \end{eqnarray} \]
由於我們希望盡量減少重新計算,因此需要一個程式庫來進行符號運算。XLA 已依附於 MLIR,因此我們使用 mlir::AffineMap 而不是編寫符號式算術程式庫。
典型的AffineMap
外觀
(d0)[s0, s1] -> (s0 + 5, d0 * 2, s1 * 3 + 50)
AffineMap
非常方便地提供兩種參數:維度和符號,我們可以分別用於 \(\boldsymbol d\) 和 \(\boldsymbol s\) 。AffineMap
不含維度範圍的任何中繼資料,因此我們必須自行提供這項資料。
struct Range {
int64_t lower_bound;
int64_t upper_bound;
};
struct IndexingMap {
mlir::AffineMap affine_map;
std::vector<Range> dimension_ranges;
std::vector<Range> symbol_ranges;
llvm::DenseMap<mlir::AffineExpr, Range> expr_ranges;
};
dim_ranges
會對索引對應的維度參數 \(\boldsymbol{d}\) 「內含」方塊限制進行編碼,此形狀與運算的輸出張量通常與運算的輸出張量 (例如轉置、減少、元素相關點和圓點) 一致,但仍有一些例外狀況,例如 HloConcatenateInstruction。
symbol_ranges
會編碼 \(\boldsymbol {s}\) 參數可使用的值。
現在讓我們逐一進行研究,瞭解上述各項資料的真正意義。
為未融合的作業建立地圖索引
元素
針對元素作業,索引對應是一種身分。
p0 = f32[10, 20] parameter(0)
p1 = f32[10, 20] parameter(1)
add = f32[10, 20] add(p0, p1)
輸入對應的輸出內容:
- 輸出 -> input_0: \((d_0, d_1) \mapsto (d_0, d_1)\) 適用於 $\boldsymbol{d} \in [0,9] \times [0, 19]\(, i.e. \)\boldsymbol{d} \in {\rm Dom}(output)$
- 輸出 -> input_1: \((d_0, d_1) \mapsto (d_0, d_1)\) 針對 $\boldsymbol{d} \in {\rm Dom} (output)$
輸出對應的輸入內容
- input_i -> 輸出: \((d_0, d_1) \mapsto (d_0, d_1)\) 針對 $\boldsymbol{d} \in {\rm Dom}(output)$
廣播
廣播是指在將輸出內容對應至輸入時,系統會移除部分維度,並在將輸入內容對應至輸出時新增。
p0 = f32[20] parameter(0)
bc0 = f32[10, 20, 30] broadcast(p0), dimensions={1}
輸入對應的輸出內容:
- 輸出 -> input: \((d_0, d_1, d_2) \mapsto (d_1)\) $\boldsymbol{d} \in {\rm Dom}(output)$
輸出對應的輸入內容
- input -> output: \((d_0) \mapsto (s_0, d_1, s_1)\) $\boldsymbol{d} \in {\rm Dom}(output)\( and \)\boldsymbol{s} \in [0, 9] \times [0, 29]$。
請注意,現在 \(\boldsymbol s\) 的輸入至輸出對應為右側。這些是代表值範圍的符號。舉例來說,在這個特定情況下,凡是含有索引 \(d_0\) 的輸入元素,都會對應至輸出的 10x1x30 配量。
常數和 Iota
相對來說,這些 API 沒有任何輸入參數,因此沒有需要計算索引的內容。
轉置
轉置地圖索引是輸入/輸出維度的排列。
p0 = f32[3, 12288, 6, 128] parameter(0)
transpose = f32[3, 6, 128, 12288] transpose(p0), dimensions={0, 2, 3, 1}
輸入對應的輸出內容:
- output -> input: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_3, d_1, d_2)\) for \(\boldsymbol{d} \in {\rm Dom}(output)\)
輸出對應的輸入內容:
- input -> output: \((d_0, d_1, d_2, d_3) \mapsto (d_0, d_2, d_3, d_1)\) for \(\boldsymbol{d} \in {\rm Dom}(input)\)
反手灌籃
建立地圖索引以進行反向變更,將還原後的維度變更為 $upper_bound(d_i) - d_i$:
p0 = f32[1, 17, 9, 9] parameter(0)
reverse = f32[1, 17, 9, 9] reverse(p0), dimensions={1, 2}
輸入對應的輸出內容:
- 輸出 -> 輸入:$(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(輸出)$
輸出對應的輸入內容:
- input -> 輸出:$(d_0, d_1, d_2, d_3) \mapsto (d_0, -d_1 + 16, -d_2 + 8, d_3)\( for \)\boldsymbol{d} \in {\rm Dom}(輸入)$
(嚴重) 遏止
減少誤報具有幾種輸入和多個 int,從輸出到輸入的對應增加了維度的減少量。因此運作方式與廣播類似
p0 = f32[256,10] parameter(0)
p0_init = f32[] constant(-inf)
p1 = s32[256,10] parameter(1)
p1_init = s32[] constant(0)
reduce = (f32[10], s32[10]) reduce(p0, p1, p0_init, p1_init),
dimensions={0}, to_apply=min
輸入對應的輸出內容:
- 輸出 -> input_j: \((d_0) \mapsto (s_0, d_0)\) 針對 $\boldsymbol{d} \in {\rm Dom}(輸出)\( and \)\boldsymbol{s} \in [0, 9]$
- 輸出 -> init_j: \((d_0) \mapsto ()\) $\boldsymbol{d} \in {\rm Dom}(output)$
輸出對應的輸入內容:
- input_i -> output_j: \((d_0, d_1) \mapsto (d_1)\) $\boldsymbol{d} \in {\rm Dom}(input)$
- init_i -> output_j: \(() \mapsto (s_0)\) for \(\boldsymbol{s} \in [0, 9]\)
\(i, j = 0, \ldots, INPUT\\_COUNT\)。
配量
將輸出內容從輸出到配量輸入編入索引時,產生的索引圖中就會有這個對應檔,對於輸出的每個元素都有效。將輸入與輸出的對應會限制在輸入中元素的橫跨範圍。
p0 = f32[10, 20, 50] parameter(0)
slice = f32[5, 3, 25] slice(f32[10, 20, 50] p0),
slice={[5:10:1], [3:20:7], [0:50:2]}
輸入對應的輸出內容:
- output -> input: \((d_0, d_1, d_2) \mapsto (d_0 + 5, 7d_1 + 3, 2d_2)\) for \(\boldsymbol{d} \in {\rm Dom}(output)\)
輸出對應的輸入內容:
- input -> output: \((d_0, d_1, d_2) \mapsto (d_0, d_1 / 7, d_2 / 2)\) 如果是\(\boldsymbol{d} \in [5, 9] \times [3, 19] \times [0, 49]\) ,步距 $[1, 7, 2]$。
待定:輸入至輸出索引
重塑
重新形狀有多種不同口味。
收合圖案
這是從 N-D 到 1D 的「線性化」重塑。
p0 = f32[4,8] parameter(0)
reshape = f32[32] reshape(p0)
輸入對應的輸出內容:
- 輸出 -> 輸入: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) $\boldsymbol{d} \in {\rm Dom}(output)$
輸出對應的輸入內容:
- input -> output: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) $\boldsymbol{d} \in {\rm Dom}(input)$。
展開形狀
這是反向「收合形狀」運算,因此會將 1D 輸入重新轉換成 N-D 輸出。
p0 = f32[32] parameter(0)
reshape = f32[4, 8] reshape(p0)
輸入對應的輸出內容:
- 輸出 -> input: \((d_0, d_1) \mapsto (8 d_0 + d_1)\) $\boldsymbol{d} \in {\rm Dom}(output)$
輸出對應的輸入內容:
- input -> output: \((d_0) \mapsto (d_0 / 8, d_0 \mod 8)\) $\boldsymbol{d} \in {\rm Dom}(input)$。
一般重塑
這些重新形狀運算無法以單一展開或收合形狀表示。只能以 2 個以上的展開或收合形狀表示。
範例 1:線性去線性化。
p0 = f32[4,8] parameter(0)
reshape = f32[2, 4, 4] reshape(p0)
此重新形狀可以表示為 tensor<4x8xf32>
的收合形狀的組合,然後為 tensor<32xf32>
再將形狀展開為 tensor<2x4x4xf32>
。
輸入對應的輸出內容:
- 輸出 -> 輸入:$(d_0, d_1, d_2) \mapsto (2d_0 + (4d_1 + d_2)) / 8、4d_1 + d_2) \mod 8)$
適用於 \(\boldsymbol{d} \in {\rm Dom}(output)\)
輸出對應的輸入內容:
- input -> 輸出:$(d_0, d_1) \mapsto ((8d_0 + d_1) / 16, ((8d_0 + d_1) \mod 16) / 4, d_1 \mod 4)$
\(\boldsymbol{d} \in {\rm Dom}(input)\)。
範例 2:展開及收合的子形狀
p0 = f32[4, 8, 12] parameter(0)
reshape = f32[32, 3, 4] reshape(p0)
這個重新形狀可以代表兩個重新形狀的組合。第一個函式將最外層的 tensor<4x8x12xf32>
收合為 tensor<32x12xf32>
,第二個會展開最內層的維度 tensor<32x12xf32>
為 tensor<32x3x4xf32>
。
輸入對應的輸出內容:
- output -> input: \((d_0, d_1, d_2) \mapsto (d_0 / 8, d_0 \mod 8, 4d_1 + d_2)\) for \(\boldsymbol{d} \in {\rm Dom}(output)\)
輸出對應的輸入內容:
- input -> output: \((d_0, d_1, d_2) \mapsto (8d_0 + d_1, d_2 / 4, d_2 \mod 4)\) for \(\boldsymbol{d} \in {\rm Dom}(input)\)。
Bitcast
位元投放運算能以轉置 - 變形作業的序列表示。因此,對於這個順序的索引地圖,其索引對應只是構成索引的成分。
串連
系統會為所有輸入定義 concat 的輸出至輸入對應關係,但網域不同,也就是說一次只能使用其中一個輸入內容。
p0 = f32[3,50] parameter(0)
p1 = f32[3,30] parameter(1)
concat = f32[3,80] concatenate(f32[3,50] p0, f32[3,30] p1),
dimensions={1}
輸入對應的輸出內容:
- 輸出 -> 輸入 1:
\((d_0, d_1) \mapsto (d_0, d_1)\) \(\boldsymbol{d} \in [0, 2] \times [0, 49]\)
- 輸出 -> 輸入 2:
\((d_0, d_1) \mapsto (d_0, d_1 - 50)\) 的 $\boldsymbol{d} \in [0, 2] \times [50, 79]$
要輸出對應的輸入內容:
- input 1 -> output: \((d_0, d_1) \mapsto (d_0, d_1)\) $\boldsymbol{d} \in {\rm Dom}(input_1)$。
- input 2 -> 輸出: \((d_0, d_1) \mapsto (d_0, d_1 + 50)\) 針對 $\boldsymbol{d} \in {\rm Dom}(input_2)$。
點 (已實作輸出至輸入)
建立點地圖索引與減少點地圖的方式十分相似。
p0 = f32[4, 128, 256] parameter(0)
p1 = f32[4, 256, 64] parameter(1)
dot = f32[4, 128, 64] dot(p0, p1),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2}, rhs_contracting_dims={1}
輸入對應的輸出內容:
- 輸出 -> input_1: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) 針對\(\boldsymbol{d} \in {\rm Dom}(output)\) 和 \(\boldsymbol{s} \in [0, 255]\)
- 輸出 -> input_2: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_2)\) 針對\(\boldsymbol{d} \in {\rm Dom}(output)\) 和 \(\boldsymbol{s} \in [0, 255]\)
輸出對應的輸入內容:
- input_1 -> 輸出: \((d_0, d_1, d_2) \mapsto (d_0, d_1, s_0)\) 適用於\(\boldsymbol{d} \in {\rm Dom}(input_1)\) 和 \(\boldsymbol{s} \in [0, 63]\)
- input_2 -> 輸出: \((d_0, d_1, d_2) \mapsto (d_0, s_0, d_1)\) 適用於\(\boldsymbol{d} \in {\rm Dom}(input_2)\) 和 \(\boldsymbol{s} \in [0, 127]\)
縮短期間 (待定)
墊片 (未定)
使用 Fusion 建立地圖索引
「Fusion op 索引地圖索引建立」是叢集中所有運算的索引對應組合。您可能會以不同的存取模式多次讀取部分輸入內容。
一種輸入內容,數個索引對應圖
以下是 \(p_0 + p_0^T\)的範例
f {
p0 = f32[1000, 1000] parameter(0)
transpose_p0 = f32[1000, 1000]{0, 1} transpose(p0), dimensions={1, 0}
ROOT a0 = f32[1000, 1000] add(p0, transpose_p0)
}
p0
的輸出至輸入索引對應會是 $(d_0, d_1) \mapsto (d_0,
d_1)\( and \)(d_0, d_1) \mapsto (d_1, d_0)$。這表示如要計算輸出結果的一個元素,我們可能需要讀取輸入參數兩次。
一個輸入內容,簡化後的索引建立對應
在某些情況下,雖然索引地圖無法立刻看出,但索引地圖實際上完全相同。
f {
p0 = f32[20, 10, 50] parameter(0)
lhs_transpose_1 = f32[10, 20, 50] transpose(p0), dimensions={1, 0, 2}
lhs_e = f32[10, 20, 50] exponential(lhs_transpose_1)
lhs_transpose_2 = f32[10, 50, 20] transpose(lhs_e), dimensions={0, 2, 1}
rhs_transpose_1 = f32[50, 10, 20] transpose(p0), dimensions={2, 1, 0}
rhs_log = f32[50, 10, 20] exponential(rhs_transpose_1)
rhs_transpose_2 = f32[10, 50, 20] transpose(rhs_log), dimensions={1, 0, 2}
ROOT add = f32[10, 50, 20] add(lhs_transpose_2, rhs_transpose_2)
}
在此情況下,p0
的輸出至輸入索引對應只有 $(d_0, d_1, d_2)\mapsto (d_2, d_0, d_1)$。
Softmax
softmax 的 parameter 0
輸出至輸入索引對應:
- \((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\)
- \((d_0, d_1, d_2)[s_0] \mapsto (d_0, d_1, s_0)\)
\(\boldsymbol{d} \in {\rm Dom}(output)\) 和 \(\boldsymbol{s} \in [0, 124]\)都是指輸入的最內層維度。
索引地圖簡化工具
mlir::AffineMap
上游的預設簡化器無法針對維度/符號範圍做出任何假設。因此無法有效使用 mod
和 div
簡化運算式。
我們可以利用關聯圖中子運算式的下限和上限知識來簡化這些結構。
這個簡化程式可以重寫下列運算式。
- \((d_0, d_1) \mapsto (d_0 + d1 / 16, d1 \mod 16)\) 的 $\boldsymbol{d} \in [0, 6] \times [0, 14]\( becomes \)(d_0, d_1) \mapsto (d_0, d_1)$
- $(d_0, d_1, d_2) \mapsto ((100d_0 + 10d_1 + d_2) /100, ((100d_0 + 10d_1 + d_2) \mod 100) / 10, d_2 \mod 9d_10\( for \)\( becomes \)
- $(d_0, d_1, d_2) \mapsto ((16d_0 + 4d_1 + d_2) /8, (16d_0 + 4d_1 + d_2) \mod 8)\( for \)d_i \in [0, 9]\( becomes \)(d_0, d_1, d_2) d_1
- \((d_0, d_1) \mapsto (-(-11d_0 - d_1 + 109) / 11 + 9)\) 適用於 $\boldsymbol{d} \in [0, 9] \times [0, 10]\( becomes \)(d_0, d_1) \mapsto (d_0)$。
索引地圖簡化器可讓我們瞭解 HLO 中的部分鏈結重新形狀會相互取消。
p0 = f32[10, 10, 10] parameter(0)
reshape1 = f32[50, 20] reshape(p0)
reshape2 = f32[10, 10, 10] reshape(reshape1)
完成索引地圖的組成並簡化之後,
\((d_0, d_1, d_2) \mapsto (d_0, d_1, d_2)\).