傳播

總覽

分割區傳播會使用使用者指定的分割區,推斷未指定的張量分割區 (或張量的特定維度)。它會在兩個方向中逐一檢查運算圖的資料流程 (使用定義鏈結),直到達到固定點為止,也就是說,如果不撤銷先前的區隔決策,區隔作業就無法再變更。

傳播可分解為多個步驟。每個步驟都會根據該運算的特性,查看特定運算並在張量 (運算元和結果) 之間傳播。以 matmul 為例,我們會在左側或右側的非收縮維度之間,傳播至結果的對應維度,或是左側和右側的收縮維度之間。

運算的特性會決定輸入和輸出中對應維度的連結,並可抽象化為每個運算的區塊處理規則

如果沒有衝突解析,傳播步驟會一律傳播,同時忽略衝突的軸線;我們稱之為 (最長) 相容的主要分割軸。

縝密設計

衝突解決階層

我們在階層中組合多種衝突解決策略:

  1. 使用者定義的優先順序。在「區隔表示法」一文中,我們說明瞭如何將優先順序附加至維度區隔,以便逐步分割程式,例如執行批次並行處理 -> megatron -> ZeRO 區隔。這可透過在迭代中套用傳播方式達成,在迭代 i 時,我們會傳播所有具有優先順序 <=i 的維度分割,並忽略所有其他維度分割。我們也會確保傳播作業不會覆寫使用者定義的較低優先順序 (>i) 分割作業,即使在先前的迭代期間忽略這些作業也一樣。
  2. 以作業為依據的優先順序。我們會根據作業類型傳播分割作業。「傳遞」作業 (例如元素作業和重塑) 的優先順序最高,而形狀轉換作業 (例如 dot 和 reduce) 的優先順序較低。
  3. 積極傳播。使用積極策略來傳播分割作業。基本策略只會傳播沒有衝突的切割,而積極策略則會解決衝突。設定較高的積極度雖然可減少記憶體占用空間,但可能會影響潛在的通訊。
  4. 基本傳播方式。這是階層中最低層級的傳播策略,不會進行任何衝突解決,而是在所有運算元與結果之間傳播相容的軸。

傳播階層,從下到上顯示 4 個堆疊,並附有以下標籤:基本傳播、積極傳播、作業優先順序傳播、使用者優先順序傳播。

這個階層可解讀為巢狀的 for 迴圈。例如,針對每個使用者優先順序,會套用完整的作業優先順序傳播。

作業分割規則

分割規則會引入每個運算的抽象概念,為實際傳播演算法提供所需資訊,以便從運算元傳播分割作業至結果,或跨運算元傳播,而無須推論特定運算類型及其屬性。這項功能基本上是將特定作業的邏輯分離,並為所有作業提供共用表示法 (資料結構),僅供傳播用。在最簡單的形式中,它只提供以下函式:

GetOpShardingRule(Operation *) -> OpShardingRuleAttr

這項規則可讓我們以通用方式 (以此資料結構為依據,也就是 OpShardingRule) 編寫傳播演算法,而不需要在多個作業中複製類似的程式碼,大幅降低作業中發生錯誤或行為不一致的可能性。

讓我們回到 matmul 範例。

封裝在傳播期間所需資訊 (也就是維度之間的關係) 的編碼,可以採用 einsum 符號的形式編寫:

(i, k), (k, j) -> (i, j)

在這種編碼中,每個維度都會對應至單一因子。

傳播如何使用這項對應:如果運算元的維度/結果是沿著軸分割,傳播作業會在這個對應中查詢該維度的因數,並以相同因數沿著各自的維度分割其他運算元/結果,並且 (依據先前討論的複製作業) 可能也會沿著該軸複製其他沒有該因數的運算元/結果。

複合因子:擴充重塑規則

在許多運算中 (例如 matmul),我們只需要將每個維度對應至單一因數。不過,這對於重塑而言是不夠的。

