Dialek 'sdy'

Dialek Shardy (SDY) menentukan representasi sharding tensor berbasis sumbu dan komponen API tambahan untuk melampirkan sharding ke tensor.

Operasi

sdy.all_gather (sdy::AllGatherOp)

Melakukan komunikasi all-gather di sepanjang sumbu

Sintaksis:

operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Mengumpulkan potongan tensor di sepanjang sumbu yang ditentukan dalam gathering_axes.

gathering_axes adalah daftar daftar sumbu. Daftar luar berada di atas dimensi tensor. Setiap daftar dalam menentukan sumbu tempat pengumpulan terpisah harus dilakukan pada dimensi masing-masing. Hal ini akan diterapkan ke sharding operand (tensor) untuk mendapatkan sharding hasil (out_sharding).

Perhatikan bahwa out_sharding tidak digunakan untuk menentukan sharding hasil. Sebagai gantinya, sharding hasil ditentukan oleh sharding operand dan gathering_axes, dan out_sharding harus cocok dengan inferensi sharding ini.

Contoh:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>

Batasan:

  • Harus memenuhi batasan yang tercantum di Sdy_CollectiveOpInterface.
  • Elemen dalam gathering_axes harus memenuhi batasan yang tercantum dalam AxisRefListAttr.
  • Menerapkan gathering_axes ke sharding operand akan mendapatkan out_sharding.

Ciri: SameOperandsAndResultType

Antarmuka: InferTypeOpInterface, Sdy_CollectiveOpInterface

Atribut:

AtributJenis MLIRDeskripsi
gathering_axes::mlir::sdy::ListOfAxisRefListsAttrDaftar daftar referensi sumbu
out_sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
tensor tensor dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor dari nilai jenis apa pun

sdy.all_reduce (sdy::AllReduceOp)

Melakukan komunikasi all-reduce di sepanjang sumbu

Sintaksis:

operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Mengurangi potongan tensor di sepanjang sumbu yang ditentukan dalam reduction_axes. Urutan reduction_axes tidak penting untuk hasilnya, tetapi dapat memengaruhi urutan grup replika yang sesuai.

Batasan:

  • Harus memenuhi batasan yang tercantum di Sdy_CollectiveOpInterface.
  • reduction_axes harus memenuhi batasan yang tercantum di AxisRefListAttr;
  • reduction_axes tidak boleh tumpang-tindih dengan sumbu sharding operand;

Ciri: SameOperandsAndResultType

Antarmuka: CollectiveOpInterface, InferTypeOpInterface

Atribut:

AtributJenis MLIRDeskripsi
reduction_axes::mlir::sdy::AxisRefListAttrDaftar referensi sumbu
out_sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
tensor tensor dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor dari nilai jenis apa pun

sdy.all_slice (sdy::AllSliceOp)

Melakukan operasi slice dinamis di sepanjang sumbu

Sintaksis:

operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Memotong potongan tensor di sepanjang sumbu yang ditentukan dalam slicing_axes. Ada dualitas aljabar antara sdy.all_slice dan sdy.all_gather.

slicing_axes adalah daftar daftar sumbu. Daftar luar berada di atas dimensi tensor. Setiap daftar bagian dalam menentukan sumbu tempat slice harus dilakukan pada dimensi masing-masing. Ini akan diterapkan ke sharding operand (tensor) untuk mendapatkan sharding hasil (out_sharding).

Perhatikan bahwa out_sharding tidak digunakan untuk menentukan sharding hasil. Sebagai gantinya, sharding hasil ditentukan oleh sharding operand dan slicing_axes, dan out_sharding harus cocok dengan inferensi sharding ini.

Contoh:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>

Batasan:

  • Elemen dalam slicing_axes harus memenuhi batasan yang tercantum dalam AxisRefListAttr.
  • Harus memenuhi batasan yang tercantum di Sdy_CollectiveOpInterface.
  • Menerapkan slicing_axes ke sharding operand akan mendapatkan out_sharding.

Ciri: SameOperandsAndResultType

Antarmuka: CollectiveOpInterface, InferTypeOpInterface

Atribut:

AtributJenis MLIRDeskripsi
slicing_axes::mlir::sdy::ListOfAxisRefListsAttrDaftar daftar referensi sumbu
out_sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
tensor tensor dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor dari nilai jenis apa pun

sdy.all_to_all (sdy::AllToAllOp)

Melakukan komunikasi semua-ke-semua di sepanjang sumbu

Sintaksis:

operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Untuk setiap tuple (axes, src_dim, tgt_dim) dalam daftar parameter, operasi ini memotong potongan tensor di sepanjang dimensi tgt_dim dan sumbu yang ditentukan di axes, menyebarkan potongan tersebut di sepanjang sumbu, dan menggabungkannya di sepanjang dimensi src_dim.

Operasi ini pada dasarnya adalah kombinasi dari all-gather di sepanjang src_dim dan axes, diikuti dengan all-slice di sepanjang tgt_dim dan axes, yaitu, akhiran dimensi sharding sumbu src_dim pada tensor input ditambahkan ke dimensi sharding sumbu tgt_dim pada tensor output.

All-to-all akan diterapkan ke sharding operand (tensor) untuk mendapatkan sharding hasil (out_sharding).

Perhatikan bahwa out_sharding tidak digunakan untuk menentukan sharding hasil. Sebagai gantinya, sharding hasil ditentukan oleh sharding operand, src_dim, tgt_dim, dan axes, dan out_sharding harus cocok dengan sharding yang disimpulkan ini.

Contoh:

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>

Batasan:

  • Harus memenuhi batasan yang tercantum di Sdy_CollectiveOpInterface.
  • Daftar parameter tidak boleh kosong.
  • Untuk setiap parameter di params:
    • Elemen di axes harus memenuhi batasan AxisRefAttr.
    • src_dim dan tgt_dim harus berupa dimensi yang valid (tidak negatif dan kurang dari pangkat tensor).
    • Setiap src_dim atau tgt_dim harus unik di semua parameter.
    • src_dim harus diurutkan dalam urutan menaik di semua parameter.
  • Memindahkan axes dari src_dim ke tgt_dim dalam sharding operand akan mendapatkan out_sharding.

Ciri: SameOperandsAndResultType

Antarmuka: InferTypeOpInterface, Sdy_CollectiveOpInterface

Atribut:

AtributJenis MLIRDeskripsi
params::mlir::sdy::AlltoAllParamListAttrDaftar parameter all-to-all
out_sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
tensor tensor dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor dari nilai jenis apa pun

sdy.collective_permute (sdy::CollectivePermuteOp)

Melakukan komunikasi permutasi kolektif untuk mengganti sumbu

Sintaksis:

operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Mengirim bagian tensor input dari setiap perangkat ke perangkat lain untuk mengurutkan ulang/mengganti sumbu yang mengelompokkan tensor.

Permutasi kolektif dapat mengubah sharding input sehingga setiap dimensi harus di-sharding seperti sebelumnya, yaitu harus di-sharding di sepanjang sumbu yang produk ukurannya cocok dengan sumbu yang sebelumnya me-sharding tensor.

Hal ini berguna untuk mengurutkan ulang sumbu dalam satu dimensi atau di seluruh dimensi yang berbeda, dan menukar sumbu yang di-shard dengan sumbu yang direplikasi.

Pada contoh di bawah, ukuran tensor yang di-shard adalah tensor<1x4x2xf32>, dan dipertahankan oleh permutasi kolektif.

Contoh:

sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>

Batasan:

  • Harus memenuhi batasan yang tercantum di Sdy_CollectiveOpInterface.
  • Jika sharding input dan output memiliki mesh yang berbeda, mesh tersebut harus memiliki sumbu yang sama persis dan urutan ID perangkat yang berbeda.
  • Untuk setiap dimensi, produk ukuran sumbu sharding di out_sharding harus cocok dengan sharding dimensi operand yang sesuai.

Ciri: SameOperandsAndResultType

Antarmuka: CollectiveOpInterface, InferTypeOpInterface

Atribut:

AtributJenis MLIRDeskripsi
out_sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
tensor tensor dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor dari nilai jenis apa pun

sdy.constant (sdy::ConstantOp)

Operasi konstan

Menghasilkan tensor output dari value konstan.

Lihat: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant

Contoh:

%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>

Ciri: AlwaysSpeculatableImplTrait

Antarmuka: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Efek: MemoryEffects::Effect{}

Atribut:

AtributJenis MLIRDeskripsi
value::mlir::ElementsAttratribut vektor/tensor konstan

Hasil:

Hasil Deskripsi
output tensor berbentuk statis dari nilai jenis apa pun

sdy.data_flow_edge (sdy::DataFlowEdgeOp)

Operasi tepi aliran data.

Sintaksis:

operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)

