傳播

總覽

區塊傳播會使用使用者指定的區塊,推斷未指定的張量區塊 (或張量的特定維度)。它會在兩個方向上遍歷運算圖的資料流程 (使用定義鏈結),直到達到固定點為止,也就是說,如果不撤銷先前的區隔決策,區隔就無法再變更。

傳播作業可分解為多個步驟。每個步驟都會根據該運算的特性,查看特定運算並在張量 (運算元和結果) 之間傳播。以 matmul 為例,我們會在左側或右側的非收縮維度之間,傳播至結果的對應維度,或是在左側和右側的收縮維度之間。

運算的特性會決定輸入和輸出內容中相應維度的連結,並可抽象化為每個運算的分割規則

如果沒有衝突解析,傳播步驟會盡可能傳播,同時忽略衝突的軸線;我們稱之為 (最長) 相容的主要區隔軸。

縝密設計

衝突解決階層

我們在階層中組合多種衝突解決策略:

  1. 使用者定義的優先順序。在「分割表示法」一文中,我們說明瞭如何將優先順序附加至維度分割作業,以便逐步分割程式,例如執行批次並行處理 -> megatron -> ZeRO 分割作業。這可透過在迭代中套用傳播方式達成,在迭代 i 時,我們會傳播所有具有優先順序 <=i 的維度分割,並忽略所有其他維度分割。我們也確保傳播不會覆寫使用者定義的較低優先順序 (>i) 分割作業,即使在先前迭代期間忽略這些作業也一樣。
  2. 以作業為依據的優先順序。我們會根據作業類型傳播分割作業。「傳遞」作業 (例如元素作業和重塑) 的優先順序最高,而形狀轉換作業 (例如 dot 和 reduce) 的優先順序較低。
  3. 積極傳播。使用積極策略來傳播分割作業。基本策略只會傳播沒有衝突的切割,而積極策略則會解決衝突。更積極的做法雖然可以減少記憶體占用空間,但可能會影響潛在的通訊。
  4. 基本傳播方式。這是階層中最低層級的傳播策略,不會進行任何衝突解決,而是在所有運算元與結果之間傳播相容的軸。

傳播階層,從下到上顯示 4 個堆疊,標籤如下:基本傳播、積極傳播、作業優先順序傳播、使用者優先順序傳播。

這個階層可解讀為巢狀的 for 迴圈。例如,針對每個使用者優先順序,會套用完整的作業優先順序傳播。

作業分割規則

分割規則會引入每個運算的抽象概念,為實際傳播演算法提供所需資訊,以便從運算元傳播分割作業至結果,或跨運算元傳播,而無須考量特定運算類型及其屬性。這項操作本質上是將特定操作的邏輯分離,並為所有操作提供共用表示法 (資料結構),僅供傳播用途。在最簡單的形式中,它只提供以下函式:

GetOpShardingRule(Operation *) -> OpShardingRuleAttr

這項規則可讓我們以基於此資料結構 (OpShardingRule) 的一般方式編寫傳播演算法,而非在許多作業中複製類似的程式碼片段,大幅降低作業中發生錯誤或行為不一致的可能性。

讓我們回到 matmul 範例。

將在傳播期間所需的資訊 (也就是維度之間的關係) 封裝在編碼中,可以使用 einsum 符號的形式編寫:

(i, k), (k, j) -> (i, j)

在這個編碼中,每個維度都會對應至單一因素。

傳播如何使用此對應:如果運算元/結果的維度沿著軸分割,傳播作業會在這個對應中查詢該維度的因數,並以相同因數沿著各自的維度分割其他運算元/結果,並且 (依據先前討論的複製作業) 可能也會沿著該軸複製其他沒有該因數的運算元/結果。

複合因子:擴充重塑規則

在許多運算中 (例如 matmul),我們只需要將每個維度對應至單一因數。不過,這對重塑而言是不夠的。

以下重塑作業會將兩個維度合併為一個:

%out = stablehlo.reshape(%in) : (tensor<2x4x32xf32>) -> tensor<8x32xf32>

在此情況下,輸入內容的維度 0 和 1 都對應至輸出內容的維度 0。假設我們先為輸入內容提供因數:

(i,j,k) : i=2, j=4, k=32

您可以看到,如果我們想在輸出內容中使用相同的因素,就需要單一維度來參照多個因素:

(i,j,k) -> ((ij), k) : i=2, j=4, k=32

如果重塑是用來分割維度,也可以執行相同的操作:

%out = stablehlo.reshape(%in) : (tensor<8x32xf32>) -> tensor<2x4x32xf32>

請注意,

((ij), k) -> (i,j,k) : i=2, j=4, k=32

這裡的 8 大小維度基本上由 2 和 4 因數組成,因此我們將這些因數稱為 (i,j,k) 因數。

這些因素也適用於沒有任何完整維度對應至其中一個因素的情況:

