背景
我們假設讀者至少熟悉切割表示法的基本概念,說明如何在 Shardy 中表示張量的切割。本文說明如何在程式中使用分割表示法,例如將分割附加至程式的特定張量。
分割區傳播是指在程式中,針對部分張量的分割區限制,決定每個張量的分割區的程序。Shardy 的編譯器 API 提供多種方法,可影響/控制分割區傳播。此外,使用者還可以將手動分割的運算插入程式中。
目標
本文將說明 Shardy 中這類 API 元件的設計,並說明其行為和不變量。請注意,雖然這個 API 用於控制區塊傳播作業,但本文件「不會」討論任何關於傳播行為的內容,也不會討論傳播作業的設計方式。
總覽
輸入/輸出分割 - 將分割附加至主函式的輸入或輸出,以指出在函式傳回/傳送時,輸入/輸出張量的分割方式。
分割限制:將分割附加至中間張量 (例如 matmul 的結果),以表示應如何分割該張量或其用途的子集。
分割群組:依 ID 將多個張量分組,以表示應以相同方式分割。
手動運算 - 包含使用網格軸子集手動分割的子運算,其中沿著這些手動軸指定所有輸入和輸出的分割,而在子運算內,張量類型相對於這些分割為本機。
縝密設計
輸入/輸出分割
允許使用者為主要函式的輸入和輸出指定區塊劃分。
在 MLIR 中,屬性可附加至函式引數和結果,因此使用者可以透過這種方式將分割屬性附加至函式。
例如:
@mesh_xy = <["x"=2, "y"=2]>
// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
{sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
%arg1: tensor<8x16xf32>)
-> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
...
}
分割限制
允許使用者在程式中將切割附加至中介張量,藉此告知分割器應如何切割該張量或其用途的子集。
這是一個 MLIR 運算,會將張量做為輸入,並附加分割屬性。這項作業可以:
- 沒有用途 (懸而未決),表示附加的切片是張量本身應如何切片。
- 有用途 - 這表示附加的分割作業是分割操作的分割方式,而輸入張量的其他用途可能會有不同的分割作業 (如果輸入張量沒有其他用途,則行為與沒有用途的情況相同)。傳播作業會決定張量本身的分割方式,並視需要重新分割。
它可以有開放維度分割,這表示運算子可進一步沿著可用的軸進行分割。
@mesh_xy = <["x"=2, "y"=2]>
%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>
分割群組
如果兩個或多個張量之間沒有資料依附性或強力資料依附性,但使用者知道這些張量應以相同或類似的方式劃分,Shardy API 就會提供一種方式來指定這項關係。這樣一來,使用者就能自由明確指定張量應分割為彼此。
為實現這項功能,我們引入了分片群組的概念,其中每個群組都包含任意數量的指示,這些指示與相同的分片群組 ID 相關聯。分割群組會強制執行同一群組內的分割作業。
舉例來說,在以下示例的假設使用者程式中,我們希望將程式的輸出內容切割成與程式輸入內容完全相同的部分,但兩者之間沒有資料相依性。
如果我們執行這個程式,分割傳播功能就無法推斷張量 %1
和 %2
的分割方式,最後會複製這些張量。不過,只要附加 shard_group
屬性,指出輸入 %0
和輸出 %2
位於同一個 shard_group
中,我們就能讓分割 @mesh_xy,
[{"x"},{"y"}]>
從輸入 %0
傳播至輸出 %2
,然後再傳播至圖表的其餘部分,也就是這裡的常數 %1
。我們可以使用 sdy.sharding_group
運算,為群組指派值。
@mesh_xy = <["x"=2, "y"=2]>
module @"jit_zeros_like" {
func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
%0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
%1 = stablehlo.constant dense<0> : tensor<8x2xi64>
%2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
return %2 : tensor<8x2xi64>
}
}
在上述簡單範例中,我們可以改為明確指定輸出端與輸入端相同的分割方式,這樣就能達到相同的效果,因為我們已事先知道要將哪個分割區指派給輸入端,但在更實際的情況下,我們會使用分割區來保持多個張量分割作業的同步,而不需要事先知道任何一個張量的分割方式,Shardy 會負責處理其餘部分,並找出最適合指派給這些張量的分割方式。
手動計算
使用者可能會想要明確控制計算作業的部分如何劃分,以及使用哪些集合。舉例來說,有些使用者希望手動套用集體 matmul (從前端 API),而不是延後至編譯器。我們提供手動運算 API,讓他們可以這麼做。
這是 MLIR 作業,其中包含單一區域,用於手動子運算。使用者會使用網格軸的子集 (可能包括所有),為此子運算指定輸入/輸出分割作業。子運算會以指定的網格軸 (亦即手動軸) 為準,進行本機/手動運算,並以未指定的軸 (亦即自由軸) 為準,進行全域/未分割運算。在傳播期間,子運算可沿著自由軸進一步分割,這與在該作業之外運算的方式相同。
例如:
@mesh_name = <["data"=2, "model"=2]>
%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
out_shardings=[<@mesh_name, [{"data"}, {?}]>]
manual_axes={"data"}
(%arg1: tensor<8x32xf32>) {
// body
return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
不變量
所有
in_shardings
、out_shardings
和manual_axes
都必須參照相同的網格。manual_axes
會依網格排序。manual_axes
必須明確用於所有入/出切割,也就是說,對於每個切割,所有手動軸都必須切割維度,或明確複製。如果某個入/出切割中存在自由軸 (任何不在
manual_axes
中的網格軸),則該軸必須是同一個維度切割中的任何手動軸的次要軸 (在上述範例中,維度切割{"model", "data"}
將無效)。運算的區域/主體是本機運算 (例如,包含使用者指定的集合)。必須是沿著手動軸進行的入/出分割作業的本機 (請參閱上方註解)。
巢狀手動運算
只要每個手動計算作業都使用各自獨特的一組手動軸,您就可以在其中巢狀多個手動計算作業。