Tepi aliran data dari beberapa op X menentukan jembatan antara kumpulan sumber (masing-masing adalah operand X atau operand terminator blok X) dan kumpulan target (masing-masing adalah hasil X atau argumen blok X), sehingga semua sumber dan target harus di-shard dengan cara yang sama.

Operasi dapat memiliki beberapa tepi aliran data yang saling ortogonal.

Contoh:

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

Operasi while ini memiliki n tepi aliran data, tepi aliran data ke-i berada di antara sumber x_i, return_value_i, dan target y_i, pred_arg_i, body_arg_i.

sdy.data_flow_edge menggunakan pemilik tepi sebagai input (dapat berupa target mana pun, tetapi sebaiknya hasil op, bukan argumen blok), yang tidak boleh memiliki penggunaan lain. Operasi ini tidak murni karena dapat mengambil input yang awalnya tidak memiliki kegunaan apa pun.

sdy.data_flow_edge juga menyimpan sharding opsional untuk semua target tepi, dan sharding tersebut harus diperbarui, bukan sharding target (jika dapat dilampirkan) selama penyebaran. Hal ini berguna saat op memiliki banyak tepi, karena jauh lebih efisien untuk:

  • ditransmisikan melalui setiap edge secara terpisah.
  • memperbarui sharding setiap edge secara terpisah, bukan semua target sekaligus (misalnya, op memiliki satu TensorShardingPerValueAttr yang tidak dapat diubah untuk sharding hasil).
  • tambahkan setiap tepi ke daftar tugas secara terpisah saat sharding sumber telah berubah.

Propagasi akan menyebarkan sharding di antara semua sumber dan target sdy.data_flow_edge seolah-olah itu adalah operasi reguler dengan sumber sebagai operand dan target sebagai hasil, serta identitas sdy.op_sharding_rule. Artinya, propagasi maju adalah dari sumber ke target dan propagasi mundur adalah dari target ke sumber.

Kami tidak mengizinkan input sdy.data_flow_edge ditentukan oleh op SdyDialect, sehingga kita dapat mengasumsikan bahwa input tersebut ditentukan oleh op yang memiliki atribut sdy.sharding yang tidak terdaftar.

Ciri: SameOperandsAndResultType

Antarmuka: InferTypeOpInterface

Atribut:

AtributJenis MLIRDeskripsi
sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
input dibentuk dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result dibentuk dari nilai jenis apa pun

sdy.manual_computation (sdy::ManualComputationOp)

Operasi paralelisme multi-perangkat dengan kolektif manual

Sintaksis:

operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)

Masuk ke wilayah yang ditulis dalam kode lokal per perangkat dengan kolektif eksplisit, dengan bentuk logis yang cocok dengan bentuk buffer fisik per perangkat lokal dan kolektif yang sama persis dengan komunikasi lintas perangkat fisik.

Isi adalah lokal terkait manual_axes. Propagasi akan terjadi melalui body pada sumbu bebas apa pun - yang tidak ada dalam daftar manual_axes.

Batasan:

  • Elemen di in_shardings dan out_shardings harus memenuhi batasan yang tercantum di TensorShardingAttr.
  • Jumlah input/output tensor global dan lokal dari region op harus cocok.
  • Sumbu manual harus muncul sebelum sumbu bebas di setiap sharding dimensi.
  • Sumbu manual tidak dapat menambahkan padding. Yaitu, ukuran dimensi harus dapat dibagi dengan ukuran sumbu manual yang sesuai.
  • Bentuk global dan lokal dari argumen/hasil region op harus cocok.
  • Tidak ada sumbu manual yang dipisahkan.

Sifat: IsolatedFromAbove, RecursiveMemoryEffects, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Antarmuka: ShardableDataFlowOpInterface

Atribut:

AtributJenis MLIRDeskripsi
in_shardings::mlir::sdy::TensorShardingPerValueAttrSharding tensor per operand/hasil operasi
out_shardings::mlir::sdy::TensorShardingPerValueAttrSharding tensor per operand/hasil operasi
manual_axes::mlir::sdy::ManualAxesAttrDaftar sumbu yang ManualComputationOp-nya bersifat manual

Operand:

Operand Deskripsi
tensors variabel dari tensor berperingkat dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
results variabel dari tensor berperingkat dari nilai jenis apa pun

sdy.mesh (sdy::MeshOp)

Mesh bernama

Sintaksis:

operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict

Menentukan mesh baru yang diberi nama. Semua mesh dalam modul harus memiliki jumlah perangkat yang sama (kecuali untuk mesh dengan satu device_id). Mesh adalah operasi Symbol yang muncul di SymbolTable modul dan dapat direferensikan oleh name-nya.

