Dialek Shardy (SDY) menentukan representasi sharding tensor berbasis sumbu dan komponen API tambahan untuk melampirkan sharding ke tensor.
Operasi
sdy.all_gather
(sdy::AllGatherOp)
Melakukan komunikasi all-gather di sepanjang sumbu
Sintaksis:
operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Mengumpulkan potongan tensor di sepanjang sumbu yang ditentukan dalam gathering_axes
.
gathering_axes
adalah daftar daftar sumbu. Daftar luar berada di atas
dimensi tensor. Setiap daftar dalam menentukan sumbu tempat
pengumpulan terpisah harus dilakukan pada dimensi masing-masing. Hal ini akan diterapkan ke sharding operand (tensor
) untuk mendapatkan sharding hasil (out_sharding
).
Perhatikan bahwa out_sharding
tidak digunakan untuk menentukan sharding
hasil. Sebagai gantinya, sharding hasil ditentukan oleh sharding
operand dan gathering_axes
, dan out_sharding
harus cocok dengan
inferensi sharding ini.
Contoh:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>
Batasan:
- Harus memenuhi batasan yang tercantum di
Sdy_CollectiveOpInterface
. - Elemen dalam
gathering_axes
harus memenuhi batasan yang tercantum dalamAxisRefListAttr
. - Menerapkan
gathering_axes
ke sharding operand akan mendapatkanout_sharding
.
Ciri: SameOperandsAndResultType
Antarmuka: InferTypeOpInterface
, Sdy_CollectiveOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
gathering_axes | ::mlir::sdy::ListOfAxisRefListsAttr | Daftar daftar referensi sumbu |
out_sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
tensor |
tensor dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor dari nilai jenis apa pun |
sdy.all_reduce
(sdy::AllReduceOp)
Melakukan komunikasi all-reduce di sepanjang sumbu
Sintaksis:
operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Mengurangi potongan tensor di sepanjang sumbu yang ditentukan dalam reduction_axes
.
Urutan reduction_axes
tidak penting untuk hasilnya, tetapi dapat
memengaruhi urutan grup replika yang sesuai.
Batasan:
- Harus memenuhi batasan yang tercantum di
Sdy_CollectiveOpInterface
. reduction_axes
harus memenuhi batasan yang tercantum diAxisRefListAttr
;reduction_axes
tidak boleh tumpang-tindih dengan sumbu sharding operand;
Ciri: SameOperandsAndResultType
Antarmuka: CollectiveOpInterface
, InferTypeOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
reduction_axes | ::mlir::sdy::AxisRefListAttr | Daftar referensi sumbu |
out_sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
tensor |
tensor dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor dari nilai jenis apa pun |
sdy.all_slice
(sdy::AllSliceOp)
Melakukan operasi slice dinamis di sepanjang sumbu
Sintaksis:
operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Memotong potongan tensor di sepanjang sumbu yang ditentukan dalam slicing_axes
. Ada
dualitas aljabar antara sdy.all_slice
dan sdy.all_gather
.
slicing_axes
adalah daftar daftar sumbu. Daftar luar berada di atas
dimensi tensor. Setiap daftar bagian dalam menentukan sumbu tempat
slice harus dilakukan pada dimensi masing-masing. Ini akan diterapkan ke sharding operand (tensor
) untuk mendapatkan sharding hasil (out_sharding
).
Perhatikan bahwa out_sharding
tidak digunakan untuk menentukan sharding
hasil. Sebagai gantinya, sharding hasil ditentukan oleh sharding
operand dan slicing_axes
, dan out_sharding
harus cocok dengan
inferensi sharding ini.
Contoh:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>
Batasan:
- Elemen dalam
slicing_axes
harus memenuhi batasan yang tercantum dalamAxisRefListAttr
. - Harus memenuhi batasan yang tercantum di
Sdy_CollectiveOpInterface
. - Menerapkan
slicing_axes
ke sharding operand akan mendapatkanout_sharding
.
Ciri: SameOperandsAndResultType
Antarmuka: CollectiveOpInterface
, InferTypeOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
slicing_axes | ::mlir::sdy::ListOfAxisRefListsAttr | Daftar daftar referensi sumbu |
out_sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
tensor |
tensor dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor dari nilai jenis apa pun |
sdy.all_to_all
(sdy::AllToAllOp)
Melakukan komunikasi semua-ke-semua di sepanjang sumbu
Sintaksis:
operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Untuk setiap tuple (axes, src_dim, tgt_dim) dalam daftar parameter, operasi ini memotong potongan tensor di sepanjang dimensi tgt_dim
dan sumbu yang ditentukan di axes
, menyebarkan potongan tersebut di sepanjang sumbu, dan menggabungkannya di sepanjang dimensi src_dim
.
Operasi ini pada dasarnya adalah kombinasi dari all-gather di sepanjang src_dim
dan axes
, diikuti dengan all-slice di sepanjang tgt_dim
dan axes
, yaitu,
akhiran dimensi sharding sumbu src_dim
pada tensor input
ditambahkan ke dimensi sharding sumbu tgt_dim
pada tensor output.
All-to-all akan diterapkan ke sharding operand (tensor
) untuk
mendapatkan sharding hasil (out_sharding
).
Perhatikan bahwa out_sharding
tidak digunakan untuk menentukan sharding
hasil. Sebagai gantinya, sharding hasil ditentukan oleh sharding
operand, src_dim
, tgt_dim
, dan axes
, dan out_sharding
harus cocok dengan
sharding yang disimpulkan ini.
Contoh:
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>
Batasan:
- Harus memenuhi batasan yang tercantum di
Sdy_CollectiveOpInterface
. - Daftar parameter tidak boleh kosong.
- Untuk setiap parameter di
params
:- Elemen di
axes
harus memenuhi batasanAxisRefAttr
. src_dim
dantgt_dim
harus berupa dimensi yang valid (tidak negatif dan kurang dari pangkat tensor).- Setiap
src_dim
atautgt_dim
harus unik di semua parameter. src_dim
harus diurutkan dalam urutan menaik di semua parameter.
- Elemen di
- Memindahkan
axes
darisrc_dim
ketgt_dim
dalam sharding operand akan mendapatkanout_sharding
.
Ciri: SameOperandsAndResultType
Antarmuka: InferTypeOpInterface
, Sdy_CollectiveOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
params | ::mlir::sdy::AlltoAllParamListAttr | Daftar parameter all-to-all |
out_sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
tensor |
tensor dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor dari nilai jenis apa pun |
sdy.collective_permute
(sdy::CollectivePermuteOp)
Melakukan komunikasi permutasi kolektif untuk mengganti sumbu
Sintaksis:
operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Mengirim bagian tensor input dari setiap perangkat ke perangkat lain untuk mengurutkan ulang/mengganti sumbu yang mengelompokkan tensor.
Permutasi kolektif dapat mengubah sharding input sehingga setiap dimensi harus di-sharding seperti sebelumnya, yaitu harus di-sharding di sepanjang sumbu yang produk ukurannya cocok dengan sumbu yang sebelumnya me-sharding tensor.
Hal ini berguna untuk mengurutkan ulang sumbu dalam satu dimensi atau di seluruh dimensi yang berbeda, dan menukar sumbu yang di-shard dengan sumbu yang direplikasi.
Pada contoh di bawah, ukuran tensor yang di-shard adalah tensor<1x4x2xf32>
, dan
dipertahankan oleh permutasi kolektif.
Contoh:
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>
Batasan:
- Harus memenuhi batasan yang tercantum di
Sdy_CollectiveOpInterface
. - Jika sharding input dan output memiliki mesh yang berbeda, mesh tersebut harus memiliki sumbu yang sama persis dan urutan ID perangkat yang berbeda.
- Untuk setiap dimensi, produk ukuran sumbu sharding di
out_sharding
harus cocok dengan sharding dimensi operand yang sesuai.
Ciri: SameOperandsAndResultType
Antarmuka: CollectiveOpInterface
, InferTypeOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
out_sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
tensor |
tensor dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor dari nilai jenis apa pun |
sdy.constant
(sdy::ConstantOp)
Operasi konstan
Menghasilkan tensor output
dari value
konstan.
Lihat: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
Contoh:
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
Ciri: AlwaysSpeculatableImplTrait
Antarmuka: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Efek: MemoryEffects::Effect{}
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
value | ::mlir::ElementsAttr | atribut vektor/tensor konstan |
Hasil:
Hasil | Deskripsi |
---|---|
output |
tensor berbentuk statis dari nilai jenis apa pun |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
Operasi tepi aliran data.
Sintaksis:
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
Tepi aliran data dari beberapa op X menentukan jembatan antara kumpulan sumber (masing-masing adalah operand X atau operand terminator blok X) dan kumpulan target (masing-masing adalah hasil X atau argumen blok X), sehingga semua sumber dan target harus di-shard dengan cara yang sama.
Operasi dapat memiliki beberapa tepi aliran data yang saling ortogonal.
Contoh:
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
Operasi while ini memiliki n tepi aliran data, tepi aliran data ke-i berada di antara
sumber x_i
, return_value_i
, dan target y_i
, pred_arg_i
,
body_arg_i
.
sdy.data_flow_edge
menggunakan pemilik tepi sebagai input (dapat berupa
target mana pun, tetapi sebaiknya hasil op, bukan argumen
blok), yang tidak boleh memiliki penggunaan lain. Operasi ini tidak murni karena
dapat mengambil input yang awalnya tidak memiliki kegunaan apa pun.
sdy.data_flow_edge
juga menyimpan sharding opsional untuk semua target tepi, dan sharding tersebut harus diperbarui, bukan sharding target (jika dapat dilampirkan) selama penyebaran. Hal ini berguna saat op
memiliki banyak tepi, karena jauh lebih efisien untuk:
- ditransmisikan melalui setiap edge secara terpisah.
- memperbarui sharding setiap edge secara terpisah, bukan semua target sekaligus
(misalnya, op memiliki satu
TensorShardingPerValueAttr
yang tidak dapat diubah untuk sharding hasil). - tambahkan setiap tepi ke daftar tugas secara terpisah saat sharding sumber telah berubah.
Propagasi akan menyebarkan sharding di antara semua sumber dan target
sdy.data_flow_edge
seolah-olah itu adalah operasi reguler dengan sumber sebagai operand
dan target sebagai hasil, serta identitas sdy.op_sharding_rule
. Artinya,
propagasi maju adalah dari sumber ke target dan propagasi
mundur adalah dari target ke sumber.
Kami tidak mengizinkan input sdy.data_flow_edge
ditentukan oleh
op SdyDialect
, sehingga kita dapat mengasumsikan bahwa input tersebut ditentukan oleh op yang memiliki
atribut sdy.sharding
yang tidak terdaftar.
Ciri: SameOperandsAndResultType
Antarmuka: InferTypeOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
input |
dibentuk dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
dibentuk dari nilai jenis apa pun |
sdy.manual_computation
(sdy::ManualComputationOp)
Operasi paralelisme multi-perangkat dengan kolektif manual
Sintaksis:
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
Masuk ke wilayah yang ditulis dalam kode lokal per perangkat dengan kolektif eksplisit, dengan bentuk logis yang cocok dengan bentuk buffer fisik per perangkat lokal dan kolektif yang sama persis dengan komunikasi lintas perangkat fisik.
Isi adalah lokal terkait manual_axes. Propagasi akan terjadi melalui body pada sumbu bebas apa pun - yang tidak ada dalam daftar manual_axes.
Batasan:
- Elemen di
in_shardings
danout_shardings
harus memenuhi batasan yang tercantum diTensorShardingAttr
. - Jumlah input/output tensor global dan lokal dari region op harus cocok.
- Sumbu manual harus muncul sebelum sumbu bebas di setiap sharding dimensi.
- Sumbu manual tidak dapat menambahkan padding. Yaitu, ukuran dimensi harus dapat dibagi dengan ukuran sumbu manual yang sesuai.
- Bentuk global dan lokal dari argumen/hasil region op harus cocok.
- Tidak ada sumbu manual yang dipisahkan.
Sifat: IsolatedFromAbove
, RecursiveMemoryEffects
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Antarmuka: ShardableDataFlowOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Sharding tensor per operand/hasil operasi |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Sharding tensor per operand/hasil operasi |
manual_axes | ::mlir::sdy::ManualAxesAttr | Daftar sumbu yang ManualComputationOp-nya bersifat manual |
Operand:
Operand | Deskripsi |
---|---|
tensors |
variabel dari tensor berperingkat dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
results |
variabel dari tensor berperingkat dari nilai jenis apa pun |
sdy.mesh
(sdy::MeshOp)
Mesh bernama
Sintaksis:
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
Menentukan mesh baru yang diberi nama. Semua mesh dalam modul harus memiliki jumlah perangkat yang sama (kecuali untuk mesh dengan satu device_id).
Mesh adalah operasi Symbol
yang muncul di
SymbolTable
modul dan dapat direferensikan oleh name
-nya.
Ciri: HasParent<ModuleOp>
Antarmuka: Symbol
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
sym_name | ::mlir::StringAttr | atribut string |
mesh | ::mlir::sdy::MeshAttr | Mesh sumbu dan daftar perangkat |
sdy.named_computation
(sdy::NamedComputationOp)
Operasi komputasi bernama
Sintaksis:
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
Mengelompokkan komputasi, yaitu blok operasi, dan memberinya nama. Propagasi akan mengalir masuk/keluar dari region seolah-olah semuanya telah disisipkan.
Hal ini dapat digunakan untuk menangani penyebaran melalui petunjuk panggilan ke fungsi
lain. Setiap pengguna Shardy harus menulis kartu impor/ekspor yang
mengonversi operasi panggilan mereka menjadi operasi sdy.named_computation
, menduplikasi/menyalin
isi fungsi yang dipanggil ke dalam isi named_computation
.
Jenis setiap argumen blok dan nilai yang ditampilkan di region harus sama dengan jenis operand dan jenis hasil op.
Contoh:
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Sifat: IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Antarmuka: ConditionallySpeculatable
, InferTypeOpInterface
, ShardableDataFlowOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
name | ::mlir::StringAttr | atribut string |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Sharding tensor per operand/hasil operasi |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Sharding tensor per operand/hasil operasi |
Operand:
Operand | Deskripsi |
---|---|
operands |
variadik dari jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
«unnamed» | variadik dari jenis apa pun |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
Operasi penghalang propagasi
Sintaksis:
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
Operasi ini beroperasi seperti operasi identitas, yang menghasilkan nilai yang sama dengan yang diambil sebagai input. Namun, dalam hal propagasi, hal ini hanya akan memungkinkan propagasi mengalir melaluinya dalam arah tertentu.
Hal ini mencegah sharding disebarkan di antara penggunaan hasil operasi penghalang dan operandnya.
FORWARD
berarti sharding hanya dapat mengalir dari operand ke hasil.BACKWARD
berarti sharding hanya dapat mengalir dari hasil ke operand.NONE
berarti tidak ada sharding yang dapat di-propagasi melalui operasi ini.- Tidak dapat menentukan
BOTH
, karena operasi ini akan menjadi redundan.
Sifat: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Antarmuka: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Efek: MemoryEffects::Effect{}
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | enum arah penyebaran |
Operand:
Operand | Deskripsi |
---|---|
input |
tensor berperingkat dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor berperingkat dari nilai jenis apa pun |
sdy.reshard
(sdy::ReshardOp)
Melakukan resharding tensor ke sharding yang berbeda
Sintaksis:
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
Melakukan resharding pada tensor input dengan sharding yang ditentukan, yang berbeda dari sharding tensor input yang ada.
ShardingConstraintOp dan ReshardOp melampirkan sharding ke tensor. Masa aktifnya adalah:
- Sebelum penyebaran sharding, ShardingConstraintOp ditambahkan oleh pengguna.
- Penyebaran sharding menggunakan ShardingConstraintOp. Tidak ada ShardingConstraintOp dalam hasil propagasi sharding. Sebagai gantinya, ReshardOp dapat ditambahkan jika diperlukan.
- Partisioner mengonversi ReshardOp menjadi operasi kolektif (atau operasi identitas). Tidak boleh ada ReshardOp dalam hasil partisi.
// TODO(b/331680067). Tambahkan pola kanonisasi untuk menghapus operasi // reshard yang redundan.
Sifat: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Antarmuka: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Efek: MemoryEffects::Effect{}
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
input |
tensor dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor dari nilai jenis apa pun |
sdy.return
(sdy::ReturnOp)
Operasi sdy.return
menghentikan region yang dilampirkan ke operasi berbasis region sdy
dan operasi berbasis region Shardy lainnya. Fungsi ini bersifat variadik: fungsi ini menggunakan daftar nilai sebagai argumen yang jenisnya dapat berupa apa saja (tetapi
dari jenis yang sama, misalnya AnyTensor
) sehingga dapat digunakan kembali di berbagai
tingkat stack IR Shardy.
Sintaksis:
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
Sifat: AlwaysSpeculatableImplTrait
, Terminator
Antarmuka: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Efek: MemoryEffects::Effect{}
Operand:
Operand | Deskripsi |
---|---|
results |
variadik dari jenis apa pun |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
Membatasi tensor ke sharding yang ditentukan
Sintaksis:
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
Melampirkan sharding ke tensor perantara (misalnya, hasil matmul) untuk menunjukkan bahwa ini adalah cara tensor tersebut, atau sebagian penggunaannya, harus di-sharding.
Jika sharding memiliki dimensi terbuka dan sumbu yang tidak dibatasi, berarti tensor dapat di-sharding lebih lanjut di sepanjang dimensi terbuka.
Operasi ini dapat:
- Tidak memiliki penggunaan (tergantung) - yang berarti sharding yang dilampirkan adalah cara tensor input itu sendiri harus di-sharding.
- Memiliki penggunaan - yang berarti sharding yang dilampirkan adalah cara penggunaan op batasan sharding harus di-sharding, sedangkan penggunaan lain dari tensor input mungkin memiliki sharding yang berbeda (jika tensor input tidak memiliki penggunaan lain, perilakunya sama dengan kasus tanpa penggunaan).
Ciri: SameOperandsAndResultType
Antarmuka: InferTypeOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Sharding tensor |
Operand:
Operand | Deskripsi |
---|---|
input |
tensor dari nilai jenis apa pun |
Hasil:
Hasil | Deskripsi |
---|---|
result |
tensor dari nilai jenis apa pun |
sdy.sharding_group
(sdy::ShardingGroupOp)
Membatasi tensor dalam grup agar memiliki sharding yang sama.
Sintaksis:
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
Operasi ini menyediakan antarmuka untuk menetapkan tensor ke grup sharding ( grup tensor yang akan diberlakukan untuk memiliki sharding yang identik). Selama penyebaran, segera setelah satu elemen grup di-shard, semua anggota lainnya akan di-shard dengan cara yang sama persis. Operasi ini mengambil ID grup argumen dan tidak menampilkan hasil, tetapi mengubah representasi grup sharding internal untuk menambahkan tensor input ke grup dengan ID yang diberikan.
Antarmuka: InferTypeOpInterface
Atribut:
Atribut | Jenis MLIR | Deskripsi |
---|---|---|
group_id | ::mlir::IntegerAttr | Atribut bilangan bulat tanpa tanda 64-bit |
Operand:
Operand | Deskripsi |
---|---|
input |
tensor berperingkat dari nilai jenis apa pun |
Atribut
AllToAllParamAttr
Parameter all-to-all
Sintaksis:
#sdy.all_to_all_param<
::llvm::ArrayRef<AxisRefAttr>, # axes
int64_t, # src_dim
int64_t # tgt_dim
>
Tupla yang berisi sumbu dan dimensi sumber/target untuk melakukan semua-ke-semua.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
sumbu | ::llvm::ArrayRef<AxisRefAttr> |
sumbu untuk melakukan semua-ke-semua |
src_dim | int64_t |
indeks dimensi sumber |
tgt_dim | int64_t |
indeks dimensi target |
AlltoAllParamListAttr
Daftar parameter all-to-all
Sintaksis:
#sdy.all_to_all_param_list<
::llvm::ArrayRef<AllToAllParamAttr> # value
>
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
nilai | ::llvm::ArrayRef<AllToAllParamAttr> |
AxisRefAttr
Referensi ke sumbu penuh atau sub-sumbu terpisah
Sintaksis:
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
Batasan:
name
harus ada dalamMeshAttr
yang terikat.- Jika ada,
sub_axis_info
harus memenuhi batasanSubAxisInfoAttr
.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
nama | ::llvm::StringRef |
nama sumbu ini |
sub_axis_info | SubAxisInfoAttr |
info tambahan jika ini adalah sub-sumbu |
AxisRefListAttr
Daftar referensi sumbu
Sintaksis:
#sdy.axis_ref_list<
::llvm::ArrayRef<AxisRefAttr> # value
>
Batasan:
- Elemen di
value
harus memenuhi batasanAxisRefAttr
. - Tidak ada referensi sumbu atau sub-sumbu duplikat yang tumpang-tindih satu sama lain.
- Tidak ada dua referensi sumbu yang berdekatan yang merupakan sub-sumbu berturut-turut dari sumbu penuh yang sama, yaitu, keduanya dapat digabungkan menjadi satu sub-sumbu atau sumbu penuh.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
nilai | ::llvm::ArrayRef<AxisRefAttr> |
DimMappingAttr
Daftar indeks faktor untuk dimensi
Daftar kosong menunjukkan bahwa ini adalah pemetaan null (ini diuraikan/dicetak
dengan *
), yaitu dimensi tidak dipetakan ke faktor apa pun.
Batasan:
- Ada minimal satu indeks faktor.
- Indeks faktor harus dalam rentang [0,
$factor_sizes
). - Jika ada beberapa faktor, tidak ada satu pun yang dapat memiliki ukuran 1.
- Tidak ada indeks faktor duplikat.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
faktor yang dipetakan ke dimensi ini |
DimensionShardingAttr
Sharding dimensi
Daftar nama sumbu untuk membuat shard dimensi tensor dari utama ke minor, Boolean yang menunjukkan apakah dimensi dapat di-shard lebih lanjut, dan bilangan bulat opsional yang menunjukkan prioritas sharding dimensi ini, yang akan dipatuhi selama penyebaran sharding. Prioritas berasal dari anotasi sharding pengguna dan nilai yang lebih rendah menunjukkan prioritas yang lebih tinggi. Prioritas tertinggi diasumsikan jika prioritas tidak ada dalam anotasi.
Batasan:
- Elemen di
axes
harus memenuhi batasan yang tercantum diAxisRefListAttr
. - Jika sharding dimensi memiliki prioritas:
- Prioritas lebih besar dari atau sama dengan 0.
- Dimensi memiliki minimal satu sumbu jika ditutup.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
sumbu | ::llvm::ArrayRef<AxisRefAttr> |
referensi sumbu |
is_closed | bool |
apakah dimensi ini tidak dapat di-sharding lebih lanjut |
prioritas | std::optional<int64_t> |
prioritas yang digunakan selama propagasi berbasis prioritas pengguna |
ListOfAxisRefListsAttr
Daftar daftar referensi sumbu
Sintaksis:
#sdy.list_of_axis_ref_lists<
::llvm::ArrayRef<AxisRefListAttr> # value
>
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
nilai | ::llvm::ArrayRef<AxisRefListAttr> |
ManualAxesAttr
Daftar sumbu yang ManualComputationOp-nya manual
Sintaksis:
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
nilai | ::llvm::ArrayRef<StringAttr> |
MeshAttr
Mesh sumbu dan daftar perangkat
Sintaksis:
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
Mesh adalah daftar sumbu dan daftar ID perangkat opsional yang menentukan urutan perangkat.
Jika daftar sumbu kosong, mesh memiliki sumbu implisit tanpa nama dengan ukuran 1. Dalam hal ini, jika daftar ID perangkat tidak diberikan, daftar ID perangkat implisit adalah [0]; jika daftar ID perangkat diberikan, daftar tersebut harus berisi satu bilangan bulat dari nilai non-negatif. Kami menyebutnya kasus sharding maksimal.
Untuk semua kasus sharding non-maksimum, jika daftar ID perangkat ditentukan, produk ukuran sumbu harus cocok dengan jumlah perangkat. Jika daftar ID perangkat tidak ditentukan, daftar ID perangkat implisit adalah iota(product(axes)). Untuk memudahkan, kami juga tidak mengizinkan penentuan daftar ID perangkat yang sama dengan iota(product(axes)); dalam hal ini, daftar ID perangkat tidak boleh ditentukan.
Berikut beberapa contoh mesh:
- Mesh kosong mewakili mesh placeholder yang dapat diganti selama penyebaran: <[]>
- Mesh dengan sumbu tanpa nama dan ID perangkat eksplisit, yang biasanya digunakan untuk merepresentasikan sharding maksimum: <[], device_ids=[3]>
- Mesh dengan dua sumbu dan ID perangkat implisit iota(6): <["a"=2, "b"=3]>
- Mesh dengan dua sumbu dan ID perangkat eksplisit yang menentukan urutan perangkat: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
Batasan:
- Elemen di
axes
tidak boleh memiliki nama duplikat. - Jika
device_ids
ditentukan:- Hasil ukuran sumbu harus sesuai dengan jumlah perangkat.
- Semua elemennya tidak boleh negatif.
device_ids
tidak boleh sama denganiota(product(axis_sizes))
.device_ids
yang diurutkan harusiota(product(axis_sizes))
.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
sumbu | ::llvm::ArrayRef<MeshAxisAttr> |
sumbu mesh |
device_ids | ::llvm::ArrayRef<int64_t> |
pengurutan perangkat eksplisit atau ID perangkat maksimum |
MeshAxisAttr
Sumbu bernama dalam mesh
Sintaksis:
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
nama | ::llvm::StringRef |
nama |
ukuran | int64_t |
ukuran sumbu ini |
OpShardingRuleAttr
Menentukan cara operasi dapat dipartisi.
Sintaksis:
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
::llvm::ArrayRef<int64_t>, # reduction_factors
::llvm::ArrayRef<int64_t>, # need_replication_factors
::llvm::ArrayRef<int64_t>, # permutation_factors
::llvm::ArrayRef<int64_t>, # blocked_propagation_factors
bool # is_custom_rule
>
Aturan sharding menentukan cara operasi dapat dipartisi sesuai dengan berbagai properti pada op - atribut apa pun, bentuk operand, bentuk hasil, dll. Misalnya:
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
Perhatikan bahwa kami mengizinkan faktor dengan ukuran 1 meskipun tidak dapat di-shard, hal ini terutama untuk kelengkapan karena banyak operasi seperti operasi pointwise memiliki dimensi ukuran satu yang sesuai di seluruh operand dan hasil.
Jenis faktor:
reduction_factors
berisi indeks faktor yang memerlukan pengurangan, seperti dimensi kontraksi dalam operasi titik.need_replication_factors
berisi indeks faktor yang memerlukan replika penuh, seperti dimensi yang diurutkan dalam operasi pengurutan.permutation_factors
berisi indeks faktor yang memerlukan permutasi kolektif jika di-shard, seperti dimensi padding dalam operasi pad.- Semua faktor lainnya dianggap sebagai faktor pass-through, yaitu faktor yang tidak memerlukan komunikasi apa pun jika di-shard dengan cara yang sama di semua tensor yang dipetakan ke faktor tersebut.
blocked_propagation_factors
berisi faktor yang tidak mengizinkan penyebaran sharding. Ini ortogonal dengan jenis faktor. Yaitu,
faktor penyebaran yang diblokir dapat berupa jenis faktor apa pun.
is_custom_rule
menjelaskan apakah ini adalah aturan yang ditentukan oleh pengguna. Pengguna dapat menentukan aturan sharding untuk panggilan kustom mereka atau menimpa aturan sharding standar yang telah ditentukan sebelumnya untuk operasi standar. Aturan kustom
selalu dipertahankan/tidak pernah dihapus.
Batasan:
- Jumlah pemetaan operand/hasil harus cocok dengan jumlah operand/hasil operasi.
- Ada minimal satu pemetaan (tidak dapat memiliki aturan untuk operasi tanpa operand/hasil).
- Peringkat setiap
TensorMappingAttr
cocok dengan peringkat jenis tensor yang sesuai. - Untuk setiap grup faktor (
reduction_factors
,need_replication_factors
,permutation_factors
):- Elemen harus berada dalam rentang [0,
$factor_sizes
]. - Tidak ada indeks faktor duplikat dalam setiap grup dan di seluruh grup.
- Elemen harus berada dalam rentang [0,
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
ukuran semua faktor dalam aturan ini |
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
pemetaan operand |
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
pemetaan hasil |
reduction_factors | ::llvm::ArrayRef<int64_t> |
faktor yang memerlukan pengurangan |
need_replication_factors | ::llvm::ArrayRef<int64_t> |
faktor yang memerlukan replikasi penuh |
permutation_factors | ::llvm::ArrayRef<int64_t> |
faktor yang memerlukan permutasi kolektif |
blocked_propagation_factors | ::llvm::ArrayRef<int64_t> |
faktor yang tidak menyebarkan sharding |
is_custom_rule | bool |
apakah aturan tersebut untuk stablehlo.custom_call |
SubAxisInfoAttr
Info tentang cara sub-sumbu ini berasal dari sumbu penuh
Sintaksis:
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
Saat membagi sumbu penuh menjadi n sub-sumbu, sumbu akan dibentuk ulang menjadi
[k_1,...,k_n], dan sub-sumbu ke-i dapat dinyatakan dengan produk semua
ukuran sumbu di sebelah kirinya m=prod(k_1,...,k_(i-1))
(alias pra-ukuran) dan ukuran
k_i. Oleh karena itu, atribut sub-axis-info menyimpan dua angka tersebut dan
ditunjukkan sebagai berikut: (m)k
untuk ukuran awal m dan ukuran k.
Batasan:
pre-size
minimal 1.size
lebih besar dari 1.pre-size
harus membagi ukuran sumbu penuh, yaitupre-size
dansize
membagi ukuran sumbu penuh, dan sub-sumbu tidak melebihi sumbu penuh.- Ukuran sub-sumbu tidak sama dengan ukuran sumbu penuh yang sesuai, dalam hal ini sumbu penuh harus digunakan.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
pre_size | int64_t |
produk ukuran sub-sumbu di sebelah kiri sub-sumbu ini |
ukuran | int64_t |
ukuran sub-sumbu ini |
TensorMappingAttr
Pemetaan faktor untuk setiap dimensi tensor.
Sintaksis:
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
Batasan:
- Elemen di
dim_mappings
harus memenuhi batasan diDimMappingAttr
. - Tidak ada indeks faktor duplikat di seluruh dimensi.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
pemetaan dimensi |
TensorShardingAttr
Sharding tensor
Sintaksis:
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
Sharding tensor terikat ke mesh tertentu, dan hanya dapat mereferensikan nama sumbu dari mesh tersebut. Sharding dimensi memberi tahu kita untuk setiap dimensi tensor, di sepanjang sumbu (atau sub-sumbu) yang di-sharding dari utama ke minor. Semua sumbu lain yang tidak melakukan shard dimensi direplikasi secara implisit atau secara eksplisit (jika muncul dalam daftar sumbu yang direplikasi).
Mesh yang menjadi tempat sharding ini dapat ditentukan oleh nama
simbol, yang mereferensikan simbol MeshOp
yang sesuai, atau MeshAttr
yang disisipkan.
Batasan:
- Elemen di
dim_shardings
harus memenuhi batasan yang tercantum diDimensionShardingAttr
. - Elemen di
replicated_axes
harus memenuhi batasan yang tercantum diAxisRefListAttr
. - Jika jenis tensor yang sesuai bukan
ShapedType
, sharding harus memiliki peringkat 0 dan tidak ada sumbu yang direplikasi. - Tensor harus memiliki peringkat.
- Jumlah sharding dimensi sama dengan pangkat tensor.
- Dimensi berukuran 0 tidak di-shard.
- Item di
replicated_axes
diurutkan berdasarkanmesh_or_ref
(lihatAxisRefAttr::getMeshComparator
).
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
atribut mesh atau atribut referensi simbol mesh datar |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
sharding dimensi |
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
referensi sumbu |
TensorShardingPerValueAttr
Sharding tensor per operand/hasil operasi
Sintaksis:
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
Daftar TensorShardingAttr
, satu untuk setiap operand/hasil operasi.
Batasan:
- Elemen di
shardings
harus memenuhi batasanTensorShardingAttr
.
Parameter:
Parameter | Jenis C++ | Deskripsi |
---|---|---|
sharding | ::llvm::ArrayRef<TensorShardingAttr> |
sharding per nilai |
Enum
PropagationDirection
Enum arah penyebaran
Kasus:
Simbol | Nilai | String |
---|---|---|
TIDAK ADA | 0 |
TIDAK ADA |
MAJU | 1 |
MAJU |
KEMBALI | 2 |
KEMBALI |
KEDUANYA | 3 |
KEDUANYA |