Latar belakang
Kami mengasumsikan pembaca sudah memahami setidaknya dasar-dasar representasi sharding, yang menjelaskan cara sharding tensor dapat dinyatakan dalam Shardy. Dokumen ini menunjukkan cara representasi sharding dapat digunakan dalam program, misalnya untuk melampirkan sharding ke tensor tertentu dari program.
Penyebaran sharding adalah proses menentukan sharding untuk setiap tensor dalam program yang diberikan batasan sharding untuk sebagian tensor. API compiler Shardy mengekspos beberapa cara untuk memengaruhi/mengontrol penyebaran sharding. Selain itu, fitur ini memungkinkan pengguna menyisipkan komputasi yang di-shard secara manual ke dalam program mereka.
Tujuan
Dokumen ini menjelaskan desain komponen API tersebut di Shardy dan menjelaskan perilaku dan invariannya. Perhatikan bahwa meskipun API ini digunakan untuk mengontrol penyebaran sharding, dokumen ini TIDAK akan membahas apa pun tentang perilaku penyebaran atau cara mendesainnya.
Ringkasan
Sharding input/output - lampirkan sharding ke input atau output fungsi utama, untuk menunjukkan bahwa ini adalah cara tensor input/output harus di-sharding saat diberikan ke/ditampilkan dari fungsi.
Batasan Sharding - lampirkan sharding ke tensor perantara (misalnya, hasil matmul) untuk menunjukkan bahwa ini adalah cara tensor tersebut, atau sebagian penggunaannya, harus di-sharding.
Grup Sharding - mengelompokkan beberapa tensor berdasarkan ID untuk menunjukkan bahwa tensor tersebut harus di-sharding dengan cara yang sama.
Penghitungan Manual - mencakup sub-komputasi yang dipartisi secara manual menggunakan subset sumbu mesh, dengan sharding di sepanjang sumbu manual tersebut ditentukan untuk semua input dan output, dan di dalam sub-komputasi, jenis tensor bersifat lokal sehubungan dengan sharding tersebut.
Desain Terperinci
Sharding input/output
Mengizinkan pengguna menentukan sharding untuk input dan output fungsi utama.
Di MLIR, atribut dapat dilampirkan ke argumen dan hasil fungsi, sehingga pengguna dapat melampirkan atribut sharding ke fungsi dengan cara ini.
Contoh:
@mesh_xy = <["x"=2, "y"=2]>
// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
{sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
%arg1: tensor<8x16xf32>)
-> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
...
}
Batasan Sharding
Memungkinkan pengguna melampirkan sharding ke tensor perantara dalam program mereka, yang memberi tahu partisi bahwa ini adalah cara tensor tersebut, atau sebagian dari penggunaannya, harus di-sharding.
Ini adalah operasi MLIR yang menggunakan tensor sebagai input, dan memiliki atribut sharding yang terpasang padanya. Operasi dapat:
- Tidak memiliki penggunaan (tergantung) - yang berarti sharding yang dilampirkan adalah cara tensor itu sendiri harus di-sharding.
- Memiliki penggunaan - yang berarti sharding yang dilampirkan adalah cara penggunaan op batasan sharding harus di-sharding, sedangkan penggunaan lain dari tensor input mungkin memiliki sharding yang berbeda (jika tensor input tidak memiliki penggunaan lain, perilakunya sama dengan kasus tidak ada penggunaan). Propagasi akan menentukan sharding tensor itu sendiri dan melakukan sharding ulang jika diperlukan.
Tabel ini dapat memiliki sharding dimensi terbuka, yang berarti operand dapat di-sharding lebih lanjut di sepanjang sumbu yang tersedia.
@mesh_xy = <["x"=2, "y"=2]>
%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>
Grup Sharding
Jika tidak ada dependensi data atau tidak ada dependensi data yang kuat antara dua tensor atau lebih, meskipun pengguna memiliki pengetahuan bahwa tensor tersebut harus dipartisi dengan cara yang sama atau serupa, Shardy API menawarkan cara untuk menentukan hubungan ini. Hal ini memberi pengguna kebebasan untuk menentukan secara eksplisit bahwa tensor harus dipartisi satu sama lain.
Untuk mencapai hal ini, kami memperkenalkan konsep grup shard, dengan setiap grup berisi sejumlah petunjuk yang dikaitkan dengan ID grup shard yang sama. Grup sharding menerapkan sharding dalam grup yang sama agar sama.
Misalnya, dalam program pengguna hipotetis seperti yang ditunjukkan di bawah, kita ingin membagi output program persis sama dengan input program saat tidak ada dependensi data di antara keduanya.
Jika kita menjalankan program ini, propagasi sharding tidak akan dapat menyimpulkan
sharding tensor %1
dan %2
, dan keduanya akan direplikasi.
Namun, dengan melampirkan atribut shard_group
yang menyatakan bahwa input %0
dan output %2
berada dalam shard_group
yang sama, kita mengizinkan sharding
@mesh_xy,
[{"x"},{"y"}]>
untuk disebarkan dari input %0
ke output
%2
, dan pada gilirannya ke seluruh grafik, yang disiarkan konstan %1
di sini. Kita dapat menetapkan nilai ke grup dengan
operasi sdy.sharding_group
.
@mesh_xy = <["x"=2, "y"=2]>
module @"jit_zeros_like" {
func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
%0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
%1 = stablehlo.constant dense<0> : tensor<8x2xi64>
%2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
return %2 : tensor<8x2xi64>
}
}
Dalam contoh sederhana di atas, sebagai alternatif, kita dapat secara eksplisit menentukan sharding yang sama pada output sebagai input, yang akan mencapai efek yang sama, karena kita sudah mengetahui shard yang ingin ditetapkan ke input terlebih dahulu tetapi dalam kasus yang lebih realistis, kita menggunakan shard agar sharding beberapa tensor tetap sinkron tanpa harus mengetahui sharding untuk salah satu darinya, sementara Shardy akan menangani sisanya dan menemukan sharding terbaik untuk ditetapkan ke mereka.
Komputasi Manual
Pengguna mungkin menginginkan kontrol eksplisit atas cara bagian komputasi mereka dipartisi, dan kolektif apa yang digunakan. Misalnya, beberapa pengguna ingin menerapkan matmul kolektif secara manual (dari API frontend) daripada menunda ke compiler. Kami menyediakan Manual Computation API yang memungkinkan mereka melakukannya.
Ini adalah operasi MLIR dengan satu region untuk subkomputasi manual. Pengguna akan menentukan sharding input/output ke subkomputasi ini menggunakan subkumpulan (termasuk mungkin semua) sumbu mesh. Sub-komputasi akan bersifat lokal/manual sehubungan dengan sumbu mesh yang ditentukan (alias sumbu manual), dan global/tidak dipartisi sehubungan dengan sumbu yang tidak ditentukan (alias sumbu bebas). Sub-komputasi dapat di-shard lebih lanjut di sepanjang sumbu bebas selama propagasi dengan cara yang sama seperti komputasi di luar operasi ini.
Contoh:
@mesh_name = <["data"=2, "model"=2]>
%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
out_shardings=[<@mesh_name, [{"data"}, {?}]>]
manual_axes={"data"}
(%arg1: tensor<8x32xf32>) {
// body
return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Invarian
Semua
in_shardings
,out_shardings
, danmanual_axes
harus merujuk ke mesh yang sama.manual_axes
diurutkan berdasarkan mesh.manual_axes
harus digunakan secara eksplisit dalam semua sharding masuk/keluar, yaitu, untuk setiap sharding, semua sumbu manual harus mengelompokkan dimensi atau direplikasi secara eksplisit.Jika sumbu bebas (sumbu mesh apa pun yang tidak ada di
manual_axes
) ada di salah satu sharding masuk/keluar, sumbu tersebut harus lebih kecil dari sumbu manual apa pun dalam sharding dimensi yang sama (dalam contoh di atas, sharding dimensi{"model", "data"}
akan tidak valid).Region/isi komputasi adalah komputasi lokal (misalnya, termasuk kolektif yang ditentukan pengguna). Ini harus bersifat lokal sehubungan dengan sharding masuk/keluar di sepanjang sumbu manual (lihat catatan di atas).
Menempatkan komputasi manual dalam bertingkat
Anda dapat menyusun bertingkat beberapa komputasi manual satu sama lain selama setiap komputasi beroperasi pada kumpulan sumbu manualnya yang unik.