Ciri: HasParent<ModuleOp>

Antarmuka: Symbol

Atribut:

AtributJenis MLIRDeskripsi
sym_name::mlir::StringAttratribut string
mesh::mlir::sdy::MeshAttrMesh sumbu dan daftar perangkat

sdy.named_computation (sdy::NamedComputationOp)

Operasi komputasi bernama

Sintaksis:

operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)

Mengelompokkan komputasi, yaitu blok operasi, dan memberinya nama. Propagasi akan mengalir masuk/keluar dari region seolah-olah semuanya telah disisipkan.

Hal ini dapat digunakan untuk menangani penyebaran melalui petunjuk panggilan ke fungsi lain. Setiap pengguna Shardy harus menulis kartu impor/ekspor yang mengonversi operasi panggilan mereka menjadi operasi sdy.named_computation, menduplikasi/menyalin isi fungsi yang dipanggil ke dalam isi named_computation.

Jenis setiap argumen blok dan nilai yang ditampilkan di region harus sama dengan jenis operand dan jenis hasil op.

Contoh:

%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Sifat: IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatableImplTrait, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Antarmuka: ConditionallySpeculatable, InferTypeOpInterface, ShardableDataFlowOpInterface

Atribut:

AtributJenis MLIRDeskripsi
name::mlir::StringAttratribut string
in_shardings::mlir::sdy::TensorShardingPerValueAttrSharding tensor per operand/hasil operasi
out_shardings::mlir::sdy::TensorShardingPerValueAttrSharding tensor per operand/hasil operasi

Operand:

Operand Deskripsi
operands variadik dari jenis apa pun

Hasil:

Hasil Deskripsi
«unnamed» variadik dari jenis apa pun

sdy.propagation_barrier (sdy::PropagationBarrierOp)

Operasi penghalang propagasi

Sintaksis:

operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)

Operasi ini beroperasi seperti operasi identitas, yang menghasilkan nilai yang sama dengan yang diambil sebagai input. Namun, dalam hal propagasi, hal ini hanya akan memungkinkan propagasi mengalir melaluinya dalam arah tertentu.

Hal ini mencegah sharding disebarkan di antara penggunaan hasil operasi penghalang dan operandnya.

  • FORWARD berarti sharding hanya dapat mengalir dari operand ke hasil.
  • BACKWARD berarti sharding hanya dapat mengalir dari hasil ke operand.
  • NONE berarti tidak ada sharding yang dapat di-propagasi melalui operasi ini.
  • Tidak dapat menentukan BOTH, karena operasi ini akan menjadi redundan.

Sifat: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Antarmuka: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Efek: MemoryEffects::Effect{}

Atribut:

AtributJenis MLIRDeskripsi
allowed_direction::mlir::sdy::PropagationDirectionAttrenum arah penyebaran

Operand:

Operand Deskripsi
input tensor berperingkat dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor berperingkat dari nilai jenis apa pun

sdy.reshard (sdy::ReshardOp)

Melakukan resharding tensor ke sharding yang berbeda

Sintaksis:

operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)

Melakukan resharding pada tensor input dengan sharding yang ditentukan, yang berbeda dari sharding tensor input yang ada.

ShardingConstraintOp dan ReshardOp melampirkan sharding ke tensor. Masa aktifnya adalah:

  1. Sebelum penyebaran sharding, ShardingConstraintOp ditambahkan oleh pengguna.
  2. Penyebaran sharding menggunakan ShardingConstraintOp. Tidak ada ShardingConstraintOp dalam hasil propagasi sharding. Sebagai gantinya, ReshardOp dapat ditambahkan jika diperlukan.
  3. Partisioner mengonversi ReshardOp menjadi operasi kolektif (atau operasi identitas). Tidak boleh ada ReshardOp dalam hasil partisi.

// TODO(b/331680067). Tambahkan pola kanonisasi untuk menghapus operasi // reshard yang redundan.

Sifat: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Antarmuka: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Efek: MemoryEffects::Effect{}

Atribut:

AtributJenis MLIRDeskripsi
sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
input tensor dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor dari nilai jenis apa pun

sdy.return (sdy::ReturnOp)

Operasi sdy.return menghentikan region yang dilampirkan ke operasi berbasis region sdy dan operasi berbasis region Shardy lainnya. Fungsi ini bersifat variadik: fungsi ini menggunakan daftar nilai sebagai argumen yang jenisnya dapat berupa apa saja (tetapi dari jenis yang sama, misalnya AnyTensor) sehingga dapat digunakan kembali di berbagai tingkat stack IR Shardy.

