-sdy-close-shardings

關閉張量分片,並捨棄複製的軸。

-sdy-constant-or-scalar-merger

合併相同的常數和純量擴展,並比對分片。

對分片相同的常數執行輕量 CSE。

匯入管道會分割及複製常數和純量擴充功能,因此分片不會在常數子運算的各種用途之間傳播。如果常數在傳播後具有相同的分片,這個階段會合併常數,以節省編譯時間。詳情請參閱 -sdy-constant-or-scalar-splitter。

-sdy-convert-global-to-local

將 SDY 程式從全域形狀轉換為本機形狀。

根據分片屬性分割邏輯維度,將 SDY 程式從全域形狀轉換為本機形狀。

這個傳遞會運用型別轉換器,將全域邏輯形狀的 RankedTensorType 對應至裝置本機的實體形狀。

-sdy-drop-sharding-rules

所有已註冊營運商的開賣活動 OpShardingRuleAttr

-sdy-insert-explicit-reshards

插入明確的重新分片,使所有作業都有相容的分片。

相容的分片基本上是指作業可以接受分片的運算元,並產生分片的結果,不需要任何重新分片通訊 (請注意,作業可能仍需要通訊,例如全縮減或光環交換)。

傳播後,部分作業可能仍有不相容的分片。

請注意,當軸 (或子軸) 用於跨多個張量分片非對應維度 (例如 matmul 中的非收縮維度),或當軸在一個張量中分片維度,但在另一個張量中未分片對應維度時,表示作業有分片衝突。因此,在這次傳遞後,作業就會變成無衝突。

這個傳遞會明確插入資料重新分割作業,以便針對每個作業,在所有運算元和結果中,以相同方式分割對應的維度,且每個軸 (或子軸) 只能用於分割單一維度類型。

範例:

輸入:

mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
  : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>

輸出內容:

sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
  : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>

在上例中,lhsrhs 都在軸「x」上,以非收縮維度分片,因此不相容。傳遞會在點運算之前,在 rhs 上插入明確的重新分片,以便點運算具有相容的分片。

選項

-enable-full-version                  : Enable full version.
-avoid-reshards-on-named-computations : Avoid explicit reshards/collectives on named computations.

-sdy-remove-all-gather-reduce-scatter-for-cmv1

_Removes sdy.all_gather and sdy.reducescatter for CMV1.

移除模式 all-gather + 點中的 all-gather。移除模式點 + reduce-scatter 中的 reduce-scatter。這個傳遞是為了與集合矩陣乘法 V1 (CMV1) 相容。這是 b/432019089 的暫時解決方案。

-sdy-remove-propagation-debug-info

匯出時移除傳播偵錯資訊 (傳播邊緣和來源分片)。

-sdy-remove-sharding-groups

在傳播後移除 ShardingGroupOps。

-sdy-remove-sub-axes-in-input-output-shardings

移除輸入/輸出分片中的子軸。

部分 Shardy 使用者希望函式輸入/輸出內容具有分片,但沒有子軸。這個傳遞會從輸入/輸出分片中移除子軸和後方軸。這個階段通常會在 sdy-update-non-divisible-input-output-shardings 之後,確保移除子軸不會導致任何無法分割的分片。

-sdy-reshard-to-collectives

將 ReshardOp 轉換為各種 Shardy 集體作業。

比對重新分片作業,並將其重新編寫為各種 Shardy 集合作業。通過此階段後,模組中就不會再有任何重新分片作業。

如果 keepRedundantReshards 為 true,則只會保留多餘的重新分片作業。根據預設,系統會假設已插入明確的重新分片 (sdy-insert-explicit-reshards),且不會保留多餘的重新分片。如果可能尚未插入明確的重新分片,則應保留多餘的重新分片。

範例:

輸入:

mesh = <"x"=2, "y"=2, "z"=2>
%0 : tensor<16x2xf32> {sdy.sharding<@mesh, \[{"x", "y", "z"}, {}\]>
%1 = sdy.reshard %arg0 <@mesh, \[{"x"}, {}\]> : tensor<16x2xf32>

輸出內容:

mesh = <"x"=2, "y"=2, "z"=2>
%0 : tensor<16x2xf32> {sdy.sharding<@mesh, \[{"x", "y", "z"}, {}\]>
%1 = sdy.all_gather \[{"y", "z"}, {}\] %arg0 out_sharding=<@mesh, \[{"x"}, {}\]> : tensor<16x2xf32>

在上述範例中,張量 %0 : tensor<16x2xf32> 會分片為 \[{"x", "y", "z"}, {}\]。接著,reshard op 會將其重新分片為 \[{"x"}, {}\]。在第一個軸上,由於後續重新分片後會移除後置字元 {"y", "z"},因此我們推斷已全數收集 {"y", "z"}。第二個維度不會變更。

選項

-keep-redundant-reshards : Whether it keeps redundant reshards or removes.

-sdy-sharding-constraint-to-reshard

將 ShardingConstraintOp 轉換為 ReshardOp。

-sdy-sink-data-flow-edges

將所有 DataFlowEdgeOp 匯入輸入內容。

將每個 DataFlowEdgeOp 的分片移至其輸入內容 (邊緣的根目標),並以輸入內容取代作業。

選項

-sink-debug-sharding-origins          : Whether to sink the debug sharding origins info. See `debug-sharding-origins` option in propagation for more info.
-sink-debug-propagation-edge-sharding : Whether to sink the debug propagation edge sharding info. See `debug-propagation-edge-sharding` option in propagation for more info.

-sdy-update-non-divisible-input-output-shardings

平均分配 FuncOp 輸入/輸出內容,因此不必因無法整除的分片而填補。

Shardy 使用者希望函式輸入/輸出可平均分割/分片,避免需要填補張量。傳播可能會導致輸入/輸出具有不可分割的切分,因此這個傳遞會將其更新為原始切分的最大維度切分前置字串,該前置字串會平均切分。