Shardy (SDY) 方言會定義以軸為基礎的張量切割表示法,以及其他 API 元件,以便將切割結果附加到張量。
作業
sdy.constant
(sdy::ConstantOp)
常數運算
從常數 value
產生 output
張量。
請參閱:https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
範例:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
特徵:AlwaysSpeculatableImplTrait
介面:ConditionallySpeculatable
、InferTypeOpInterface
、NoMemoryEffect (MemoryEffectOpInterface)
效果:MemoryEffects::Effect{}
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
value | ::mlir::ElementsAttr | 常數向量/張量屬性 |
成果:
結果 | 說明 |
---|---|
output |
任意值類型的張量 |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
資料流邊緣操作。
語法:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
某些 op X 的資料流邊緣定義了一組來源 (每個來源為 X 的運算元或 X 的區塊結束器的運算元) 與一組目標 (每個都是 X 或 X 的區塊引數的結果) 之間的橋樑,因此所有來源和目標都應該以相同方式進行資料分割。
運算可能有多個資料流邊緣彼此直觀的邊緣。
例如:
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
這個 while 運算子有 n 個資料流程邊緣,第 i 個資料流程邊緣位於來源 x_i
、return_value_i
和目標 y_i
、pred_arg_i
、body_arg_i
之間。
sdy.data_flow_edge
會將邊緣的根目標做為輸入內容 (可以是任何目標,但最好是運算結果,而非區塊引數),且不應有任何其他用途。這個運算子並非純粹,因為它可以接受原本沒有任何用途的輸入內容。
sdy.data_flow_edge
也為邊緣的所有目標保留選用的分割作業,且應在傳播期間更新分割作業,而非目標的分割作業 (如果可附加)。當操作有許多邊緣時,這項功能就非常實用,因為:
- 分別透過每個邊緣傳播。
- 分別更新各個邊緣的分割作業,而非一次更新所有目標 (例如,一個作業有一個不可變動的
TensorShardingPerValueAttr
,用於結果分割)。 - 當來源的區塊化方式變更時,請分別將每個邊緣新增至工作清單。
傳播作業會在 sdy.data_flow_edge
的所有來源和目標之間傳播分割作業,就好像是使用來源做為運算元,目標做為結果,以及身分 sdy.op_sharding_rule
的一般運算作業一樣。也就是說,前向傳播是從來源傳播至目標,而反向傳播則是從目標傳播至來源。
我們不允許 SdyDialect
運算子定義 sdy.data_flow_edge
的輸入內容,因此可以假設該輸入內容是由具有未註冊 sdy.sharding
屬性的運算子定義。
特徵:SameOperandsAndResultType
介面:InferTypeOpInterface
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | 張量資料分割 |
運算元:
運算元 | 說明 |
---|---|
input |
任何類型值的形狀 |
成果:
結果 | 說明 |
---|---|
result |
任何類型值的形狀 |
sdy.manual_computation
(sdy::ManualComputationOp)
使用手動集合運算的多裝置平行運算作業
語法:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
跳至以裝置本機程式碼和明確集合方式編寫的區域,其中邏輯形狀會與裝置本機的物理緩衝區形狀相符,集合則會與物理跨裝置通訊完全對應。
本體是相對於 manual_axes 的本機。系統會在任何免費軸的主體上進行傳播,這些軸不在 manual_axes 清單中。
特徵:IsolatedFromAbove
、RecursiveMemoryEffects
、SingleBlockImplicitTerminator<ReturnOp>
、SingleBlock
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 每個運算元/運算結果的張量資料分割 |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 根據運算子的運算元/結果進行張量切割 |
manual_axes | ::mlir::sdy::ManualAxesAttr |
運算元:
運算元 | 說明 |
---|---|
tensors |
任何類型值的排名張量變數 |
成果:
結果 | 說明 |
---|---|
results |
任何類型值的排名張量變數 |
sdy.mesh
(sdy::MeshOp)
命名網格
語法:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
定義新的命名網格。模組中的所有網格都必須有相同數量的裝置 (只有單一 device_id 的網格除外)。網格是 Symbol
作業,會顯示在模組的 SymbolTable
中,並可由其 name
參照。
特徵:HasParent<ModuleOp>
介面:Symbol
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
sym_name | ::mlir::StringAttr | 字串屬性 |
mesh | ::mlir::sdy::MeshAttr | 軸線網格和裝置清單 |
sdy.named_computation
(sdy::NamedComputationOp)
已命名運算作業
語法:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
將運算 (也就是一組運算) 分組,並為其命名。就像所有元素一樣,傳播會在區域內流動。
這可用於處理對其他函式發出的呼叫指示。任何 Shardy 使用者都應編寫匯入/匯出傳遞,將其呼叫作業轉換為 sdy.named_computation
作業,並將呼叫函式的主體複製/複製到 named_computation
的主體中。
區塊中每個引數和傳回值的類型,必須與運算元和運算結果的類型相同。
範例:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
特徵:IsolatedFromAbove
、RecursiveMemoryEffects
、RecursivelySpeculatableImplTrait
、SingleBlockImplicitTerminator<ReturnOp>
、SingleBlock
介面:ConditionallySpeculatable
、ShardableDataFlowOpInterface
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
name | ::mlir::StringAttr | 字串屬性 |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 每個運算元/運算結果的張量資料分割 |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | 根據運算子的運算元/結果進行張量切割 |
運算元:
運算元 | 說明 |
---|---|
operands |
任何類型的變數參數 |
成果:
結果 | 說明 |
---|---|
«unnamed» | 任何類型的變數參數 |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
傳播阻隔操作
語法:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
這個運算就像是恆等運算,會輸出與輸入相同的值。但就傳播而言,這只會讓傳播以特定方向流過。
這可避免在使用分隔操作和運算元件的結果時,發生分割傳播的情形。
FORWARD
表示分割作業只能從運算元到結果。BACKWARD
表示資料分割只能從結果流向運算元。NONE
表示沒有任何資料分割可經由此運算傳播。- 無法指定
BOTH
,因為這個運算會重複。
特徵:AlwaysSpeculatableImplTrait
、Elementwise
、SameOperandsAndResultType
介面:ConditionallySpeculatable
、InferTypeOpInterface
、NoMemoryEffect (MemoryEffectOpInterface)
效果:MemoryEffects::Effect{}
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | 傳播方向列舉 |
運算元:
運算元 | 說明 |
---|---|
input |
任何類型值的排名張量 |
成果:
結果 | 說明 |
---|---|
result |
任何類型值的排名張量 |
sdy.reshard
(sdy::ReshardOp)
將張量重新分割至其他分割區
語法:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
使用指定的分割方式重新分割輸入張量,這與輸入張量的現有分割方式不同。
ShardingConstraintOp 和 ReshardOp 都會將切割作業附加至張量。生命週期為:
- 在區塊處理傳播之前,使用者會新增 ShardingConstraintOp。
- 資料分割傳播會使用 ShardingConstraintOp。分割區傳播結果中沒有 ShardingConstraintOp。但可以視需要新增 ReshardOp。
- 分區器會將 ReshardOp 轉換為集體運算 (或身分運算)。分區器的結果中不應出現 ReshardOp。
// TODO(b/331680067). 新增標準化模式,移除多餘的 // reshard 作業。
特徵:AlwaysSpeculatableImplTrait
、Elementwise
、SameOperandsAndResultType
介面:ConditionallySpeculatable
、InferTypeOpInterface
、NoMemoryEffect (MemoryEffectOpInterface)
效果:MemoryEffects::Effect{}
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | 張量資料分割 |
運算元:
運算元 | 說明 |
---|---|
input |
任意值類型的張量 |
成果:
結果 | 說明 |
---|---|
result |
任意值類型的張量 |
sdy.return
(sdy::ReturnOp)
sdy.return
作業會終止連接至 sdy
個區域性作業的區域,以及其他任何以資料分割區域為基礎的作業。這屬於變異性:它會將類型的值做為引數做為引數,但類型可以是任一種,例如 AnyTensor
,因此可以在 Shardy IR 堆疊的不同層級重複使用。
語法:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
特徵:AlwaysSpeculatableImplTrait
、Terminator
介面:ConditionallySpeculatable
、NoMemoryEffect (MemoryEffectOpInterface)
效果:MemoryEffects::Effect{}
運算元:
運算元 | 說明 |
---|---|
results |
任何類型的變數參數 |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
將張量限制在指定的切割區
語法:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
將切割附加至中介張量 (例如 matmul 的結果),以表示應如何切割該張量或其用途的子集。
如果資料分割具有開放維度且無限制軸,表示可以進一步根據開放維度分割張量。
這項運算可以:
- 沒有用途 (停滯),這表示連接的資料分割是輸入張量本身的資料分割方式。
- 有用途 - 這表示附加的分割作業是分割操作的分割方式,而輸入張量的其他用途可能會有不同的分割作業 (如果輸入張量沒有其他用途,則行為與沒有用途的情況相同)。
特徵:Elementwise
、SameOperandsAndResultType
介面:InferTypeOpInterface
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | 張量資料分割 |
運算元:
運算元 | 說明 |
---|---|
input |
任意值類型的張量 |
成果:
結果 | 說明 |
---|---|
result |
任意值類型的張量 |
sdy.sharding_group
(sdy::ShardingGroupOp)
分割群組作業
語法:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
這個運算子提供介面,可將張量指派給切割群組 (會強制執行相同切割作業的張量群組)。在傳播期間,只要一個群組元素分割,所有其他成員都會以完全相同的方式分割。這項作業會採用引數群組 ID,但不會傳回結果,而是修改內部切片群組表示法,將輸入張量新增至具有指定 ID 的群組。
屬性:
屬性 | 機器學習 IR 類型 | 說明 |
---|---|---|
group_id | ::mlir::IntegerAttr | 64 位元無正負號整數屬性 |
運算元:
運算元 | 說明 |
---|---|
input |
任何類型值的排序張量 |
屬性
AxisRefAttr
參照完整軸或分割子軸
語法:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
名稱 | ::llvm::StringRef |
名稱 |
sub_axis_info | SubAxisInfoAttr |
DimMappingAttr
維度的因子索引清單
所有因子索引都必須位於 [0, num_factors) 範圍內,空白清單則表示這是空值對應 (會使用 *
剖析/列印),也就是說,維度未對應至任何因子。
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
DimensionShardingAttr
維度區塊
要將張量維度分割的軸名稱清單 (從主要到次要)、布林值 (指出是否可進一步分割維度),以及可選的整數 (表示此維度分割的優先順序),這會在分割傳播期間受到尊重。優先順序來自使用者切割註解,值越低,優先順序越高。如果註解中未提供優先順序,系統會假設為最高優先順序。
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
軸 | ::llvm::ArrayRef<AxisRefAttr> |
軸參照清單 |
is_closed | bool |
|
優先順序 | std::optional<int64_t> |
ManualAxesAttr
語法:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
值 | ::llvm::ArrayRef<StringAttr> |
MeshAttr
軸線網格和裝置清單
語法:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
網格是軸的清單,以及指定裝置順序的選用裝置 ID 清單。
如果軸清單為空白,則網格會包含大小為 1 的隱含未命名軸。在這種情況下,如果未提供裝置 ID 清單,隱含的裝置 ID 清單為 [0];如果提供裝置 ID 清單,則該清單必須包含任何非負值的單一整數。我們稱之為最大分割案例。
針對所有非資料分割的情況,如果指定了裝置 ID 清單,軸大小的產品應與裝置數量相符。如果未指定裝置 ID 清單,隱含的裝置 ID 清單為 iota(product(axes))。為簡化操作,我們也禁止指定與 iota(product(axes)) 相同的裝置 ID 清單;在這種情況下,請勿指定裝置 ID 清單。
以下是幾個網格範例:
- 空白網格代表預留位置網格,可在傳播期間取代:<[]>
- 具有未命名軸和明確裝置 ID 的網格,通常用於代表最大資料分割:<[], device_ids=[3]>
- 具有兩個軸線和隱含裝置 ID 的網格 iota(6):<["a"=2, "b"=3]>
- 具有兩個軸線和明確裝置 ID 的網格,用於指定裝置順序:<["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
軸 | ::llvm::ArrayRef<MeshAxisAttr> |
|
device_ids | ::llvm::ArrayRef<int64_t> |
MeshAxisAttr
網格中的命名軸
語法:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
名稱 | ::llvm::StringRef |
名稱 |
大小 | int64_t |
OpShardingRuleAttr
指定如何分割作業。
語法:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
bool # is_custom_rule
>
資料分割規則會指定如何根據運算上的各種屬性 (任何屬性、運算元形狀、結果形狀等) 為作業進行分區。例如:
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
請注意,即使因無法分割而無法使用,我們仍允許大小為 1 的因子,這主要是為了完整性,因為許多運算 (例如點運算) 都有大小為 1 的維度,這些維度會在運算元和結果之間對應。
is_custom_rule
會說明這是使用者為 stablehlo.custom_call
作業定義的規則。分割器不知道如何分割這些作業,因此使用者必須告知分割器如何分割。如果是自訂規則,則系統一律會保留規則,永遠不會移除。is_custom_rule
只能針對 stablehlo.custom_call
運算設為 true。
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
|
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
|
is_custom_rule | bool |
SubAxisInfoAttr
這個子軸是如何從完整軸衍生而來
語法:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
將完整軸分割成 n 個子軸時,軸會轉換為 [k_1,...,k_n],第 i 個子軸可由其左側所有軸大小的乘積 m=prod(k_1,...,k_(i-1))
(又稱為預先大小) 和大小 k_i 表示。因此,子軸資訊屬性會保留這兩個數字,並以以下方式表示:(m)k
代表預先大小 m 和大小 k。
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
pre_size | int64_t |
|
大小 | int64_t |
TensorMappingAttr
張量的每個維度因式對應。
語法:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
TensorShardingAttr
張量區塊
語法:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
張量切割會繫結至特定網格,且只能參照該網格中的軸名稱。維度分割作業會告訴我們,張量每個維度沿著哪些軸 (或子軸) 從主要分割到次要分割。其他所有不會分割維度的軸,都以隱含或明確的方式 (如果出現在複製軸的清單中) 複製。
此區塊劃分所繫結的網格,可以使用符號名稱、參照對應的 MeshOp
符號,或內嵌的 MeshAttr
來指定。
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
網格屬性或平面網格符號參照屬性 |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
|
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
軸參照清單 |
TensorShardingPerValueAttr
根據運算子的運算元/結果進行張量切割
語法:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
參數:
參數 | C++ 類型 | 說明 |
---|---|---|
分割 | ::llvm::ArrayRef<TensorShardingAttr> |
列舉
PropagationDirection
傳播方向列舉
案件:
符號 | 值 | 字串 |
---|---|---|
無 | 0 |
無 |
FORWARD | 1 |
FORWARD |
BACKWARD | 2 |
BACKWARD |
雙方 | 3 |
雙方 |