Sintaksis:

operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?

Sifat: AlwaysSpeculatableImplTrait, Terminator

Antarmuka: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Efek: MemoryEffects::Effect{}

Operand:

Operand Deskripsi
results variadik dari jenis apa pun

sdy.sharding_constraint (sdy::ShardingConstraintOp)

Membatasi tensor ke sharding yang ditentukan

Sintaksis:

operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)

Melampirkan sharding ke tensor perantara (misalnya, hasil matmul) untuk menunjukkan bahwa ini adalah cara tensor tersebut, atau sebagian penggunaannya, harus di-sharding.

Jika sharding memiliki dimensi terbuka dan sumbu yang tidak dibatasi, berarti tensor dapat di-sharding lebih lanjut di sepanjang dimensi terbuka.

Operasi ini dapat:

  • Tidak memiliki penggunaan (tergantung) - yang berarti sharding yang dilampirkan adalah cara tensor input itu sendiri harus di-sharding.
  • Memiliki penggunaan - yang berarti sharding yang dilampirkan adalah cara penggunaan op batasan sharding harus di-sharding, sedangkan penggunaan lain dari tensor input mungkin memiliki sharding yang berbeda (jika tensor input tidak memiliki penggunaan lain, perilakunya sama dengan kasus tanpa penggunaan).

Ciri: SameOperandsAndResultType

Antarmuka: InferTypeOpInterface

Atribut:

AtributJenis MLIRDeskripsi
sharding::mlir::sdy::TensorShardingAttrSharding tensor

Operand:

Operand Deskripsi
input tensor dari nilai jenis apa pun

Hasil:

Hasil Deskripsi
result tensor dari nilai jenis apa pun

sdy.sharding_group (sdy::ShardingGroupOp)

Membatasi tensor dalam grup agar memiliki sharding yang sama.

Sintaksis:

operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)

Operasi ini menyediakan antarmuka untuk menetapkan tensor ke grup sharding ( grup tensor yang akan diberlakukan untuk memiliki sharding yang identik). Selama penyebaran, segera setelah satu elemen grup di-shard, semua anggota lainnya akan di-shard dengan cara yang sama persis. Operasi ini mengambil ID grup argumen dan tidak menampilkan hasil, tetapi mengubah representasi grup sharding internal untuk menambahkan tensor input ke grup dengan ID yang diberikan.

Antarmuka: InferTypeOpInterface

Atribut:

AtributJenis MLIRDeskripsi
group_id::mlir::IntegerAttrAtribut bilangan bulat tanpa tanda 64-bit

Operand:

Operand Deskripsi
input tensor berperingkat dari nilai jenis apa pun

Atribut

AllToAllParamAttr

Parameter all-to-all

Sintaksis:

#sdy.all_to_all_param<
  ::llvm::ArrayRef<AxisRefAttr>,   # axes
  int64_t,   # src_dim
  int64_t   # tgt_dim
>

Tupla yang berisi sumbu dan dimensi sumber/target untuk melakukan semua-ke-semua.

Parameter:

Parameter Jenis C++ Deskripsi
sumbu ::llvm::ArrayRef<AxisRefAttr> sumbu untuk melakukan semua-ke-semua
src_dim int64_t indeks dimensi sumber
tgt_dim int64_t indeks dimensi target

AlltoAllParamListAttr

Daftar parameter all-to-all

Sintaksis:

#sdy.all_to_all_param_list<
  ::llvm::ArrayRef<AllToAllParamAttr>   # value
>

Parameter:

Parameter Jenis C++ Deskripsi
nilai ::llvm::ArrayRef<AllToAllParamAttr>

AxisRefAttr

Referensi ke sumbu penuh atau sub-sumbu terpisah

Sintaksis:

#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>

Batasan:

  • name harus ada dalam MeshAttr yang terikat.
  • Jika ada, sub_axis_info harus memenuhi batasan SubAxisInfoAttr.

Parameter:

Parameter Jenis C++ Deskripsi
nama ::llvm::StringRef nama sumbu ini
sub_axis_info SubAxisInfoAttr info tambahan jika ini adalah sub-sumbu

AxisRefListAttr

Daftar referensi sumbu

Sintaksis:

#sdy.axis_ref_list<
  ::llvm::ArrayRef<AxisRefAttr>   # value
>