以下重塑作業會將兩個維度合併為一個:

%out = mhlo.reshape(%in) : (tensor<2x4x32xf32>) -> tensor<8x32xf32>

在此輸入的維度 0 和 1 都對應至輸出的維度 0。假設我們先為輸入內容提供因數:

(i,j,k) : i=2, j=4, k=32

您可以看到,如果我們想在輸出內容中使用相同的因素,就需要單一維度來參照多個因素:

(i,j,k) -> ((ij), k) : i=2, j=4, k=32

如果重塑是用來分割維度,也可以執行相同的操作:

%out = mhlo.reshape(%in) : (tensor<8x32xf32>) -> tensor<2x4x32xf32> ((ij), k) -> (i,j,k) : i=2, j=4, k=32

這裡的 8 大小維度基本上由 2 和 4 因數組成,因此我們稱之為 (i,j,k) 因數。

這些因素也適用於沒有任何完整維度對應其中一個因素的情況:

%out = mhlo.reshape(%in) : (tensor<8x4xf32>) -> tensor<2x16xf32> ((ij), k) -> (i,(jk)) : i=2, j=4, k=4

這個範例也強調了為何我們需要儲存因數大小,因為我們無法輕易從對應的尺寸推斷出這些值。

核心傳播演算法

沿著因素傳播分割作業

Shardy 包含張量、維度和因子的階層。代表不同層級的資料。因子是子維度。這是用於區塊傳播作業的內部階層。每個維度可能對應一或多個因素。維度和因子之間的對應關係由 OpShardingRule 定義。

顯示 Shardy 傳播演算法的結構定義。

Shardy 會沿著因數 (而非維度) 傳播區塊切割軸。如要執行這項操作,我們有三個步驟,如下圖所示

  1. 將專案 DimSharding 改為 FactorSharding
  2. 在 FactorSharding 的空間中傳播分割軸
  3. 投射更新後的 FactorSharding,以取得更新後的 DimSharding

顯示在 FactorSharding 和 DimSharding 之間進行分割區傳播的結構定義。

以圖表呈現分割傳播的因素

我們會使用下表來呈現區塊化傳播問題和演算法。

F0 F1 F2 明確複製的軸
T0
T1
T2
  • 每個資料欄都代表一個因素。F0 代表索引為 0 的因子。我們會沿著因數 (欄) 傳播分割作業。
  • 每列代表一個張量。T0 是索引為 0 的張量。張量是特定運算涉及的所有運算元與結果。資料列中的軸不得重疊。一個軸 (或子軸) 無法用於多次分割一個張量。如果明確複製軸,我們就無法使用該軸分割張量。

因此,每個儲存格都代表一個因數分割。部分張量中可能會缺少因子。C = dot(A, B) 的資料表如下所示。含有 N 的儲存格表示因數不在張量中。例如,F2 位於 T1 和 T2,但不在 T0 中。

C = dot(A, B) F0 批次處理暗淡 F1 非契約式維度 F2 非收縮 dim F3 合約 dim 明確複製的軸
T0 = A
T1 = B
T2 = C

收集及傳播分割軸

我們會使用下列簡單的範例,以視覺化方式呈現傳播情形。

F0 F1 F2 明確複製的軸
T0 「a」 「f」
T1 「a」和「b」 「c」和「d」 「g」
T2 「c」和「e」

步驟 1:找出沿著每個因素傳播的軸 (又稱為 (最長) 相容的主要區隔軸)。在本例中,我們會沿著 F0 傳播 ["a", "b"]、沿著 F1 傳播 ["c"],並沿著 F2 傳播空值。

步驟 2:展開因子分割作業,即可取得下列結果。

F0 F1 F2 明確複製的軸
T0 「a」, "b" "c" 「f」
T1 「a」和「b」 「c」和「d」 「g」
T2 "a", "b" 「c」和「e」