形狀和版面配置

XLA 運算元的結構

請參考以下 HLO 範例:

add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)}
          add(exponential.183, broadcast.3115)

這包括下列元件:

  • Op Name: add.936
    • 這是作業的專屬名稱。
  • 形狀:bf16[8,1,1280,16384]
    • 這是 Op 的輸出形狀。這裡的 dtype 為 bf16,形狀為 [8,1,1280,16384]
  • 版面配置 (含並排):3,2,0,1:T(8,128)(2,1)
    • 這說明陣列在記憶體中的儲存方式。3,2,0,1 表示記憶體中的軸順序 (即以欄為主、以列為主等),T(8,128)(2,1) 則表示使用的分塊和填補。
    • 版面配置為選用項目,如未指定,系統就不會進行分格,並假設維度是從最主要到最次要排序。
  • 作業:add
    • 正在執行的作業。這裡為「Add」,Op 名稱中也會提及。
  • 引數:exponential.183broadcast.3115
    • 這項作業會採用兩個引數,並以專屬名稱指定。

我們來看另一個例子,也就是融合作業:

%fusion.3 = bf16[32,32,4096]{2,1,0:T(8,128)(2,1)S(1)}
            fusion(bf16[32,32,8192]{2,1,0:T(8,128)(2,1)S(1)} %fusion.32),
            kind=kCustom, calls=%all-reduce-scatter.3

除了先前說明的元件,這項服務還包含:

  • 屬性:kindcalls
    • 這些項目提供更多關於所執行作業的資訊,在本例中為融合。
  • 記憶體位置 (記憶體空間 ID):S(1)
    • 這表示陣列的儲存記憶體空間/位置。S(1) 表示這個陣列位於 VMEM (在 TPU 上)。
  • 輸入引數 %fusion.32 的形狀和版面配置詳細資料

以下各節說明形狀、版面配置記憶體空間 ID。如要進一步瞭解並排顯示功能,請參閱這篇文章

形狀

XLA ShapeProto proto (xla_data.proto) 說明 N 維陣列 (簡稱陣列) 的維度數量、大小和資料類型。

術語、符號和慣例

  • 陣列的實際維度數量是指大小大於 1 的維度數量。

  • 維度的編號從 0 開始,最多可達 N 維陣列的 N-1。 維度大小為非負整數。特別是大小為 0 的情況。維度編號是為方便起見而任意標示的標籤,這些維度數字的順序並不代表形狀版面配置中的特定次要/主要順序。版面配置是由 LayoutProto Proto 決定。

  • 按照慣例,維度會依維度編號遞增順序列出。舉例來說,如果大小為 [A x B x C] 的 3 維陣列,維度 0 的大小為 A,維度 1 的大小為 B,維度 2 的大小為 C

    XLA 中的部分公用程式也支援類似 Python 的負數索引:維度 -1 是最後一個維度 (相當於 N 維度陣列的 N-1)。舉例來說,在上述 3 維陣列中,維度 -1 的大小為 C,維度 -2 的大小為 B,依此類推。

  • 二維、三維和四維陣列通常會與維度相關聯的特定字母。舉例來說,如果是 2D 陣列:

    • 維度 0:y
    • 維度 1:x

    如果是 3D 陣列:

    • 維度 0:z
    • 維度 1:y
    • 維度 2:x

    4D 陣列:

    • 維度 0:p
    • 維度 1:z
    • 維度 2:y
    • 維度 3:x
  • XLA API 中採用維度的函式會依維度編號遞增排序。這與將維度做為 initializer_list 傳遞時所用的順序相符,例如:

    ShapeUtil::MakeShape(F32, {A, B, C, D})

    會建立形狀,其維度大小陣列由 [A, B, C, D] 序列組成。

版面配置

LayoutProto proto 說明陣列在記憶體中的表示方式。當中包含下列欄位:

message LayoutProto {
  repeated int64 minor_to_major;
  int64 tail_padding_alignment_in_elements;
  ...
}

次要到主要維度的排序

唯一的必填欄位是 minor_to_major。這個欄位說明形狀內維度的次要到主要排序方式。minor_to_major 中的值是陣列的維度排序 (N 維度陣列為 0N-1),第一個值是最小維度,最後一個值是最大維度。最次要的維度是當逐步處理以線性記憶體配置的陣列元素時,變化最快的維度。

舉例來說,請考量以下大小為 [2 x 3] 的 2D 陣列:

a b c
d e f

這裡的維度 0 大小為 2,維度 1 大小為 3。如果版面配置中的 minor_to_major 欄位為 [0, 1],則維度 0 是最次要的維度,而維度 1 是最主要的維度。這對應於線性記憶體中的下列版面配置:

a d b e c f

這個從次要到主要維度的 0N-1 順序,類似於以資料欄為主 (適用於 2 維)。假設維度是單調遞增排序,我們在程式碼中也可以簡單地將這個版面配置稱為「維度 0 是次要維度」。

另一方面,如果版面配置中的 minor_to_major 欄位是 [1, 0],則線性記憶體中的版面配置為:

a b c d e f

N 維度陣列的次要到主要維度順序 (從 N-10),類似於以列為主 (適用於 2 維)。假設維度排序為單調遞增,我們在程式碼中也可以簡單地將這個版面配置稱為「維度 0 是主要維度」。

預設次要至主要排序方式

新建立的「形狀」預設版面配置為「維度順序為從主要到次要」(即 [N-1, ..., 0])。

邊框間距

tail_padding_alignment_in_elements 欄位會根據元素數量定義分塊陣列的對齊方式。套用平鋪後,系統會在版面配置結尾新增填補元素,直到元素總數是這個值的倍數為止。

將陣列編入索引

index_util.h 中的 IndexUtil 類別提供公用程式,用於在形狀和版面配置之間轉換多維度索引和線性索引。多維度索引包含每個維度的 int64 索引。線性索引是單一 int64 值,可索引至保存陣列的緩衝區。如要簡化形狀和版面配置的建立及操控程序,請參閱同一目錄中的 shape_util.hlayout_util.h

記憶體空間 ID

在 HLO 中,每個陣列都可以使用記憶體空間 ID 註解,寫成 S(n)。

  • S(0) (通常會省略) 表示裝置高頻寬記憶體 (HBM)。
  • S(1) 代表裝置上的虛擬記憶體 (VMEM)。
  • S(2)S(3) 等對應於額外的裝置專屬記憶體空間。
  • S(5) 表示主機記憶體。