Batasan:

  • Elemen di value harus memenuhi batasan AxisRefAttr.
  • Tidak ada referensi sumbu atau sub-sumbu duplikat yang tumpang-tindih satu sama lain.
  • Tidak ada dua referensi sumbu yang berdekatan yang merupakan sub-sumbu berturut-turut dari sumbu penuh yang sama, yaitu, keduanya dapat digabungkan menjadi satu sub-sumbu atau sumbu penuh.

Parameter:

Parameter Jenis C++ Deskripsi
nilai ::llvm::ArrayRef<AxisRefAttr>

DimMappingAttr

Daftar indeks faktor untuk dimensi

Daftar kosong menunjukkan bahwa ini adalah pemetaan null (ini diuraikan/dicetak dengan *), yaitu dimensi tidak dipetakan ke faktor apa pun.

Batasan:

  • Ada minimal satu indeks faktor.
  • Indeks faktor harus dalam rentang [0, $factor_sizes).
  • Jika ada beberapa faktor, tidak ada satu pun yang dapat memiliki ukuran 1.
  • Tidak ada indeks faktor duplikat.

Parameter:

Parameter Jenis C++ Deskripsi
factor_indices ::llvm::ArrayRef<int64_t> faktor yang dipetakan ke dimensi ini

DimensionShardingAttr

Sharding dimensi

Daftar nama sumbu untuk membuat shard dimensi tensor dari utama ke minor, Boolean yang menunjukkan apakah dimensi dapat di-shard lebih lanjut, dan bilangan bulat opsional yang menunjukkan prioritas sharding dimensi ini, yang akan dipatuhi selama penyebaran sharding. Prioritas berasal dari anotasi sharding pengguna dan nilai yang lebih rendah menunjukkan prioritas yang lebih tinggi. Prioritas tertinggi diasumsikan jika prioritas tidak ada dalam anotasi.

Batasan:

  • Elemen di axes harus memenuhi batasan yang tercantum di AxisRefListAttr.
  • Jika sharding dimensi memiliki prioritas:
    • Prioritas lebih besar dari atau sama dengan 0.
    • Dimensi memiliki minimal satu sumbu jika ditutup.

Parameter:

Parameter Jenis C++ Deskripsi
sumbu ::llvm::ArrayRef<AxisRefAttr> referensi sumbu
is_closed bool apakah dimensi ini tidak dapat di-sharding lebih lanjut
prioritas std::optional<int64_t> prioritas yang digunakan selama propagasi berbasis prioritas pengguna

ListOfAxisRefListsAttr

Daftar daftar referensi sumbu

Sintaksis:

#sdy.list_of_axis_ref_lists<
  ::llvm::ArrayRef<AxisRefListAttr>   # value
>

Parameter:

Parameter Jenis C++ Deskripsi
nilai ::llvm::ArrayRef<AxisRefListAttr>

ManualAxesAttr

Daftar sumbu yang ManualComputationOp-nya manual

Sintaksis:

#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>

Parameter:

Parameter Jenis C++ Deskripsi
nilai ::llvm::ArrayRef<StringAttr>

MeshAttr

Mesh sumbu dan daftar perangkat

Sintaksis:

#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>

Mesh adalah daftar sumbu dan daftar ID perangkat opsional yang menentukan urutan perangkat.

Jika daftar sumbu kosong, mesh memiliki sumbu implisit tanpa nama dengan ukuran 1. Dalam hal ini, jika daftar ID perangkat tidak diberikan, daftar ID perangkat implisit adalah [0]; jika daftar ID perangkat diberikan, daftar tersebut harus berisi satu bilangan bulat dari nilai non-negatif. Kami menyebutnya kasus sharding maksimal.

Untuk semua kasus sharding non-maksimum, jika daftar ID perangkat ditentukan, produk ukuran sumbu harus cocok dengan jumlah perangkat. Jika daftar ID perangkat tidak ditentukan, daftar ID perangkat implisit adalah iota(product(axes)). Untuk memudahkan, kami juga tidak mengizinkan penentuan daftar ID perangkat yang sama dengan iota(product(axes)); dalam hal ini, daftar ID perangkat tidak boleh ditentukan.

Berikut beberapa contoh mesh:

  • Mesh kosong mewakili mesh placeholder yang dapat diganti selama penyebaran: <[]>
  • Mesh dengan sumbu tanpa nama dan ID perangkat eksplisit, yang biasanya digunakan untuk merepresentasikan sharding maksimum: <[], device_ids=[3]>
  • Mesh dengan dua sumbu dan ID perangkat implisit iota(6): <["a"=2, "b"=3]>
  • Mesh dengan dua sumbu dan ID perangkat eksplisit yang menentukan urutan perangkat: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>

