Sfondo
Presumiamo che i lettori abbiano familiarità almeno con le nozioni di base della rappresentazione dello sharding, che descrive come lo sharding di un tensore può essere espresso in Shardy. Questo documento mostra come le rappresentazioni di suddivisione possono essere utilizzate in un programma, ad esempio per associare una suddivisione a un tensore specifico del programma.
La propagazione dello sharding è il processo di decisione di uno sharding per ogni tensore in un programma, dati i vincoli di sharding per un sottoinsieme di tensori. L'API compiler di Shardy offre diversi modi per influenzare/controllare la propagazione dello sharding. Inoltre, consente agli utenti di inserire calcoli suddivisi manualmente nei propri programmi.
Obiettivo
Questo documento descrive il design di questi componenti dell'API in Shardy e ne spiega il comportamento e le invarianti. Tieni presente che, sebbene questa API venga utilizzata per controllare la propagazione dello sharding, in questo documento NON verrà discusso il comportamento della propagazione né la sua progettazione.
Panoramica
Sharding di input/output: associa uno sharding a un input o un output della funzione principale per indicare in che modo il tensore di input/output deve essere suddiviso quando viene assegnato alla funzione o restituito dalla funzione.
Limite di suddivisione: associa una suddivisione a un tensore intermedio (ad es. il risultato di una moltiplicazione matriciale) per indicare come deve essere suddiviso il tensore o un sottoinsieme dei relativi utilizzi.
Gruppo di suddivisione in parti: raggruppa più tensori in base a un ID per indicare che devono essere suddivisi nello stesso modo.
Calcolo manuale: racchiude un sottocalcolo partizionato manualmente utilizzando un sottoinsieme di assi del mesh, dove gli sharding lungo questi assi manuali sono specificati per tutti gli input e le uscite, e all'interno del sottocalcolo i tipi di tensore sono locali rispetto a questi sharding.
Progetto dettagliato
Sharding di input/output
Consente agli utenti di specificare uno sharding per gli input e le uscite della funzione principale.
In MLIR, gli attributi possono essere associati agli argomenti e ai risultati delle funzioni, pertanto gli utenti possono associare gli attributi di sharding alla funzione in questo modo.
Ad esempio:
@mesh_xy = <["x"=2, "y"=2]>
// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
{sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
%arg1: tensor<8x16xf32>)
-> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
...
}
Vincolo di suddivisione
Consente agli utenti di associare uno sharding a un tensore intermedio nel programma, indicando al partizionatore in che modo deve essere suddiviso il tensore o un sottoinsieme dei suoi usi.
Si tratta di un'operazione MLIR che prende il tensore come input e ha un attributo di suddivisione associato. L'operazione può:
- Non avere utilizzi (non collegati), il che significa che lo sharding allegato è il modo in cui deve essere suddiviso il tensore stesso.
- Hanno utilizzi: significa che lo sharding allegato è il modo in cui devono essere suddivisi gli utilizzi dell'operazione di vincolo di sharding, mentre altri utilizzi del tensore di input potrebbero avere uno sharding diverso (se il tensore di input non ha altri utilizzi, il comportamento è lo stesso del caso senza utilizzi). La propagazione determinerà lo sharding del tensore stesso e lo sharding se necessario.
Può avere suddivisioni delle dimensioni aperte, il che significa che l'operando può essere suddiviso ulteriormente in base agli assi disponibili.
@mesh_xy = <["x"=2, "y"=2]>
%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>
Gruppo di shard
Nei casi in cui non esistono dipendenze di dati o dipendenze di dati forti tra due o più tensori, mentre gli utenti sanno che questi tensori devono essere suddivisi nello stesso modo o in modo simile, l'API Shardy offre un modo per specificare questa relazione. In questo modo, gli utenti hanno la libertà di specificare esplicitamente che i tensori devono essere partizionati in modo uguale.
Per farlo, introduciamo il concetto di gruppi di shard, in cui ogni gruppo contiene un numero qualsiasi di istruzioni associate allo stesso ID gruppo di shard. I gruppi di suddivisione in parti forzano la suddivisione in parti all'interno dello stesso gruppo.
Ad esempio, in un programma utente ipotetico come quello mostrato di seguito, vogliamo suddividere l'output del programma esattamente come l'input del programma senza che esistano dipendenze di dati tra i due.
Se eseguiamo questo programma, la propagazione dello sharding non potrà dedurre lo sharding dei tensori %1
e %2
, che finiranno per essere replicati.
Tuttavia, associando un attributo shard_group
che indica che l'input %0
e l'output %2
si trovano nello stesso shard_group
, consentiamo la propagazione del frammento
@mesh_xy,
[{"x"},{"y"}]>
dall'input %0
all'output
%2
e, a sua volta, al resto del grafo, che viene trasmesso come costante %1
qui. Possiamo assegnare un valore a un gruppo con l'operazione sdy.sharding_group
.
@mesh_xy = <["x"=2, "y"=2]>
module @"jit_zeros_like" {
func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
%0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
%1 = stablehlo.constant dense<0> : tensor<8x2xi64>
%2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
return %2 : tensor<8x2xi64>
}
}
In questo semplice esempio riportato sopra, in alternativa, avremmo potuto specificare esplicitamente lo stesso suddivisione in output e in input, ottenendo lo stesso effetto, poiché sappiamo già quale shard vogliamo assegnare all'input in anticipo, ma in casi più realistici, utilizziamo lo shard per mantenere sincronizzato lo sharding di più tensori senza necessariamente conoscere lo sharding per nessuno di essi, mentre Shardy si occuperà del resto e troverà lo sharding migliore da assegnare.
Calcolo manuale
Gli utenti potrebbero voler avere un controllo esplicito sulla modalità di suddivisione delle parti del calcolo e sui collettivi utilizzati. Ad esempio, alcuni utenti vogliono applicare manualmente la moltiplicazione matriciale collettiva (dall'API frontend) anziché posticipare al compilatore. Forniamo un'API di calcolo manuale che consente di eseguire questa operazione.
Questa è l'operazione MLIR con una singola regione per il sottocalcolo manuale. Gli utenti specificano gli shard di input/output per questo sottocalcolo utilizzando un sottoinsieme (possibilmente tutti) degli assi della maglia. Il sottocalcolo sarà locale/manuale rispetto agli assi del mesh specificati (ovvero assi manuali) e globale/non partizionato rispetto a quelli non specificati (ovvero assi liberi). Il calcolo parziale può essere ulteriormente suddiviso lungo gli assi liberi durante la propagazione, come avviene per il calcolo al di fuori di questa operazione.
Ad esempio:
@mesh_name = <["data"=2, "model"=2]>
%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
out_shardings=[<@mesh_name, [{"data"}, {?}]>]
manual_axes={"data"}
(%arg1: tensor<8x32xf32>) {
// body
return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Invariati
Tutti i valori
in_shardings
,out_shardings
emanual_axes
devono fare riferimento alla stessa maglia.manual_axes
è ordinato rispetto alla mesh.manual_axes
deve essere utilizzato esplicitamente in tutti gli shard in/out, ovvero per ogni shard tutti gli assi manuali devono suddividere una dimensione o essere replicati esplicitamente.Se in uno dei suddivisioni in parti in/out esiste un asse libero (qualsiasi asse del mesh non in
manual_axes
), deve essere secondario rispetto a qualsiasi asse manuale nello stesso suddivisione in parti della dimensione (nell'esempio precedente, un suddivisione in parti della dimensione{"model", "data"}
non sarebbe valida).La regione/il corpo del calcolo è il calcolo locale (ad es. inclusi i collettivi specificati dall'utente). Deve essere locale rispetto allo sharding in/out lungo gli assi manuali (vedi nota sopra).
Nidificazione di calcoli manuali
Puoi nidificare più calcoli manuali l'uno dentro l'altro, a condizione che ciascuno operi su un proprio insieme univoco di assi manuali.