%out = stablehlo.reshape(%in) : (tensor<8x4xf32>) -> tensor<2x16xf32>
// ((ij), k) -> (i,(jk)) : i=2, j=4, k=4

這個範例也強調為何我們需要儲存因數大小,因為我們無法輕易從對應的尺寸推斷出這些值。

核心傳播演算法

沿著因數傳播分割

Shardy 中包含張量、維度和因子的階層。代表不同層級的資料。因子是子維度。這是用於區塊傳播作業的內部階層。每個維度可能對應一或多個因素。維度和因子之間的對應是由 OpShardingRule 定義。

顯示 Shardy 傳播演算法的結構定義。

Shardy 會沿著因數 (而非維度) 傳播區塊切割軸。如要執行這項操作,我們有三個步驟,如下圖所示:

  1. DimSharding 專案至 FactorSharding
  2. FactorSharding 的空間中散發區塊處理軸
  3. 投射更新後的 FactorSharding,以取得更新後的 DimSharding

顯示在 FactorSharding 和 DimSharding 之間進行切割傳播的結構定義。

以圖表呈現沿著因素傳播的分割區

我們將使用下表來呈現分割區傳播問題和演算法。

F0 F1 F2 明確複製的軸
T0
T1
T2
  • 每個欄代表一個因素。F0 代表索引為 0 的因子。我們會沿著因數 (欄) 傳播分割作業。
  • 每列代表一個張量。T0 是索引為 0 的張量。張量是特定運算涉及的所有運算元和結果。資料列中的軸不得重疊。一個軸 (或子軸) 無法用於多次分割單一張張量。如果明確複製軸,我們就無法使用該軸分割張量。

因此,每個儲存格都代表一個因數分割。部分張量中可能會缺少因子。C = dot(A, B) 的資料表如下所示。含有 N 的儲存格表示因數不在張量中。例如,F2 位於 T1 和 T2,但不在 T0 中。

C = dot(A, B) F0 批次處理暗淡 F1 非契約式維度 F2 非收縮型區間 F3 合約 dim 明確複製的軸
T0 = A
T1 = B
T2 = C

收集及傳播區塊劃分軸

我們會使用下列簡單的範例,以視覺化方式呈現傳播情形。

F0 F1 F2 明確複製的軸
T0 「a」 「f」
T1 「a」和「b」 「c」和「d」 「g」
T2 「c」和「e」

步驟 1:找出沿著每個因素傳播的軸 (又稱為 (最長) 相容的主要區隔軸)。在本例中,我們會沿著 F0 傳播 ["a", "b"]、沿著 F1 傳播 ["c"],並沿著 F2 傳播空值。

步驟 2:展開因子切割,即可取得下列結果。

F0 F1 F2 明確複製的軸
T0 「a」, "b" "c" 「f」
T1 「a」和「b」 「c」和「d」 「g」
T2 "a", "b" 「c」和「e」

資料流作業

上述傳播步驟說明適用於大多數作業。不過,在某些情況下,分割規則可能不適用。在這些情況下,Shardy 會定義資料流作業。

某個 op X 的資料流邊緣會在一系列來源和一系列目標之間定義橋接,以便所有來源和目標都以相同方式分割。這類作業的例子包括 stablehlo::OptimizationBarrierOpstablehlo::WhileOpstablehlo::CaseOpsdy::ManualComputationOp。最終,任何實作 ShardableDataFlowOpInterface 的作業都會視為資料流作業。

一個運算可擁有多個彼此垂直的資料流邊緣。例如:

    y_0, ..., y_n = while (x_0, ..., x_n)
                    ((pred_arg_0,... , pred_arg_n) { ... })
                    ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                    })

這個 while 運算子有 n 資料流動邊緣:第 i 個資料流動邊緣位於來源 x_ireturn_value_i 和目標 y_ipred_arg_ibody_arg_i 之間。

Shardy 會在資料流程邊緣的所有來源和目標之間傳播分割作業,就好像是使用來源做為運算元,目標做為結果,以及身分 sdy.op_sharding_rule 的一般運算一樣。也就是說,正向傳播是從來源傳播到目標,而反向傳播則是從目標傳播到來源。

使用者必須實作多種方法,說明如何透過擁有者取得每個資料流邊緣的來源和目標,以及如何取得及設定邊緣擁有者的分割區。擁有者是 Shardy 傳播機制所使用的資料流邊緣的使用者指定目標。使用者可以任意選擇,但必須是靜態值。

舉例來說,假設 custom_op 的定義如下:

  y_1, ..., y_n = custom_op (x_1, ..., x_n)
                  ((body_arg_1,..., body_arg_n) {
                    ...
                    return return_value_1, ..., return_value_n
                  })

這個 custom_op 有兩種資料流邊緣類型:return_value_i (來源) 和 y_i (目標) 之間的 n 邊緣,以及 x_i (來源) 和 body_arg_i (目標) 之間的 n 邊緣。在這種情況下,邊緣擁有者與目標相同。