Batasan:

  • Elemen di axes tidak boleh memiliki nama duplikat.
  • Jika device_ids ditentukan:
    • Hasil ukuran sumbu harus sesuai dengan jumlah perangkat.
    • Semua elemennya tidak boleh negatif.
    • device_ids tidak boleh sama dengan iota(product(axis_sizes)).
    • device_ids yang diurutkan harus iota(product(axis_sizes)).

Parameter:

Parameter Jenis C++ Deskripsi
sumbu ::llvm::ArrayRef<MeshAxisAttr> sumbu mesh
device_ids ::llvm::ArrayRef<int64_t> pengurutan perangkat eksplisit atau ID perangkat maksimum

MeshAxisAttr

Sumbu bernama dalam mesh

Sintaksis:

#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>

Parameter:

Parameter Jenis C++ Deskripsi
nama ::llvm::StringRef nama
ukuran int64_t ukuran sumbu ini

OpShardingRuleAttr

Menentukan cara operasi dapat dipartisi.

Sintaksis:

#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  ::llvm::ArrayRef<int64_t>,   # reduction_factors
  ::llvm::ArrayRef<int64_t>,   # need_replication_factors
  ::llvm::ArrayRef<int64_t>,   # permutation_factors
  ::llvm::ArrayRef<int64_t>,   # blocked_propagation_factors
  bool   # is_custom_rule
>

Aturan sharding menentukan cara operasi dapat dipartisi sesuai dengan berbagai properti pada op - atribut apa pun, bentuk operand, bentuk hasil, dll. Misalnya:

%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>

Perhatikan bahwa kami mengizinkan faktor dengan ukuran 1 meskipun tidak dapat di-shard, hal ini terutama untuk kelengkapan karena banyak operasi seperti operasi pointwise memiliki dimensi ukuran satu yang sesuai di seluruh operand dan hasil.

Jenis faktor:

  • reduction_factors berisi indeks faktor yang memerlukan pengurangan, seperti dimensi kontraksi dalam operasi titik.
  • need_replication_factors berisi indeks faktor yang memerlukan replika penuh, seperti dimensi yang diurutkan dalam operasi pengurutan.
  • permutation_factors berisi indeks faktor yang memerlukan permutasi kolektif jika di-shard, seperti dimensi padding dalam operasi pad.
  • Semua faktor lainnya dianggap sebagai faktor pass-through, yaitu faktor yang tidak memerlukan komunikasi apa pun jika di-shard dengan cara yang sama di semua tensor yang dipetakan ke faktor tersebut.

blocked_propagation_factors berisi faktor yang tidak mengizinkan penyebaran sharding. Ini ortogonal dengan jenis faktor. Yaitu, faktor penyebaran yang diblokir dapat berupa jenis faktor apa pun.

is_custom_rule menjelaskan apakah ini adalah aturan yang ditentukan oleh pengguna. Pengguna dapat menentukan aturan sharding untuk panggilan kustom mereka atau menimpa aturan sharding standar yang telah ditentukan sebelumnya untuk operasi standar. Aturan kustom selalu dipertahankan/tidak pernah dihapus.

Batasan:

  • Jumlah pemetaan operand/hasil harus cocok dengan jumlah operand/hasil operasi.
  • Ada minimal satu pemetaan (tidak dapat memiliki aturan untuk operasi tanpa operand/hasil).
  • Peringkat setiap TensorMappingAttr cocok dengan peringkat jenis tensor yang sesuai.
  • Untuk setiap grup faktor (reduction_factors, need_replication_factors, permutation_factors):
    • Elemen harus berada dalam rentang [0, $factor_sizes].
    • Tidak ada indeks faktor duplikat dalam setiap grup dan di seluruh grup.

Parameter:

Parameter Jenis C++ Deskripsi
factor_sizes ::llvm::ArrayRef<int64_t> ukuran semua faktor dalam aturan ini
operand_mappings ::llvm::ArrayRef<TensorMappingAttr> pemetaan operand
result_mappings ::llvm::ArrayRef<TensorMappingAttr> pemetaan hasil
reduction_factors ::llvm::ArrayRef<int64_t> faktor yang memerlukan pengurangan
need_replication_factors ::llvm::ArrayRef<int64_t> faktor yang memerlukan replikasi penuh
permutation_factors ::llvm::ArrayRef<int64_t> faktor yang memerlukan permutasi kolektif
blocked_propagation_factors ::llvm::ArrayRef<int64_t> faktor yang tidak menyebarkan sharding
is_custom_rule bool apakah aturan tersebut untuk stablehlo.custom_call

