本文說明 XLA 的廣播語意。
什麼是廣播?
廣播是指讓形狀不同的陣列具有相容的形狀,以便進行算術運算。這個術語是借用自 NumPy 廣播。
如果要在不同等級的多維度陣列之間,或在形狀不同但相容的多維度陣列之間執行運算,可能需要廣播。以加法 X+v 為例,其中 X 是矩陣 (具有 2 個維度的陣列),而 v 是向量 (具有 1 個維度的陣列)。如要執行元素加法,XLA 必須「廣播」向量 v,使維度數量與矩陣 X 相同,也就是將 v 複製特定次數。向量的長度必須至少符合矩陣的其中一個維度。
例如:
|1 2 3| + |7 8 9|
|4 5 6|
矩陣的維度為 (2,3),向量的維度為 (3)。向量會透過在資料列上複製的方式廣播,以取得下列內容:
|1 2 3| + |7 8 9| = |8 10 12|
|4 5 6| |7 8 9| |11 13 15|
原則
XLA 語言盡可能嚴格且明確,避免隱含的「神奇」功能。這類功能可能會稍微簡化某些運算定義,但代價是使用者程式碼中會納入更多假設,長期而言難以變更。如有需要,可以在用戶端層級包裝函式中加入隱含的魔法功能。
就廣播而言,XLA 需要在不同等級的陣列間運算時,明確指定廣播。這與 NumPy 不同,後者會在可能的情況下推斷規格。
將低維度陣列廣播至高維度陣列
純量一律可透過陣列播送,無須明確指定播送維度。純量和陣列之間的元素二元運算,是指對陣列中的每個元素套用純量運算。舉例來說,將純量加到矩陣中,表示要產生一個矩陣,其中每個元素都是純量和輸入矩陣中對應元素的總和。
|1 2 3| + 7 = |8 9 10|
|4 5 6| |11 12 13|
大多數廣播需求都可以透過二元運算中的維度元組擷取。當作業的輸入內容具有不同等級時,這個廣播元組會指定要將高維度陣列中的哪些維度與低維度陣列相符。
請參考上一個範例。請將維度為 (3) 的向量加到維度為 (2,3) 的矩陣,而不是將純量加到 (2,3) 矩陣。如未指定廣播,這項作業無效。如要正確要求矩陣向量加法,請將廣播維度指定為 (1),也就是將向量的維度與矩陣的維度 1 相符。在 2D 中,如果維度 0 代表資料列,維度 1 代表資料欄,這表示向量的每個元素都會成為資料欄,大小與矩陣中的資料列數量相符:
|7 8 9| ==> |7 8 9|
|7 8 9|
以更複雜的例子來說,請考慮將 3 元素向量 (維度 (3)) 新增至 3x3 矩陣 (維度 (3,3))。在這個範例中,有兩種方式可以進行廣播:
(1) 可使用廣播維度 1。每個向量元素都會成為資料欄,且向量會針對矩陣中的每個資料列重複。
|7 8 9| ==> |7 8 9|
|7 8 9|
|7 8 9|
(2) 可使用廣播維度 0。每個向量元素都會成為資料列,且向量會針對矩陣中的每個資料欄重複。
|7| ==> |7 7 7|
|8| |8 8 8|
|9| |9 9 9|
廣播維度可以是元組,用來描述如何將較小維度的形狀廣播至較大維度的形狀。舉例來說,假設有 2x3x4 的長方體和 3x4 的矩陣,廣播元組 (1,2) 表示矩陣與長方體的維度 1 和 2 相符。
如果提供 broadcast_dimensions 引數,則 XlaBuilder 中的二進位運算會使用這類廣播。例如,請參閱 XlaBuilder::Add。在 XLA 原始碼中,這類廣播有時稱為「InDim」廣播。
正式定義
廣播屬性可指定要比對高維度陣列的哪些維度,藉此將低維度陣列與高維度陣列比對。舉例來說,對於維度為 MxNxPxQ 的陣列,維度為 T 的向量可以比對如下:
MxNxPxQ
dim 3: T
dim 2: T
dim 1: T
dim 0: T
在上述兩種情況中,T 都必須等於高維度陣列的相符維度。然後,系統會將向量的值從相符的維度,播送至所有其他維度。
如要將 TxV 矩陣對應至 MxNxPxQ 陣列,請使用一對廣播維度:
MxNxPxQ
dim 2,3: T V
dim 1,2: T V
dim 0,3: T V
etc...
廣播元組中的維度順序,必須與低維度陣列的維度預期與高維度陣列維度相符的順序相同。元組中的第一個元素會指定高維度陣列中的哪個維度必須與低維度陣列中的維度 0 相符。元組中的第二個元素會指定高維度陣列中的哪個維度必須與低維度陣列中的維度 1 相符,依此類推。廣播維度的順序必須嚴格遞增。舉例來說,在先前的範例中,將 V 對應至 N 和 T 對應至 P 都是無效的,將 V 對應至 P 和 N 也是無效的。
使用退化維度廣播類似維度的陣列
相關問題是廣播兩個維度數量相同,但維度大小不同的陣列。與 NumPy 相同,只有在陣列相容時,才能執行這項操作。如果兩個陣列的所有維度都相容,這兩個陣列就相容。如果符合下列條件,兩個維度即為相容:
- 兩者相等,或
- 其中一個是 1 (「退化」維度)
如果遇到兩個相容的陣列,結果形狀在每個維度索引中,都會是兩個輸入的最大值。
範例:
- (2,1) 和 (2,3) 會向 (2,3) 廣播。
- (1,2,5) 和 (7,2,5) 會向 (7,2,5) 廣播。
- (7,2,5) 和 (7,1,5) 會向 (7,2,5) 廣播。
- (7,2,5) 和 (7,2,6) 不相容,因此無法廣播。
如果每個輸入陣列在不同索引處都有退化維度,就會出現特殊情況,但這也受到支援。在本例中,結果是「外部運算」:(2,1) 和 (1,3) 廣播至 (2,3)。如需更多範例,請參閱 NumPy 廣播說明文件。
廣播組合
將低維度陣列廣播至高維度陣列和使用退化維度廣播,都可以在同一個二元運算中執行。舉例來說,大小為 4 的向量和大小為 1x2 的矩陣可使用值為 (0) 的廣播維度加總:
|1 2 3 4| + [5 6] // [5 6] is a 1x2 matrix, not a vector.
首先,向量會使用廣播維度廣播至 2 維 (矩陣)。廣播維度中的單一值 (0) 表示向量的維度零與矩陣的維度零相符。這會產生大小為 4xM 的矩陣,其中 M 的值會與 1x2 陣列中的對應維度大小相符。因此會產生 4x2 矩陣:
|1 1| + [5 6]
|2 2|
|3 3|
|4 4|
然後「退化維度廣播」會廣播 1x2 矩陣的維度零,以符合右側的對應維度大小:
|1 1| + |5 6| |6 7|
|2 2| + |5 6| = |7 8|
|3 3| + |5 6| |8 9|
|4 4| + |5 6| |9 10|
更複雜的例子是將大小為 1x2 的矩陣加入大小為 4x3x1 的陣列,並使用 (1, 2) 的廣播維度。首先,1x2 矩陣會使用廣播維度,廣播至最多 3 個維度,產生中介 Mx1x2 陣列,其中維度大小 M 取決於較大運算元 (4x3x1 陣列) 的大小,產生 4x1x2 中介陣列。M 位於維度 0 (最左側的維度),因為維度 1 和 2 會對應至原始 1x2 矩陣的維度,而廣播維度為 (1, 2)。使用退化維度的廣播,即可將這個中繼陣列新增至 4x3x1 矩陣,產生 4x3x2 陣列結果。