Le dialecte Shardy (SDY) définit une représentation de fractionnement de tenseur basée sur les axes et des composants d'API supplémentaires pour associer des fractionnements à des tenseurs.
Opérations
sdy.all_gather
(sdy::AllGatherOp)
Effectue une communication all-gather le long des axes.
Syntaxe :
operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Réunit des segments d'un tenseur le long des axes spécifiés dans gathering_axes
.
gathering_axes
est une liste de listes d'axes. La liste externe dépasse les dimensions du tenseur. Chaque liste interne spécifie les axes sur lesquels un rassemblement distinct doit être effectué sur la dimension correspondante. Il sera appliqué au fractionnement de l'opérande (tensor
) pour obtenir le fractionnement du résultat (out_sharding
).
Notez que out_sharding
n'est pas utilisé pour déterminer le fractionnement du résultat. Au lieu de cela, le fractionnement du résultat est déterminé par le fractionnement de l'opérande et de gathering_axes
, et out_sharding
doit correspondre à ce fractionnement inféré.
Exemple :
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>
Contraintes :
- doit respecter les contraintes listées dans
Sdy_CollectiveOpInterface
. - Les éléments de
gathering_axes
doivent respecter les contraintes listées dansAxisRefListAttr
. - L'application de
gathering_axes
au fractionnement de l'opérande génèreout_sharding
.
Caractéristiques: SameOperandsAndResultType
Interfaces: InferTypeOpInterface
, Sdy_CollectiveOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
gathering_axes | ::mlir::sdy::ListOfAxisRefListsAttr | Liste des listes de références d'axe |
out_sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
tensor |
tenseur de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur de valeurs de n'importe quel type |
sdy.all_reduce
(sdy::AllReduceOp)
Effectuer une communication all-reduce le long des axes
Syntaxe :
operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Réduit les segments d'un tenseur le long des axes spécifiés dans reduction_axes
.
L'ordre de reduction_axes
n'a pas d'importance pour le résultat, mais il peut affecter l'ordre des groupes de réplicas correspondants.
Contraintes :
- doit respecter les contraintes listées dans
Sdy_CollectiveOpInterface
. reduction_axes
doit respecter les contraintes listées dansAxisRefListAttr
.reduction_axes
ne doit pas chevaucher les axes de fractionnement de l'opérande.
Caractéristiques: SameOperandsAndResultType
Interfaces: CollectiveOpInterface
, InferTypeOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
reduction_axes | ::mlir::sdy::AxisRefListAttr | Liste des références d'axe |
out_sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
tensor |
tenseur de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur de valeurs de n'importe quel type |
sdy.all_slice
(sdy::AllSliceOp)
Effectue une opération de tranche dynamique le long des axes.
Syntaxe :
operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Tranche des segments d'un tenseur selon les axes spécifiés dans slicing_axes
. Il existe une dualité algébrique entre sdy.all_slice
et sdy.all_gather
.
slicing_axes
est une liste de listes d'axes. La liste externe dépasse les dimensions du tenseur. Chaque liste interne spécifie les axes sur lesquels une tranche doit être effectuée pour la dimension correspondante. Il sera appliqué au fractionnement de l'opérande (tensor
) pour obtenir le fractionnement du résultat (out_sharding
).
Notez que out_sharding
n'est pas utilisé pour déterminer le fractionnement du résultat. Au lieu de cela, le fractionnement du résultat est déterminé par le fractionnement de l'opérande et de slicing_axes
, et out_sharding
doit correspondre à ce fractionnement inféré.
Exemple :
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>
Contraintes :
- Les éléments de
slicing_axes
doivent respecter les contraintes listées dansAxisRefListAttr
. - doit respecter les contraintes listées dans
Sdy_CollectiveOpInterface
. - L'application de
slicing_axes
au fractionnement de l'opérande génèreout_sharding
.
Caractéristiques: SameOperandsAndResultType
Interfaces: CollectiveOpInterface
, InferTypeOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
slicing_axes | ::mlir::sdy::ListOfAxisRefListsAttr | Liste des listes de références d'axe |
out_sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
tensor |
tenseur de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur de valeurs de n'importe quel type |
sdy.all_to_all
(sdy::AllToAllOp)
Effectue une communication de type "tout-à-tous" le long des axes.
Syntaxe :
operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Pour chaque tuple (axes, src_dim, tgt_dim) de la liste des paramètres, cette opération découpe des segments d'un tenseur selon la dimension tgt_dim
et les axes spécifiés dans axes
, les disperse selon les axes et les concatène selon la dimension src_dim
.
Cette opération est essentiellement une combinaison d'une opération de collecte complète sur src_dim
et axes
, suivie d'une tranche complète sur tgt_dim
et axes
, c'est-à-dire qu'un suffixe de la dimension de fractionnement des axes src_dim
sur le tenseur d'entrée est ajouté à la dimension de fractionnement des axes tgt_dim
sur le tenseur de sortie.
La distribution de tous les nœuds vers tous les nœuds sera appliquée au fractionnement de l'opérande (tensor
) pour obtenir le fractionnement du résultat (out_sharding
).
Notez que out_sharding
n'est pas utilisé pour déterminer le fractionnement du résultat. À la place, le fractionnement du résultat est déterminé par le fractionnement de l'opérande, src_dim
, tgt_dim
et axes
, et out_sharding
doit correspondre à ce fractionnement inféré.
Exemple :
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>
Contraintes :
- doit respecter les contraintes listées dans
Sdy_CollectiveOpInterface
. - La liste des paramètres est obligatoire.
- Pour chaque paramètre de
params
:- Les éléments de
axes
doivent respecter les contraintes deAxisRefAttr
. src_dim
ettgt_dim
doivent être des dimensions valides (non négatives et inférieures au rang du tenseur).- Les
src_dim
outgt_dim
doivent être uniques pour tous les paramètres. src_dim
doit être trié par ordre croissant pour tous les paramètres.
- Les éléments de
- Déplacer
axes
desrc_dim
verstgt_dim
dans le fractionnement des opérandes génèreout_sharding
.
Caractéristiques: SameOperandsAndResultType
Interfaces: InferTypeOpInterface
, Sdy_CollectiveOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
params | ::mlir::sdy::AlltoAllParamListAttr | Liste des paramètres de type "tout-à-tous" |
out_sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
tensor |
tenseur de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur de valeurs de n'importe quel type |
sdy.collective_permute
(sdy::CollectivePermuteOp)
Effectue une communication de permutation collective pour remplacer les axes.
Syntaxe :
operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)
Envoie un segment du tenseur d'entrée de chaque appareil à un autre pour réorganiser/remplacer les axes qui divisent le tenseur.
Une permutation collective peut transformer le fractionnement d'entrée de sorte que chaque dimension doit être fractionnée comme avant, c'est-à-dire qu'elle doit être fractionnée selon des axes dont le produit des tailles correspond à celui des axes qui fractionnaient auparavant le tenseur.
Cela permet de réorganiser les axes dans une seule dimension ou dans différentes dimensions, et d'échanger des axes partitionnés contre des axes répliqués.
Dans l'exemple ci-dessous, la taille du tenseur fractionné est tensor<1x4x2xf32>
, et elle est conservée par la permutation collective.
Exemple :
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>
Contraintes :
- doit respecter les contraintes listées dans
Sdy_CollectiveOpInterface
. - Si le fractionnement des entrées et des sorties comporte des maillages différents, ces maillages doivent avoir exactement les mêmes axes et un ordre différent des ID d'appareil.
- Pour chaque dimension, le produit des tailles d'axe de fractionnement dans
out_sharding
doit correspondre à celui du fractionnement de la dimension d'opérande correspondante.
Caractéristiques: SameOperandsAndResultType
Interfaces: CollectiveOpInterface
, InferTypeOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
out_sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
tensor |
tenseur de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur de valeurs de n'importe quel type |
sdy.constant
(sdy::ConstantOp)
Fonctionnement constant
Génère un tenseur output
à partir d'une constante value
.
Consultez la page https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant.
Exemple :
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
Caractéristiques: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effets: MemoryEffects::Effect{}
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
value | ::mlir::ElementsAttr | Attribut de vecteur/tenseur constant |
Résultats :
Résultat | Description |
---|---|
output |
Tensor de forme statique de valeurs de tout type |
sdy.data_flow_edge
(sdy::DataFlowEdgeOp)
Opération de bord de flux de données.
Syntaxe :
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
Un bord de flux de données d'une opération X définit un pont entre un ensemble de sources (chacune est un opérande de X ou un opérande du terminateur de bloc de X) et un ensemble de cibles (chacune est un résultat de X ou un argument de bloc de X), de sorte que toutes les sources et cibles doivent être fractionnées de la même manière.
Une opération peut comporter plusieurs arêtes de flux de données qui sont orthogonales les unes aux autres.
Exemple :
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
...
return return_value_0, ..., return_value_n
})
Cette opération while comporte n arcs de flux de données, les arcs de flux de données i se situant entre les sources x_i
, return_value_i
et les cibles y_i
, pred_arg_i
, body_arg_i
.
Un sdy.data_flow_edge
prend en entrée le propriétaire d'un arc (il peut s'agir de l'une des cibles, mais de préférence d'un résultat d'opération plutôt que d'un argument de bloc), qui ne doit pas avoir d'autres utilisations. Cette opération n'est pas pure, car elle peut accepter une entrée qui n'avait initialement aucune utilité.
sdy.data_flow_edge
contient également un fractionnement facultatif pour toutes les cibles du bord. Ce fractionnement doit être mis à jour au lieu du fractionnement des cibles (s'il peut être associé) lors de la propagation. Cette fonctionnalité est utile lorsqu'une opération comporte de nombreux bords, car il est beaucoup plus efficace de:
- se propagent séparément dans chaque arête.
- mettez à jour le fractionnement de chaque arête séparément au lieu de toutes les cibles à la fois (par exemple, une opération ne comporte qu'un seul
TensorShardingPerValueAttr
immuable pour les fractionnements de résultats). - Ajoutez chaque arête à la liste de travail séparément lorsque le fractionnement d'une source a changé.
La propagation propage les fractionnements entre toutes les sources et cibles d'un sdy.data_flow_edge
comme s'il s'agissait d'une opération régulière avec les sources comme opérandes et les cibles comme résultats, et une identité sdy.op_sharding_rule
. Cela signifie que la propagation avant va des sources aux cibles, et la rétropropagation des cibles aux sources.
Nous n'autorisons pas l'entrée d'un sdy.data_flow_edge
à être définie par une opération SdyDialect
. Nous pouvons donc supposer qu'elle est définie par une opération dont l'attribut sdy.sharding
n'est pas enregistré.
Caractéristiques: SameOperandsAndResultType
Interfaces: InferTypeOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
input |
de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
de valeurs de n'importe quel type |
sdy.manual_computation
(sdy::ManualComputationOp)
Opération de parallélisme multi-appareil avec des collectifs manuels
Syntaxe :
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:`
functional-type(operands, results)
Accédez à une région écrite en termes de code local par appareil avec des collectifs explicites, où les formes logiques correspondent aux formes de tampon physiques locales par appareil et les collectifs correspondent exactement à la communication physique entre les appareils.
Le corps est local par rapport aux axes manuels. La propagation se produira à travers le corps sur toutes les axes libres (celles qui ne figurent pas dans la liste "manual_axes").
Contraintes :
- Les éléments de
in_shardings
etout_shardings
doivent respecter les contraintes listées dansTensorShardingAttr
. - Le nombre d'entrées/sorties de tenseur globales et locales de la région d'opération doit correspondre.
- Les axes manuels doivent précéder les axes libres dans chaque partitionnement de dimension.
- Les axes manuels ne peuvent pas ajouter de marge intérieure. Plus précisément, la taille de la dimension doit être divisible par la taille des axes manuels correspondants.
- Les formes globales et locales des arguments/résultats des régions d'opération doivent correspondre.
- Aucune axe manuel n'est divisé.
Caractéristiques: IsolatedFromAbove
, RecursiveMemoryEffects
, SingleBlockImplicitTerminator<ReturnOp>
, SingleBlock
Interfaces: ShardableDataFlowOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Segmentation de tensor par opérande/résultat d'une opération |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Segmentation de tensor par opérande/résultat d'une opération |
manual_axes | ::mlir::sdy::ManualAxesAttr | Liste des axes pour lesquels une opération de calcul manuel est manuelle |
Opérandes:
Opérande | Description |
---|---|
tensors |
Variadique de tensor classé de valeurs de tout type |
Résultats :
Résultat | Description |
---|---|
results |
Variadique de tensor classé de valeurs de tout type |
sdy.mesh
(sdy::MeshOp)
Maillage nommé
Syntaxe :
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
Définit un nouveau maillage nommé. Tous les maillages d'un module doivent avoir le même nombre d'appareils (à l'exception des maillages avec un seul device_id).
Le maillage est une opération Symbol
qui apparaît dans le SymbolTable
du module et peut être référencé par son name
.
Caractéristiques: HasParent<ModuleOp>
Interfaces: Symbol
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
sym_name | ::mlir::StringAttr | attribut de chaîne |
mesh | ::mlir::sdy::MeshAttr | Maillage des axes et liste des appareils |
sdy.named_computation
(sdy::NamedComputationOp)
Opération de calcul nommée
Syntaxe :
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
Regroupe un calcul, c'est-à-dire un bloc d'opérations, et lui attribue un nom. La propagation s'effectue dans/hors de la région comme si tout était intégré.
Vous pouvez l'utiliser pour gérer la propagation via des instructions d'appel vers d'autres fonctions. Tous les utilisateurs de Shardy doivent écrire un passage d'importation/d'exportation qui convertit leurs opérations d'appel en opérations sdy.named_computation
, en dupliquant/copiant le corps de la fonction appelée dans le corps de la named_computation
.
Le type de chaque argument de bloc et des valeurs renvoyées dans la région doit être le même que le type des opérandes et le type de résultats de l'opération.
Exemple :
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
Caractéristiques: IsolatedFromAbove
, RecursiveMemoryEffects
, RecursivelySpeculatableImplTrait
, SingleBlockImplicitTerminator<ReturnOp>
et SingleBlock
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, ShardableDataFlowOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
name | ::mlir::StringAttr | attribut de chaîne |
in_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Segmentation de tensor par opérande/résultat d'une opération |
out_shardings | ::mlir::sdy::TensorShardingPerValueAttr | Segmentation de tensor par opérande/résultat d'une opération |
Opérandes:
Opérande | Description |
---|---|
operands |
variadique de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
"unnamed" | variadique de n'importe quel type |
sdy.propagation_barrier
(sdy::PropagationBarrierOp)
Opération de barrière de propagation
Syntaxe :
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
Cette opération fonctionne comme une opération d'identité, en affichant la même valeur qu'elle a prise en entrée. Toutefois, en termes de propagation, cela ne permet qu'à la propagation de s'écouler dans une certaine direction.
Cela empêche la propagation des fractionnements entre les utilisations du résultat de l'opération de barrière et de son operand.
FORWARD
signifie que les fractionnements ne peuvent passer que de l'opérande au résultat.BACKWARD
signifie que les fractionnements ne peuvent passer que du résultat à l'opérande.NONE
signifie qu'aucun fractionnement ne peut se propager via cette opération.- Vous ne pouvez pas spécifier
BOTH
, car cette opération serait redondante.
Caractéristiques: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effets: MemoryEffects::Effect{}
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
allowed_direction | ::mlir::sdy::PropagationDirectionAttr | énumération de la direction de propagation |
Opérandes:
Opérande | Description |
---|---|
input |
tenseur classé de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur classé de valeurs de n'importe quel type |
sdy.reshard
(sdy::ReshardOp)
Répartitionne un tenseur dans un autre schéma de partitionnement.
Syntaxe :
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
Répartitionne le tenseur d'entrée avec la répartition spécifiée, qui est différente de la répartition existante du tenseur d'entrée.
ShardingConstraintOp et ReshardOp associent un fractionnement à un tenseur. Leur durée de vie est la suivante:
- Avant la propagation du fractionnement, ShardingConstraintOp est ajouté par les utilisateurs.
- La propagation de la segmentation consomme ShardingConstraintOp. Aucun ShardingConstraintOp n'est indiqué dans les résultats de la propagation du sharding. À la place, ReshardOp peut être ajouté si nécessaire.
- Un partitionneur convertit une opération ReshardOp en opération collective (ou opération d'identité). Aucun ReshardOp ne doit figurer dans les résultats du partitionneur.
// TODO(b/331680067). Ajout d'un modèle de canonisation pour supprimer les opérations de reshard redondantes.
Caractéristiques: AlwaysSpeculatableImplTrait
, SameOperandsAndResultType
Interfaces: ConditionallySpeculatable
, InferTypeOpInterface
, NoMemoryEffect (MemoryEffectOpInterface)
Effets: MemoryEffects::Effect{}
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
input |
tenseur de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur de valeurs de n'importe quel type |
sdy.return
(sdy::ReturnOp)
L'opération sdy.return
met fin aux régions associées aux opérations basées sur les régions sdy
et à toute autre opération basée sur les régions Shardy. Il est variadique: il prend en argument une liste de valeurs dont les types peuvent être n'importe lesquels (mais du même type, par exemple AnyTensor
) et peut donc être réutilisé à différents niveaux de la pile d'IR Shardy.
Syntaxe :
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
Caractéristiques: AlwaysSpeculatableImplTrait
, Terminator
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effets: MemoryEffects::Effect{}
Opérandes:
Opérande | Description |
---|---|
results |
variadique de n'importe quel type |
sdy.sharding_constraint
(sdy::ShardingConstraintOp)
Contraint un tenseur au fractionnement spécifié
Syntaxe :
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
Associe un fractionnement à un tenseur intermédiaire (par exemple, le résultat d'une multiplication matricielle) pour indiquer que c'est ainsi que ce tenseur, ou un sous-ensemble de ses utilisations, doit être fractionné.
Si le fractionnement comporte des dimensions ouvertes et des axes non contraints, cela signifie que le tenseur peut être fractionné davantage selon les dimensions ouvertes.
Cette opération peut:
- Ne pas être utilisé (dangling) : cela signifie que le fractionnement associé est la façon dont le tenseur d'entrée lui-même doit être fractionné.
- Avoir des utilisations : cela signifie que le partitionnement associé est la façon dont les utilisations de l'opération de contrainte de partitionnement doivent être partitionnées, tandis que d'autres utilisations du tenseur d'entrée peuvent avoir un partitionnement différent (si le tenseur d'entrée n'a pas d'autres utilisations, le comportement est le même que dans le cas où il n'y a pas d'utilisations).
Caractéristiques: SameOperandsAndResultType
Interfaces: InferTypeOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
sharding | ::mlir::sdy::TensorShardingAttr | Segmentation de tenseurs |
Opérandes:
Opérande | Description |
---|---|
input |
tenseur de valeurs de n'importe quel type |
Résultats :
Résultat | Description |
---|---|
result |
tenseur de valeurs de n'importe quel type |
sdy.sharding_group
(sdy::ShardingGroupOp)
Contraint les tenseurs du groupe à avoir le même fractionnement.
Syntaxe :
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
Cette opération fournit une interface permettant d'attribuer des tenseurs à des groupes de fractionnement (groupes de tenseurs qui devront avoir des fractionnements identiques). Lors de la propagation, dès qu'un élément de groupe est partitionné, tous les autres membres sont partitionnés exactement de la même manière. Cette opération prend l'ID de groupe d'arguments et ne renvoie aucun résultat, mais modifie plutôt la représentation du groupe de fractionnement interne pour ajouter le tenseur d'entrée au groupe avec l'ID donné.
Interfaces: InferTypeOpInterface
Attributs :
Attribut | Type MLIR | Description |
---|---|---|
group_id | ::mlir::IntegerAttr | Attribut d'entier non signé de 64 bits |
Opérandes:
Opérande | Description |
---|---|
input |
tenseur classé de valeurs de n'importe quel type |
Attributs
AllToAllParamAttr
Paramètre "Tout-à-tout"
Syntaxe :
#sdy.all_to_all_param<
::llvm::ArrayRef<AxisRefAttr>, # axes
int64_t, # src_dim
int64_t # tgt_dim
>
Tuple contenant les axes et les dimensions source/cible sur lesquels effectuer une opération de type "tout-à-tout".
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
axes | ::llvm::ArrayRef<AxisRefAttr> |
les axes sur lesquels effectuer une opération de type "tout-à-tous" |
src_dim | int64_t |
l'indice de la dimension source ; |
tgt_dim | int64_t |
l'indice de la dimension cible ; |
AlltoAllParamListAttr
Liste des paramètres de communication de type "tout-à-tous"
Syntaxe :
#sdy.all_to_all_param_list<
::llvm::ArrayRef<AllToAllParamAttr> # value
>
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
valeur | ::llvm::ArrayRef<AllToAllParamAttr> |
AxisRefAttr
Référence à une axe complète ou à une sous-axe fractionnée
Syntaxe :
#sdy.axis_ref<
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
>
Contraintes :
name
doit être présent dans leMeshAttr
lié.- Si
sub_axis_info
est présent, il doit respecter les contraintes deSubAxisInfoAttr
.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
nom | ::llvm::StringRef |
nom de cet axe |
sub_axis_info | SubAxisInfoAttr |
Informations supplémentaires si l'axe est un sous-axe |
AxisRefListAttr
Liste des références d'axe
Syntaxe :
#sdy.axis_ref_list<
::llvm::ArrayRef<AxisRefAttr> # value
>
Contraintes :
- Les éléments de
value
doivent respecter les contraintes deAxisRefAttr
. - Il n'existe pas de références d'axe ou de sous-axes en double qui se chevauchent.
- Deux références d'axe adjacentes ne sont pas des sous-axes consécutifs de ce même axe complet, c'est-à-dire qu'elles peuvent être fusionnées en un seul sous-axe ou en un axe complet.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
valeur | ::llvm::ArrayRef<AxisRefAttr> |
DimMappingAttr
Liste des indices de facteur pour une dimension
Une liste vide indique qu'il s'agit d'un mappage nul (analysé/imprimé avec *
), c'est-à-dire que la dimension n'est mappée sur aucun facteur.
Contraintes :
- Il existe au moins un indice de facteur.
- Les indices de facteur doivent être compris dans la plage [0,
$factor_sizes
). - Si plusieurs facteurs sont pris en compte, aucun d'entre eux ne peut avoir une taille 1.
- Aucun indice de facteur en double.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
factor_indices | ::llvm::ArrayRef<int64_t> |
facteurs auxquels cette dimension est mappée |
DimensionShardingAttr
Segmentation des dimensions
Liste des noms d'axes sur lesquels diviser une dimension de tenseur, de l'axe principal à l'axe secondaire, un booléen indiquant si la dimension peut être divisée davantage et un entier facultatif indiquant la priorité de ce fractionnement de dimension, qui sera respecté lors de la propagation du fractionnement. Les priorités proviennent des annotations de fractionnement des utilisateurs. Une valeur plus basse indique une priorité plus élevée. La priorité la plus élevée est supposée lorsque la priorité est manquante dans l'annotation.
Contraintes :
- Les éléments de
axes
doivent respecter les contraintes listées dansAxisRefListAttr
. - Si un fractionnement de dimension a une priorité :
- La priorité est supérieure ou égale à 0.
- La dimension comporte au moins un axe si elle est fermée.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
axes | ::llvm::ArrayRef<AxisRefAttr> |
références d'axe |
is_closed | bool |
si cette dimension ne peut pas être partitionnée davantage |
priorité | std::optional<int64_t> |
Priorité utilisée lors de la propagation basée sur la priorité de l'utilisateur |
ListOfAxisRefListsAttr
Liste des listes de références d'axe
Syntaxe :
#sdy.list_of_axis_ref_lists<
::llvm::ArrayRef<AxisRefListAttr> # value
>
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
valeur | ::llvm::ArrayRef<AxisRefListAttr> |
ManualAxesAttr
Liste des axes pour lesquels une opération ManualComputationOp est manuelle
Syntaxe :
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
>
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
valeur | ::llvm::ArrayRef<StringAttr> |
MeshAttr
Maillage des axes et liste des appareils
Syntaxe :
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
>
Une maille est une liste d'axes et une liste facultative d'ID d'appareils spécifiant l'ordre des appareils.
Si la liste des axes est vide, le maillage comporte une taille d'axe implicite sans nom de 1. Dans ce cas, si aucune liste d'ID d'appareil n'est fournie, la liste d'ID d'appareil implicite est [0]. Si une liste d'ID d'appareil est fournie, elle doit contenir un seul entier de valeur non négative. C'est ce que nous appelons le cas de partitionnement maximal.
Pour tous les cas de partitionnement non maximal, si une liste d'ID d'appareil est spécifiée, le produit des tailles d'axe doit correspondre au nombre d'appareils. Si aucune liste d'ID d'appareil n'est spécifiée, la liste d'ID d'appareil implicite est iota(product(axes)). Pour simplifier, nous n'autorisons pas non plus de spécifier une liste d'ID d'appareil identique à iota(product(axes)). Dans ce cas, une liste d'ID d'appareil ne doit pas être spécifiée.
Voici quelques exemples de maillages:
- Un maillage vide représente un maillage d'espace réservé pouvant être remplacé lors de la propagation: <[]>
- Une maille avec une axe sans nom et un ID d'appareil explicite, qui est généralement utilisé pour représenter le fractionnement maximal: <[], device_ids=[3]>
- Une maille avec deux axes et des ID d'appareil implicites iota(6): <["a"=2, "b"=3]>
- Une maille avec deux axes et des ID d'appareil explicites spécifiant l'ordre des appareils: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
Contraintes :
- Les éléments de
axes
ne doivent pas porter de noms en double. - Si
device_ids
est spécifié :- Le produit des tailles d'axe doit correspondre au nombre d'appareils.
- Tous ses éléments doivent être non négatifs.
device_ids
ne doit pas être égal àiota(product(axis_sizes))
.device_ids
trié doit êtreiota(product(axis_sizes))
.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
axes | ::llvm::ArrayRef<MeshAxisAttr> |
axes de maillage |
device_ids | ::llvm::ArrayRef<int64_t> |
Ordre d'appareil explicite ou ID d'appareil maximal |
MeshAxisAttr
Axe nommé dans un maillage
Syntaxe :
#sdy.mesh_axis<
::llvm::StringRef, # name
int64_t # size
>
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
nom | ::llvm::StringRef |
nom |
taille | int64_t |
taille de cet axe |
OpShardingRuleAttr
Spécifie comment une opération peut être partitionnée.
Syntaxe :
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
::llvm::ArrayRef<int64_t>, # reduction_factors
::llvm::ArrayRef<int64_t>, # need_replication_factors
::llvm::ArrayRef<int64_t>, # permutation_factors
::llvm::ArrayRef<int64_t>, # blocked_propagation_factors
bool # is_custom_rule
>
Une règle de partitionnement spécifie comment une opération peut être partitionnée en fonction de diverses propriétés de l'opération (attributs, forme des opérandes, forme des résultats, etc.). Par exemple:
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
{i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
Notez que nous autorisons les facteurs de taille 1, même s'ils ne peuvent pas être partitionnés. Cela est principalement pour des raisons de complétude, car de nombreuses opérations, telles que les opérations ponctuelles, ont des dimensions de taille 1 qui correspondent entre les opérandes et les résultats.
Types de facteurs :
reduction_factors
contient les indices des facteurs nécessitant une réduction, tels que les dimensions de contraction dans une opération de point.need_replication_factors
contient les indices des facteurs nécessitant une réplication complète, tels que la dimension triée dans une opération de tri.permutation_factors
contient les indices des facteurs nécessitant une permutation collective s'ils sont partitionnés, tels que les dimensions de remplissage dans une opération de remplissage.- Tous les autres facteurs sont considérés comme des facteurs de passthrough, c'est-à-dire des facteurs qui ne nécessitent aucune communication s'ils sont partitionnés de la même manière pour tous les tenseurs qui leur sont mappés.
blocked_propagation_factors
contient les facteurs pour lesquels la propagation des fractionnements n'est pas autorisée. Il est orthogonal aux types de facteurs. Plus précisément, un facteur de propagation bloqué peut être n'importe quel type de facteur.
is_custom_rule
indique s'il s'agit d'une règle définie par un utilisateur. Les utilisateurs peuvent définir des règles de fractionnement pour leurs appels personnalisés ou écraser les règles de fractionnement prédéfinies pour les opérations standards. Une règle personnalisée est toujours conservée et jamais supprimée.
Contraintes :
- Le nombre de mappages opérande/résultat doit correspondre au nombre d'opérandes/résultats de l'opération.
- Il existe au moins un mappage (vous ne pouvez pas avoir de règle pour une opération sans opérandes/résultats).
- Le rang de chaque
TensorMappingAttr
correspond au rang du type de tenseur correspondant. - Pour chaque groupe de facteurs (
reduction_factors
,need_replication_factors
,permutation_factors
) :- Les éléments doivent être compris dans la plage [0,
$factor_sizes
]. - Aucun indice de facteur en double dans chaque groupe et entre les groupes.
- Les éléments doivent être compris dans la plage [0,
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
factor_sizes | ::llvm::ArrayRef<int64_t> |
tailles de tous les facteurs de cette règle |
operand_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
mappages d'opérandes |
result_mappings | ::llvm::ArrayRef<TensorMappingAttr> |
mappages des résultats |
reduction_factors | ::llvm::ArrayRef<int64_t> |
facteurs nécessitant une réduction |
need_replication_factors | ::llvm::ArrayRef<int64_t> |
facteurs nécessitant une réplication complète |
permutation_factors | ::llvm::ArrayRef<int64_t> |
facteurs nécessitant une permutation collective |
blocked_propagation_factors | ::llvm::ArrayRef<int64_t> |
facteurs selon lesquels les fractionnements ne sont pas propagés |
is_custom_rule | bool |
si la règle concerne un stablehlo.custom_call |
SubAxisInfoAttr
Informations sur la façon dont ce sous-axe est dérivé de l'axe complet
Syntaxe :
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t # size
>
Lorsque vous divisez un axe complet en n sous-axes, l'axe est restructuré en [k_1,...,k_n], et le sous-axe i peut être exprimé par le produit de toutes les tailles d'axe à sa gauche m=prod(k_1,...,k_(i-1))
(également appelée pré-taille) et de la taille k_i. Par conséquent, l'attribut sub-axis-info contient ces deux nombres et est indiqué comme suit: (m)k
pour la pré-taille m et la taille k.
Contraintes :
pre-size
est au moins égal à 1.size
est supérieur à 1.pre-size
doit diviser la taille de l'axe complet, c'est-à-dire quepre-size
etsize
divisent la taille de l'axe complet, et que l'axe secondaire ne dépasse pas l'axe complet.- La taille du sous-axe n'est pas égale à celle de l'axe complet correspondant. Dans ce cas, l'axe complet doit être utilisé à la place.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
pre_size | int64_t |
produit des tailles des sous-axes à gauche de ce sous-axe |
taille | int64_t |
taille de ce sous-axe |
TensorMappingAttr
Mappages de facteurs pour chaque dimension d'un tenseur.
Syntaxe :
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
Contraintes :
- Les éléments de
dim_mappings
doivent respecter les contraintes deDimMappingAttr
. - Aucun indice de facteur en double entre les dimensions.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
dim_mappings | ::llvm::ArrayRef<DimMappingAttr> |
mappages de dimensions |
TensorShardingAttr
Division des tenseurs
Syntaxe :
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
Un fractionnement de tenseur est lié à un maillage spécifique et ne peut faire référence qu'aux noms d'axes de ce maillage. Les fractionnements de dimension nous indiquent, pour chaque dimension du tenseur, le long desquels axes (ou sous-axes) il est fractionné de majeur à mineur. Tous les autres axes qui ne fractionnent pas une dimension sont répliqués implicitement ou explicitement (s'ils figurent dans la liste des axes répliqués).
Le maillage auquel ce fractionnement est lié peut être spécifié par un nom de symbole, en référence à un symbole MeshOp
correspondant ou à un MeshAttr
intégré.
Contraintes :
- Les éléments de
dim_shardings
doivent respecter les contraintes listées dansDimensionShardingAttr
. - Les éléments de
replicated_axes
doivent respecter les contraintes listées dansAxisRefListAttr
. - Si le type de tenseur correspondant n'est pas un
ShapedType
, le fractionnement doit avoir un rang 0 et ne comporter aucune axe répliqué. - Le tenseur doit avoir un rang.
- Le nombre de partitions de dimension est égal au rang du tenseur.
- Les dimensions de taille 0 ne sont pas fragmentées.
- Les éléments de
replicated_axes
sont triés par rapport àmesh_or_ref
(voirAxisRefAttr::getMeshComparator
).
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
mesh_or_ref | ::mlir::Attribute |
Attribut de maillage ou attribut de référence de symbole de maillage plat |
dim_shardings | ::llvm::ArrayRef<DimensionShardingAttr> |
segmentations de dimension |
replicated_axes | ::llvm::ArrayRef<AxisRefAttr> |
références d'axe |
TensorShardingPerValueAttr
Division des tensors par operande/résultat d'une opération
Syntaxe :
#sdy.sharding_per_value<
::llvm::ArrayRef<TensorShardingAttr> # shardings
>
Liste de TensorShardingAttr
, une pour chaque operand/résultat d'une opération.
Contraintes :
- Les éléments de
shardings
doivent respecter les contraintes deTensorShardingAttr
.
Paramètres :
Paramètre | Type C++ | Description |
---|---|---|
segmentations | ::llvm::ArrayRef<TensorShardingAttr> |
sharding par valeur |
Enums
PropagationDirection
Énumération de la direction de propagation
Étuis :
Symbole | Valeur | Chaîne |
---|---|---|
AUCUN | 0 |
AUCUN |
AVANCE | 1 |
AVANCE |
ARRIÈRE | 2 |
ARRIÈRE |
TOUS LES MODÈLES | 3 |
TOUS LES MODÈLES |