SubAxisInfoAttr

Info tentang cara sub-sumbu ini berasal dari sumbu penuh

Sintaksis:

#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>

Saat membagi sumbu penuh menjadi n sub-sumbu, sumbu akan dibentuk ulang menjadi [k_1,...,k_n], dan sub-sumbu ke-i dapat dinyatakan dengan produk semua ukuran sumbu di sebelah kirinya m=prod(k_1,...,k_(i-1)) (alias pra-ukuran) dan ukuran k_i. Oleh karena itu, atribut sub-axis-info menyimpan dua angka tersebut dan ditunjukkan sebagai berikut: (m)k untuk ukuran awal m dan ukuran k.

Batasan:

  • pre-size minimal 1.
  • size lebih besar dari 1.
  • pre-size harus membagi ukuran sumbu penuh, yaitu pre-size dan size membagi ukuran sumbu penuh, dan sub-sumbu tidak melebihi sumbu penuh.
  • Ukuran sub-sumbu tidak sama dengan ukuran sumbu penuh yang sesuai, dalam hal ini sumbu penuh harus digunakan.

Parameter:

Parameter Jenis C++ Deskripsi
pre_size int64_t produk ukuran sub-sumbu di sebelah kiri sub-sumbu ini
ukuran int64_t ukuran sub-sumbu ini

TensorMappingAttr

Pemetaan faktor untuk setiap dimensi tensor.

Sintaksis:

#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>

Batasan:

  • Elemen di dim_mappings harus memenuhi batasan di DimMappingAttr.
  • Tidak ada indeks faktor duplikat di seluruh dimensi.

Parameter:

Parameter Jenis C++ Deskripsi
dim_mappings ::llvm::ArrayRef<DimMappingAttr> pemetaan dimensi

TensorShardingAttr

Sharding tensor

Sintaksis:

#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>   # replicated_axes
>

Sharding tensor terikat ke mesh tertentu, dan hanya dapat mereferensikan nama sumbu dari mesh tersebut. Sharding dimensi memberi tahu kita untuk setiap dimensi tensor, di sepanjang sumbu (atau sub-sumbu) yang di-sharding dari utama ke minor. Semua sumbu lain yang tidak melakukan shard dimensi direplikasi secara implisit atau secara eksplisit (jika muncul dalam daftar sumbu yang direplikasi).

Mesh yang menjadi tempat sharding ini dapat ditentukan oleh nama simbol, yang mereferensikan simbol MeshOp yang sesuai, atau MeshAttr yang disisipkan.

Batasan:

  • Elemen di dim_shardings harus memenuhi batasan yang tercantum di DimensionShardingAttr.
  • Elemen di replicated_axes harus memenuhi batasan yang tercantum di AxisRefListAttr.
  • Jika jenis tensor yang sesuai bukan ShapedType, sharding harus memiliki peringkat 0 dan tidak ada sumbu yang direplikasi.
  • Tensor harus memiliki peringkat.
  • Jumlah sharding dimensi sama dengan pangkat tensor.
  • Dimensi berukuran 0 tidak di-shard.
  • Item di replicated_axes diurutkan berdasarkan mesh_or_ref (lihat AxisRefAttr::getMeshComparator).

Parameter:

Parameter Jenis C++ Deskripsi
mesh_or_ref ::mlir::Attribute atribut mesh atau atribut referensi simbol mesh datar
dim_shardings ::llvm::ArrayRef<DimensionShardingAttr> sharding dimensi
replicated_axes ::llvm::ArrayRef<AxisRefAttr> referensi sumbu

TensorShardingPerValueAttr

Sharding tensor per operand/hasil operasi

Sintaksis:

#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>

Daftar TensorShardingAttr, satu untuk setiap operand/hasil operasi.

Batasan:

  • Elemen di shardings harus memenuhi batasan TensorShardingAttr.

Parameter:

Parameter Jenis C++ Deskripsi
sharding ::llvm::ArrayRef<TensorShardingAttr> sharding per nilai

Enum

PropagationDirection

Enum arah penyebaran

Kasus:

Simbol Nilai String
TIDAK ADA 0 TIDAK ADA
MAJU 1 MAJU
KEMBALI 2 KEMBALI
KEDUANYA 3 KEDUANYA