總覽
分割區傳播會使用使用者指定的分割區,推斷未指定的張量分割區 (或張量的特定維度)。它會在兩個方向中逐一檢查運算圖的資料流程 (使用定義鏈結),直到達到固定點為止,也就是說,如果不撤銷先前的區隔決策,區隔作業就無法再變更。
傳播可分解為多個步驟。每個步驟都會根據該運算的特性,查看特定運算並在張量 (運算元和結果) 之間傳播。以 matmul 為例,我們會在左側或右側的非收縮維度之間,傳播至結果的對應維度,或是左側和右側的收縮維度之間。
運算的特性會決定輸入和輸出中對應維度的連結,並可抽象化為每個運算的區塊處理規則。
如果沒有衝突解析,傳播步驟會一律傳播,同時忽略衝突的軸線;我們稱之為 (最長) 相容的主要分割軸。
縝密設計
衝突解決階層
我們在階層中組合多種衝突解決策略:
- 使用者定義的優先順序。在「區隔表示法」一文中,我們說明瞭如何將優先順序附加至維度區隔,以便逐步分割程式,例如執行批次並行處理 -> megatron -> ZeRO 區隔。這可透過在迭代中套用傳播方式達成,在迭代
i
時,我們會傳播所有具有優先順序<=i
的維度分割,並忽略所有其他維度分割。我們也會確保傳播作業不會覆寫使用者定義的較低優先順序 (>i
) 分割作業,即使在先前的迭代期間忽略這些作業也一樣。 - 以作業為依據的優先順序。我們會根據作業類型傳播分割作業。「傳遞」作業 (例如元素作業和重塑) 的優先順序最高,而形狀轉換作業 (例如 dot 和 reduce) 的優先順序較低。
- 積極傳播。使用積極策略來傳播分割作業。基本策略只會傳播沒有衝突的切割,而積極策略則會解決衝突。設定較高的積極度雖然可減少記憶體占用空間,但可能會影響潛在的通訊。
- 基本傳播方式。這是階層中最低層級的傳播策略,不會進行任何衝突解決,而是在所有運算元與結果之間傳播相容的軸。
這個階層可解讀為巢狀的 for 迴圈。例如,針對每個使用者優先順序,會套用完整的作業優先順序傳播。
作業分割規則
分割規則會引入每個運算的抽象概念,為實際傳播演算法提供所需資訊,以便從運算元傳播分割作業至結果,或跨運算元傳播,而無須推論特定運算類型及其屬性。這項功能基本上是將特定作業的邏輯分離,並為所有作業提供共用表示法 (資料結構),僅供傳播用。在最簡單的形式中,它只提供以下函式:
GetOpShardingRule(Operation *) -> OpShardingRuleAttr
這項規則可讓我們以通用方式 (以此資料結構為依據,也就是 OpShardingRule) 編寫傳播演算法,而不需要在多個作業中複製類似的程式碼,大幅降低作業中發生錯誤或行為不一致的可能性。
讓我們回到 matmul 範例。
封裝在傳播期間所需資訊 (也就是維度之間的關係) 的編碼,可以採用 einsum 符號的形式編寫:
(i, k), (k, j) -> (i, j)
在這種編碼中,每個維度都會對應至單一因子。
傳播如何使用這項對應:如果運算元的維度/結果是沿著軸分割,傳播作業會在這個對應中查詢該維度的因數,並以相同因數沿著各自的維度分割其他運算元/結果,並且 (依據先前討論的複製作業) 可能也會沿著該軸複製其他沒有該因數的運算元/結果。
複合因子:擴充重塑規則
在許多運算中 (例如 matmul),我們只需要將每個維度對應至單一因數。不過,這對於重塑而言是不夠的。
以下重塑作業會將兩個維度合併為一個:
%out = mhlo.reshape(%in) : (tensor<2x4x32xf32>) -> tensor<8x32xf32>
在此輸入的維度 0 和 1 都對應至輸出的維度 0。假設我們先為輸入內容提供因數:
(i,j,k) : i=2, j=4, k=32
您可以看到,如果我們想在輸出內容中使用相同的因素,就需要單一維度來參照多個因素:
(i,j,k) -> ((ij), k) : i=2, j=4, k=32
如果重塑是用來分割維度,也可以執行相同的操作:
%out = mhlo.reshape(%in) : (tensor<8x32xf32>) -> tensor<2x4x32xf32> ((ij), k) -> (i,j,k) : i=2, j=4, k=32
這裡的 8 大小維度基本上由 2 和 4 因數組成,因此我們稱之為 (i,j,k) 因數。
這些因素也適用於沒有任何完整維度對應其中一個因素的情況:
%out = mhlo.reshape(%in) : (tensor<8x4xf32>) -> tensor<2x16xf32> ((ij), k) -> (i,(jk)) : i=2, j=4, k=4
這個範例也強調了為何我們需要儲存因數大小,因為我們無法輕易從對應的尺寸推斷出這些值。
核心傳播演算法
沿著因素傳播分割作業
Shardy 包含張量、維度和因子的階層。代表不同層級的資料。因子是子維度。這是用於區塊傳播作業的內部階層。每個維度可能對應一或多個因素。維度和因子之間的對應關係由 OpShardingRule 定義。
Shardy 會沿著因數 (而非維度) 傳播區塊切割軸。如要執行這項操作,我們有三個步驟,如下圖所示
- 將專案 DimSharding 改為 FactorSharding
- 在 FactorSharding 的空間中傳播分割軸
- 投射更新後的 FactorSharding,以取得更新後的 DimSharding
以圖表呈現分割傳播的因素
我們會使用下表來呈現區塊化傳播問題和演算法。
F0 | F1 | F2 | 明確複製的軸 | |
---|---|---|---|---|
T0 | ||||
T1 | ||||
T2 |
- 每個資料欄都代表一個因素。F0 代表索引為 0 的因子。我們會沿著因數 (欄) 傳播分割作業。
- 每列代表一個張量。T0 是索引為 0 的張量。張量是特定運算涉及的所有運算元與結果。資料列中的軸不得重疊。一個軸 (或子軸) 無法用於多次分割一個張量。如果明確複製軸,我們就無法使用該軸分割張量。
因此,每個儲存格都代表一個因數分割。部分張量中可能會缺少因子。C = dot(A, B)
的資料表如下所示。含有 N
的儲存格表示因數不在張量中。例如,F2 位於 T1 和 T2,但不在 T0 中。
C = dot(A, B) |
F0 批次處理暗淡 | F1 非契約式維度 | F2 非收縮 dim | F3 合約 dim | 明確複製的軸 |
---|---|---|---|---|---|
T0 = A | 否 | ||||
T1 = B | 否 | ||||
T2 = C | 否 |
收集及傳播分割軸
我們會使用下列簡單的範例,以視覺化方式呈現傳播情形。
F0 | F1 | F2 | 明確複製的軸 | |
---|---|---|---|---|
T0 | 「a」 | 「f」 | ||
T1 | 「a」和「b」 | 「c」和「d」 | 「g」 | |
T2 | 「c」和「e」 |
步驟 1:找出沿著每個因素傳播的軸 (又稱為 (最長) 相容的主要區隔軸)。在本例中,我們會沿著 F0 傳播 ["a", "b"]
、沿著 F1 傳播 ["c"]
,並沿著 F2 傳播空值。
步驟 2:展開因子分割作業,即可取得下列結果。
F0 | F1 | F2 | 明確複製的軸 | |
---|---|---|---|---|
T0 | 「a」, "b" | "c" | 「f」 | |
T1 | 「a」和「b」 | 「c」和「d」 | 「g」 | |
T2 | "a", "b" | 「c」和「e」 |