Dialecte

Le dialecte Shardy (SDY) définit une représentation de fractionnement de tenseur basée sur les axes et des composants d'API supplémentaires pour associer des fractionnements à des tenseurs.

Opérations

sdy.all_gather (sdy::AllGatherOp)

Effectue une communication all-gather le long des axes.

Syntaxe :

operation ::= `sdy.all_gather` $gathering_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Réunit des segments d'un tenseur le long des axes spécifiés dans gathering_axes.

gathering_axes est une liste de listes d'axes. La liste externe dépasse les dimensions du tenseur. Chaque liste interne spécifie les axes sur lesquels un rassemblement distinct doit être effectué sur la dimension correspondante. Il sera appliqué au fractionnement de l'opérande (tensor) pour obtenir le fractionnement du résultat (out_sharding).

Notez que out_sharding n'est pas utilisé pour déterminer le fractionnement du résultat. Au lieu de cela, le fractionnement du résultat est déterminé par le fractionnement de l'opérande et de gathering_axes, et out_sharding doit correspondre à ce fractionnement inféré.

Exemple :

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>

Contraintes :

  • doit respecter les contraintes listées dans Sdy_CollectiveOpInterface.
  • Les éléments de gathering_axes doivent respecter les contraintes listées dans AxisRefListAttr.
  • L'application de gathering_axes au fractionnement de l'opérande génère out_sharding.

Caractéristiques: SameOperandsAndResultType

Interfaces: InferTypeOpInterface, Sdy_CollectiveOpInterface

Attributs :

AttributType MLIRDescription
gathering_axes::mlir::sdy::ListOfAxisRefListsAttrListe des listes de références d'axe
out_sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
tensor tenseur de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur de valeurs de n'importe quel type

sdy.all_reduce (sdy::AllReduceOp)

Effectuer une communication all-reduce le long des axes

Syntaxe :

operation ::= `sdy.all_reduce` $reduction_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Réduit les segments d'un tenseur le long des axes spécifiés dans reduction_axes. L'ordre de reduction_axes n'a pas d'importance pour le résultat, mais il peut affecter l'ordre des groupes de réplicas correspondants.

Contraintes :

  • doit respecter les contraintes listées dans Sdy_CollectiveOpInterface.
  • reduction_axes doit respecter les contraintes listées dans AxisRefListAttr.
  • reduction_axes ne doit pas chevaucher les axes de fractionnement de l'opérande.

Caractéristiques: SameOperandsAndResultType

Interfaces: CollectiveOpInterface, InferTypeOpInterface

Attributs :

AttributType MLIRDescription
reduction_axes::mlir::sdy::AxisRefListAttrListe des références d'axe
out_sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
tensor tenseur de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur de valeurs de n'importe quel type

sdy.all_slice (sdy::AllSliceOp)

Effectue une opération de tranche dynamique le long des axes.

Syntaxe :

operation ::= `sdy.all_slice` $slicing_axes $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Tranche des segments d'un tenseur selon les axes spécifiés dans slicing_axes. Il existe une dualité algébrique entre sdy.all_slice et sdy.all_gather.

slicing_axes est une liste de listes d'axes. La liste externe dépasse les dimensions du tenseur. Chaque liste interne spécifie les axes sur lesquels une tranche doit être effectuée pour la dimension correspondante. Il sera appliqué au fractionnement de l'opérande (tensor) pour obtenir le fractionnement du résultat (out_sharding).

Notez que out_sharding n'est pas utilisé pour déterminer le fractionnement du résultat. Au lieu de cela, le fractionnement du résultat est déterminé par le fractionnement de l'opérande et de slicing_axes, et out_sharding doit correspondre à ce fractionnement inféré.

Exemple :

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>

Contraintes :

  • Les éléments de slicing_axes doivent respecter les contraintes listées dans AxisRefListAttr.
  • doit respecter les contraintes listées dans Sdy_CollectiveOpInterface.
  • L'application de slicing_axes au fractionnement de l'opérande génère out_sharding.

Caractéristiques: SameOperandsAndResultType

Interfaces: CollectiveOpInterface, InferTypeOpInterface

Attributs :

AttributType MLIRDescription
slicing_axes::mlir::sdy::ListOfAxisRefListsAttrListe des listes de références d'axe
out_sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
tensor tenseur de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur de valeurs de n'importe quel type

sdy.all_to_all (sdy::AllToAllOp)

Effectue une communication de type "tout-à-tous" le long des axes.

Syntaxe :

operation ::= `sdy.all_to_all` $params $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Pour chaque tuple (axes, src_dim, tgt_dim) de la liste des paramètres, cette opération découpe des segments d'un tenseur selon la dimension tgt_dim et les axes spécifiés dans axes, les disperse selon les axes et les concatène selon la dimension src_dim.

Cette opération est essentiellement une combinaison d'une opération de collecte complète sur src_dim et axes, suivie d'une tranche complète sur tgt_dim et axes, c'est-à-dire qu'un suffixe de la dimension de fractionnement des axes src_dim sur le tenseur d'entrée est ajouté à la dimension de fractionnement des axes tgt_dim sur le tenseur de sortie.

La distribution de tous les nœuds vers tous les nœuds sera appliquée au fractionnement de l'opérande (tensor) pour obtenir le fractionnement du résultat (out_sharding).

Notez que out_sharding n'est pas utilisé pour déterminer le fractionnement du résultat. À la place, le fractionnement du résultat est déterminé par le fractionnement de l'opérande, src_dim, tgt_dim et axes, et out_sharding doit correspondre à ce fractionnement inféré.

Exemple :

%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
%2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>

Contraintes :

  • doit respecter les contraintes listées dans Sdy_CollectiveOpInterface.
  • La liste des paramètres est obligatoire.
  • Pour chaque paramètre de params :
    • Les éléments de axes doivent respecter les contraintes de AxisRefAttr.
    • src_dim et tgt_dim doivent être des dimensions valides (non négatives et inférieures au rang du tenseur).
    • Les src_dim ou tgt_dim doivent être uniques pour tous les paramètres.
    • src_dim doit être trié par ordre croissant pour tous les paramètres.
  • Déplacer axes de src_dim vers tgt_dim dans le fractionnement des opérandes génère out_sharding.

Caractéristiques: SameOperandsAndResultType

Interfaces: InferTypeOpInterface, Sdy_CollectiveOpInterface

Attributs :

AttributType MLIRDescription
params::mlir::sdy::AlltoAllParamListAttrListe des paramètres de type "tout-à-tous"
out_sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
tensor tenseur de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur de valeurs de n'importe quel type

sdy.collective_permute (sdy::CollectivePermuteOp)

Effectue une communication de permutation collective pour remplacer les axes.

Syntaxe :

operation ::= `sdy.collective_permute` $tensor `out_sharding````=```$out_sharding attr-dict `:` type($result)

Envoie un segment du tenseur d'entrée de chaque appareil à un autre pour réorganiser/remplacer les axes qui divisent le tenseur.

Une permutation collective peut transformer le fractionnement d'entrée de sorte que chaque dimension doit être fractionnée comme avant, c'est-à-dire qu'elle doit être fractionnée selon des axes dont le produit des tailles correspond à celui des axes qui fractionnaient auparavant le tenseur.

Cela permet de réorganiser les axes dans une seule dimension ou dans différentes dimensions, et d'échanger des axes partitionnés contre des axes répliqués.

Dans l'exemple ci-dessous, la taille du tenseur fractionné est tensor<1x4x2xf32>, et elle est conservée par la permutation collective.

Exemple :

sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
%2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>

Contraintes :

  • doit respecter les contraintes listées dans Sdy_CollectiveOpInterface.
  • Si le fractionnement des entrées et des sorties comporte des maillages différents, ces maillages doivent avoir exactement les mêmes axes et un ordre différent des ID d'appareil.
  • Pour chaque dimension, le produit des tailles d'axe de fractionnement dans out_sharding doit correspondre à celui du fractionnement de la dimension d'opérande correspondante.

Caractéristiques: SameOperandsAndResultType

Interfaces: CollectiveOpInterface, InferTypeOpInterface

Attributs :

AttributType MLIRDescription
out_sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
tensor tenseur de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur de valeurs de n'importe quel type

sdy.constant (sdy::ConstantOp)

Fonctionnement constant

Génère un tenseur output à partir d'une constante value.

Consultez la page https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant.

Exemple :

%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>

Caractéristiques: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effets: MemoryEffects::Effect{}

Attributs :

AttributType MLIRDescription
value::mlir::ElementsAttrAttribut de vecteur/tenseur constant

Résultats :

Résultat Description
output Tensor de forme statique de valeurs de tout type

sdy.data_flow_edge (sdy::DataFlowEdgeOp)

Opération de bord de flux de données.

Syntaxe :

operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)

Un bord de flux de données d'une opération X définit un pont entre un ensemble de sources (chacune est un opérande de X ou un opérande du terminateur de bloc de X) et un ensemble de cibles (chacune est un résultat de X ou un argument de bloc de X), de sorte que toutes les sources et cibles doivent être fractionnées de la même manière.

Une opération peut comporter plusieurs arêtes de flux de données qui sont orthogonales les unes aux autres.

Exemple :

  y_0, ..., y_n = while (x_0, ..., x_n)
                  ((pred_arg_0,... , pred_arg_n) { ... })
                  ((body_arg_0,..., body_arg_n) {
                    ...
                    return return_value_0, ..., return_value_n
                  })

Cette opération while comporte n arcs de flux de données, les arcs de flux de données i se situant entre les sources x_i, return_value_i et les cibles y_i, pred_arg_i, body_arg_i.

Un sdy.data_flow_edge prend en entrée le propriétaire d'un arc (il peut s'agir de l'une des cibles, mais de préférence d'un résultat d'opération plutôt que d'un argument de bloc), qui ne doit pas avoir d'autres utilisations. Cette opération n'est pas pure, car elle peut accepter une entrée qui n'avait initialement aucune utilité.

sdy.data_flow_edge contient également un fractionnement facultatif pour toutes les cibles du bord. Ce fractionnement doit être mis à jour au lieu du fractionnement des cibles (s'il peut être associé) lors de la propagation. Cette fonctionnalité est utile lorsqu'une opération comporte de nombreux bords, car il est beaucoup plus efficace de:

  • se propagent séparément dans chaque arête.
  • mettez à jour le fractionnement de chaque arête séparément au lieu de toutes les cibles à la fois (par exemple, une opération ne comporte qu'un seul TensorShardingPerValueAttr immuable pour les fractionnements de résultats).
  • Ajoutez chaque arête à la liste de travail séparément lorsque le fractionnement d'une source a changé.

La propagation propage les fractionnements entre toutes les sources et cibles d'un sdy.data_flow_edge comme s'il s'agissait d'une opération régulière avec les sources comme opérandes et les cibles comme résultats, et une identité sdy.op_sharding_rule. Cela signifie que la propagation avant va des sources aux cibles, et la rétropropagation des cibles aux sources.

Nous n'autorisons pas l'entrée d'un sdy.data_flow_edge à être définie par une opération SdyDialect. Nous pouvons donc supposer qu'elle est définie par une opération dont l'attribut sdy.sharding n'est pas enregistré.

Caractéristiques: SameOperandsAndResultType

Interfaces: InferTypeOpInterface

Attributs :

AttributType MLIRDescription
sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
input de valeurs de n'importe quel type

Résultats :

Résultat Description
result de valeurs de n'importe quel type

sdy.manual_computation (sdy::ManualComputationOp)

Opération de parallélisme multi-appareil avec des collectifs manuels

Syntaxe :

operation ::= `sdy.manual_computation` `(`operands`)`
              `in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
              `out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
              `manual_axes````=```$manual_axes
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:`
              functional-type(operands, results)

Accédez à une région écrite en termes de code local par appareil avec des collectifs explicites, où les formes logiques correspondent aux formes de tampon physiques locales par appareil et les collectifs correspondent exactement à la communication physique entre les appareils.

Le corps est local par rapport aux axes manuels. La propagation se produira à travers le corps sur toutes les axes libres (celles qui ne figurent pas dans la liste "manual_axes").

Contraintes :

  • Les éléments de in_shardings et out_shardings doivent respecter les contraintes listées dans TensorShardingAttr.
  • Le nombre d'entrées/sorties de tenseur globales et locales de la région d'opération doit correspondre.
  • Les axes manuels doivent précéder les axes libres dans chaque partitionnement de dimension.
  • Les axes manuels ne peuvent pas ajouter de marge intérieure. Plus précisément, la taille de la dimension doit être divisible par la taille des axes manuels correspondants.
  • Les formes globales et locales des arguments/résultats des régions d'opération doivent correspondre.
  • Aucune axe manuel n'est divisé.

Caractéristiques: IsolatedFromAbove, RecursiveMemoryEffects, SingleBlockImplicitTerminator<ReturnOp>, SingleBlock

Interfaces: ShardableDataFlowOpInterface

Attributs :

AttributType MLIRDescription
in_shardings::mlir::sdy::TensorShardingPerValueAttrSegmentation de tensor par opérande/résultat d'une opération
out_shardings::mlir::sdy::TensorShardingPerValueAttrSegmentation de tensor par opérande/résultat d'une opération
manual_axes::mlir::sdy::ManualAxesAttrListe des axes pour lesquels une opération de calcul manuel est manuelle

Opérandes:

Opérande Description
tensors Variadique de tensor classé de valeurs de tout type

Résultats :

Résultat Description
results Variadique de tensor classé de valeurs de tout type

sdy.mesh (sdy::MeshOp)

Maillage nommé

Syntaxe :

operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict

Définit un nouveau maillage nommé. Tous les maillages d'un module doivent avoir le même nombre d'appareils (à l'exception des maillages avec un seul device_id). Le maillage est une opération Symbol qui apparaît dans le SymbolTable du module et peut être référencé par son name.

Caractéristiques: HasParent<ModuleOp>

Interfaces: Symbol

Attributs :

AttributType MLIRDescription
sym_name::mlir::StringAttrattribut de chaîne
mesh::mlir::sdy::MeshAttrMaillage des axes et liste des appareils

sdy.named_computation (sdy::NamedComputationOp)

Opération de calcul nommée

Syntaxe :

operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
              (`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
              (`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
              custom<SingleBlockRegionNoBlockId>($body)
              attr-dict
              `:` functional-type($operands, results)

Regroupe un calcul, c'est-à-dire un bloc d'opérations, et lui attribue un nom. La propagation s'effectue dans/hors de la région comme si tout était intégré.

Vous pouvez l'utiliser pour gérer la propagation via des instructions d'appel vers d'autres fonctions. Tous les utilisateurs de Shardy doivent écrire un passage d'importation/d'exportation qui convertit leurs opérations d'appel en opérations sdy.named_computation, en dupliquant/copiant le corps de la fonction appelée dans le corps de la named_computation.

Le type de chaque argument de bloc et des valeurs renvoyées dans la région doit être le même que le type des opérandes et le type de résultats de l'opération.

Exemple :

%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
  sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Caractéristiques: IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatableImplTrait, SingleBlockImplicitTerminator<ReturnOp> et SingleBlock

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, ShardableDataFlowOpInterface

Attributs :

AttributType MLIRDescription
name::mlir::StringAttrattribut de chaîne
in_shardings::mlir::sdy::TensorShardingPerValueAttrSegmentation de tensor par opérande/résultat d'une opération
out_shardings::mlir::sdy::TensorShardingPerValueAttrSegmentation de tensor par opérande/résultat d'une opération

Opérandes:

Opérande Description
operands variadique de n'importe quel type

Résultats :

Résultat Description
"unnamed" variadique de n'importe quel type

sdy.propagation_barrier (sdy::PropagationBarrierOp)

Opération de barrière de propagation

Syntaxe :

operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)

Cette opération fonctionne comme une opération d'identité, en affichant la même valeur qu'elle a prise en entrée. Toutefois, en termes de propagation, cela ne permet qu'à la propagation de s'écouler dans une certaine direction.

Cela empêche la propagation des fractionnements entre les utilisations du résultat de l'opération de barrière et de son operand.

  • FORWARD signifie que les fractionnements ne peuvent passer que de l'opérande au résultat.
  • BACKWARD signifie que les fractionnements ne peuvent passer que du résultat à l'opérande.
  • NONE signifie qu'aucun fractionnement ne peut se propager via cette opération.
  • Vous ne pouvez pas spécifier BOTH, car cette opération serait redondante.

Caractéristiques: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effets: MemoryEffects::Effect{}

Attributs :

AttributType MLIRDescription
allowed_direction::mlir::sdy::PropagationDirectionAttrénumération de la direction de propagation

Opérandes:

Opérande Description
input tenseur classé de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur classé de valeurs de n'importe quel type

sdy.reshard (sdy::ReshardOp)

Répartitionne un tenseur dans un autre schéma de partitionnement.

Syntaxe :

operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)

Répartitionne le tenseur d'entrée avec la répartition spécifiée, qui est différente de la répartition existante du tenseur d'entrée.

ShardingConstraintOp et ReshardOp associent un fractionnement à un tenseur. Leur durée de vie est la suivante:

  1. Avant la propagation du fractionnement, ShardingConstraintOp est ajouté par les utilisateurs.
  2. La propagation de la segmentation consomme ShardingConstraintOp. Aucun ShardingConstraintOp n'est indiqué dans les résultats de la propagation du sharding. À la place, ReshardOp peut être ajouté si nécessaire.
  3. Un partitionneur convertit une opération ReshardOp en opération collective (ou opération d'identité). Aucun ReshardOp ne doit figurer dans les résultats du partitionneur.

// TODO(b/331680067). Ajout d'un modèle de canonisation pour supprimer les opérations de reshard redondantes.

Caractéristiques: AlwaysSpeculatableImplTrait, SameOperandsAndResultType

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effets: MemoryEffects::Effect{}

Attributs :

AttributType MLIRDescription
sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
input tenseur de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur de valeurs de n'importe quel type

sdy.return (sdy::ReturnOp)

L'opération sdy.return met fin aux régions associées aux opérations basées sur les régions sdy et à toute autre opération basée sur les régions Shardy. Il est variadique: il prend en argument une liste de valeurs dont les types peuvent être n'importe lesquels (mais du même type, par exemple AnyTensor) et peut donc être réutilisé à différents niveaux de la pile d'IR Shardy.

Syntaxe :

operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?

Caractéristiques: AlwaysSpeculatableImplTrait, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effets: MemoryEffects::Effect{}

Opérandes:

Opérande Description
results variadique de n'importe quel type

sdy.sharding_constraint (sdy::ShardingConstraintOp)

Contraint un tenseur au fractionnement spécifié

Syntaxe :

operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)

Associe un fractionnement à un tenseur intermédiaire (par exemple, le résultat d'une multiplication matricielle) pour indiquer que c'est ainsi que ce tenseur, ou un sous-ensemble de ses utilisations, doit être fractionné.

Si le fractionnement comporte des dimensions ouvertes et des axes non contraints, cela signifie que le tenseur peut être fractionné davantage selon les dimensions ouvertes.

Cette opération peut:

  • Ne pas être utilisé (dangling) : cela signifie que le fractionnement associé est la façon dont le tenseur d'entrée lui-même doit être fractionné.
  • Avoir des utilisations : cela signifie que le partitionnement associé est la façon dont les utilisations de l'opération de contrainte de partitionnement doivent être partitionnées, tandis que d'autres utilisations du tenseur d'entrée peuvent avoir un partitionnement différent (si le tenseur d'entrée n'a pas d'autres utilisations, le comportement est le même que dans le cas où il n'y a pas d'utilisations).

Caractéristiques: SameOperandsAndResultType

Interfaces: InferTypeOpInterface

Attributs :

AttributType MLIRDescription
sharding::mlir::sdy::TensorShardingAttrSegmentation de tenseurs

Opérandes:

Opérande Description
input tenseur de valeurs de n'importe quel type

Résultats :

Résultat Description
result tenseur de valeurs de n'importe quel type

sdy.sharding_group (sdy::ShardingGroupOp)

Contraint les tenseurs du groupe à avoir le même fractionnement.

Syntaxe :

operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)

Cette opération fournit une interface permettant d'attribuer des tenseurs à des groupes de fractionnement (groupes de tenseurs qui devront avoir des fractionnements identiques). Lors de la propagation, dès qu'un élément de groupe est partitionné, tous les autres membres sont partitionnés exactement de la même manière. Cette opération prend l'ID de groupe d'arguments et ne renvoie aucun résultat, mais modifie plutôt la représentation du groupe de fractionnement interne pour ajouter le tenseur d'entrée au groupe avec l'ID donné.

Interfaces: InferTypeOpInterface

Attributs :

AttributType MLIRDescription
group_id::mlir::IntegerAttrAttribut d'entier non signé de 64 bits

Opérandes:

Opérande Description
input tenseur classé de valeurs de n'importe quel type

Attributs

AllToAllParamAttr

Paramètre "Tout-à-tout"

Syntaxe :

#sdy.all_to_all_param<
  ::llvm::ArrayRef<AxisRefAttr>,   # axes
  int64_t,   # src_dim
  int64_t   # tgt_dim
>

Tuple contenant les axes et les dimensions source/cible sur lesquels effectuer une opération de type "tout-à-tout".

Paramètres :

Paramètre Type C++ Description
axes ::llvm::ArrayRef<AxisRefAttr> les axes sur lesquels effectuer une opération de type "tout-à-tous"
src_dim int64_t l'indice de la dimension source ;
tgt_dim int64_t l'indice de la dimension cible ;

AlltoAllParamListAttr

Liste des paramètres de communication de type "tout-à-tous"

Syntaxe :

#sdy.all_to_all_param_list<
  ::llvm::ArrayRef<AllToAllParamAttr>   # value
>

Paramètres :

Paramètre Type C++ Description
valeur ::llvm::ArrayRef<AllToAllParamAttr>

AxisRefAttr

Référence à une axe complète ou à une sous-axe fractionnée

Syntaxe :

#sdy.axis_ref<
  ::llvm::StringRef,   # name
  SubAxisInfoAttr   # sub_axis_info
>

Contraintes :

  • name doit être présent dans le MeshAttr lié.
  • Si sub_axis_info est présent, il doit respecter les contraintes de SubAxisInfoAttr.

Paramètres :

Paramètre Type C++ Description
nom ::llvm::StringRef nom de cet axe
sub_axis_info SubAxisInfoAttr Informations supplémentaires si l'axe est un sous-axe

AxisRefListAttr

Liste des références d'axe

Syntaxe :

#sdy.axis_ref_list<
  ::llvm::ArrayRef<AxisRefAttr>   # value
>

Contraintes :

  • Les éléments de value doivent respecter les contraintes de AxisRefAttr.
  • Il n'existe pas de références d'axe ou de sous-axes en double qui se chevauchent.
  • Deux références d'axe adjacentes ne sont pas des sous-axes consécutifs de ce même axe complet, c'est-à-dire qu'elles peuvent être fusionnées en un seul sous-axe ou en un axe complet.

Paramètres :

Paramètre Type C++ Description
valeur ::llvm::ArrayRef<AxisRefAttr>

DimMappingAttr

Liste des indices de facteur pour une dimension

Une liste vide indique qu'il s'agit d'un mappage nul (analysé/imprimé avec *), c'est-à-dire que la dimension n'est mappée sur aucun facteur.

Contraintes :

  • Il existe au moins un indice de facteur.
  • Les indices de facteur doivent être compris dans la plage [0, $factor_sizes).
  • Si plusieurs facteurs sont pris en compte, aucun d'entre eux ne peut avoir une taille 1.
  • Aucun indice de facteur en double.

Paramètres :

Paramètre Type C++ Description
factor_indices ::llvm::ArrayRef<int64_t> facteurs auxquels cette dimension est mappée

DimensionShardingAttr

Segmentation des dimensions

Liste des noms d'axes sur lesquels diviser une dimension de tenseur, de l'axe principal à l'axe secondaire, un booléen indiquant si la dimension peut être divisée davantage et un entier facultatif indiquant la priorité de ce fractionnement de dimension, qui sera respecté lors de la propagation du fractionnement. Les priorités proviennent des annotations de fractionnement des utilisateurs. Une valeur plus basse indique une priorité plus élevée. La priorité la plus élevée est supposée lorsque la priorité est manquante dans l'annotation.

Contraintes :

  • Les éléments de axes doivent respecter les contraintes listées dans AxisRefListAttr.
  • Si un fractionnement de dimension a une priorité :
    • La priorité est supérieure ou égale à 0.
    • La dimension comporte au moins un axe si elle est fermée.

Paramètres :

Paramètre Type C++ Description
axes ::llvm::ArrayRef<AxisRefAttr> références d'axe
is_closed bool si cette dimension ne peut pas être partitionnée davantage
priorité std::optional<int64_t> Priorité utilisée lors de la propagation basée sur la priorité de l'utilisateur

ListOfAxisRefListsAttr

Liste des listes de références d'axe

Syntaxe :

#sdy.list_of_axis_ref_lists<
  ::llvm::ArrayRef<AxisRefListAttr>   # value
>

Paramètres :

Paramètre Type C++ Description
valeur ::llvm::ArrayRef<AxisRefListAttr>

ManualAxesAttr

Liste des axes pour lesquels une opération ManualComputationOp est manuelle

Syntaxe :

#sdy.manual_axes<
  ::llvm::ArrayRef<StringAttr>   # value
>

Paramètres :

Paramètre Type C++ Description
valeur ::llvm::ArrayRef<StringAttr>

MeshAttr

Maillage des axes et liste des appareils

Syntaxe :

#sdy.mesh<
  ::llvm::ArrayRef<MeshAxisAttr>,   # axes
  ::llvm::ArrayRef<int64_t>   # device_ids
>

Une maille est une liste d'axes et une liste facultative d'ID d'appareils spécifiant l'ordre des appareils.

Si la liste des axes est vide, le maillage comporte une taille d'axe implicite sans nom de 1. Dans ce cas, si aucune liste d'ID d'appareil n'est fournie, la liste d'ID d'appareil implicite est [0]. Si une liste d'ID d'appareil est fournie, elle doit contenir un seul entier de valeur non négative. C'est ce que nous appelons le cas de partitionnement maximal.

Pour tous les cas de partitionnement non maximal, si une liste d'ID d'appareil est spécifiée, le produit des tailles d'axe doit correspondre au nombre d'appareils. Si aucune liste d'ID d'appareil n'est spécifiée, la liste d'ID d'appareil implicite est iota(product(axes)). Pour simplifier, nous n'autorisons pas non plus de spécifier une liste d'ID d'appareil identique à iota(product(axes)). Dans ce cas, une liste d'ID d'appareil ne doit pas être spécifiée.

Voici quelques exemples de maillages:

  • Un maillage vide représente un maillage d'espace réservé pouvant être remplacé lors de la propagation: <[]>
  • Une maille avec une axe sans nom et un ID d'appareil explicite, qui est généralement utilisé pour représenter le fractionnement maximal: <[], device_ids=[3]>
  • Une maille avec deux axes et des ID d'appareil implicites iota(6): <["a"=2, "b"=3]>
  • Une maille avec deux axes et des ID d'appareil explicites spécifiant l'ordre des appareils: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>

Contraintes :

  • Les éléments de axes ne doivent pas porter de noms en double.
  • Si device_ids est spécifié :
    • Le produit des tailles d'axe doit correspondre au nombre d'appareils.
    • Tous ses éléments doivent être non négatifs.
    • device_ids ne doit pas être égal à iota(product(axis_sizes)).
    • device_ids trié doit être iota(product(axis_sizes)).

Paramètres :

Paramètre Type C++ Description
axes ::llvm::ArrayRef<MeshAxisAttr> axes de maillage
device_ids ::llvm::ArrayRef<int64_t> Ordre d'appareil explicite ou ID d'appareil maximal

MeshAxisAttr

Axe nommé dans un maillage

Syntaxe :

#sdy.mesh_axis<
  ::llvm::StringRef,   # name
  int64_t   # size
>

Paramètres :

Paramètre Type C++ Description
nom ::llvm::StringRef nom
taille int64_t taille de cet axe

OpShardingRuleAttr

Spécifie comment une opération peut être partitionnée.

Syntaxe :

#sdy.op_sharding_rule<
  ::llvm::ArrayRef<int64_t>,   # factor_sizes
  ::llvm::ArrayRef<TensorMappingAttr>,   # operand_mappings
  ::llvm::ArrayRef<TensorMappingAttr>,   # result_mappings
  ::llvm::ArrayRef<int64_t>,   # reduction_factors
  ::llvm::ArrayRef<int64_t>,   # need_replication_factors
  ::llvm::ArrayRef<int64_t>,   # permutation_factors
  ::llvm::ArrayRef<int64_t>,   # blocked_propagation_factors
  bool   # is_custom_rule
>

Une règle de partitionnement spécifie comment une opération peut être partitionnée en fonction de diverses propriétés de l'opération (attributs, forme des opérandes, forme des résultats, etc.). Par exemple:

%0 = stablehlo.add %arg0, %arg1 {
    sdy.sharding_rule = #sdy.op_sharding_rule<
        ([i, j],[i, j])->([i, j])
        {i=8, j=8}>
} : tensor<8x8xf32>
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
  sdy.sharding_rule = #sdy.op_sharding_rule<
      ([i, k],[k, j])->([i, j])
      {i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>

Notez que nous autorisons les facteurs de taille 1, même s'ils ne peuvent pas être partitionnés. Cela est principalement pour des raisons de complétude, car de nombreuses opérations, telles que les opérations ponctuelles, ont des dimensions de taille 1 qui correspondent entre les opérandes et les résultats.

Types de facteurs :

  • reduction_factors contient les indices des facteurs nécessitant une réduction, tels que les dimensions de contraction dans une opération de point.
  • need_replication_factors contient les indices des facteurs nécessitant une réplication complète, tels que la dimension triée dans une opération de tri.
  • permutation_factors contient les indices des facteurs nécessitant une permutation collective s'ils sont partitionnés, tels que les dimensions de remplissage dans une opération de remplissage.
  • Tous les autres facteurs sont considérés comme des facteurs de passthrough, c'est-à-dire des facteurs qui ne nécessitent aucune communication s'ils sont partitionnés de la même manière pour tous les tenseurs qui leur sont mappés.

blocked_propagation_factors contient les facteurs pour lesquels la propagation des fractionnements n'est pas autorisée. Il est orthogonal aux types de facteurs. Plus précisément, un facteur de propagation bloqué peut être n'importe quel type de facteur.

is_custom_rule indique s'il s'agit d'une règle définie par un utilisateur. Les utilisateurs peuvent définir des règles de fractionnement pour leurs appels personnalisés ou écraser les règles de fractionnement prédéfinies pour les opérations standards. Une règle personnalisée est toujours conservée et jamais supprimée.

Contraintes :

  • Le nombre de mappages opérande/résultat doit correspondre au nombre d'opérandes/résultats de l'opération.
  • Il existe au moins un mappage (vous ne pouvez pas avoir de règle pour une opération sans opérandes/résultats).
  • Le rang de chaque TensorMappingAttr correspond au rang du type de tenseur correspondant.
  • Pour chaque groupe de facteurs (reduction_factors, need_replication_factors, permutation_factors) :
    • Les éléments doivent être compris dans la plage [0, $factor_sizes].
    • Aucun indice de facteur en double dans chaque groupe et entre les groupes.

Paramètres :

Paramètre Type C++ Description
factor_sizes ::llvm::ArrayRef<int64_t> tailles de tous les facteurs de cette règle
operand_mappings ::llvm::ArrayRef<TensorMappingAttr> mappages d'opérandes
result_mappings ::llvm::ArrayRef<TensorMappingAttr> mappages des résultats
reduction_factors ::llvm::ArrayRef<int64_t> facteurs nécessitant une réduction
need_replication_factors ::llvm::ArrayRef<int64_t> facteurs nécessitant une réplication complète
permutation_factors ::llvm::ArrayRef<int64_t> facteurs nécessitant une permutation collective
blocked_propagation_factors ::llvm::ArrayRef<int64_t> facteurs selon lesquels les fractionnements ne sont pas propagés
is_custom_rule bool si la règle concerne un stablehlo.custom_call

SubAxisInfoAttr

Informations sur la façon dont ce sous-axe est dérivé de l'axe complet

Syntaxe :

#sdy.sub_axis_info<
  int64_t,   # pre_size
  int64_t   # size
>

Lorsque vous divisez un axe complet en n sous-axes, l'axe est restructuré en [k_1,...,k_n], et le sous-axe i peut être exprimé par le produit de toutes les tailles d'axe à sa gauche m=prod(k_1,...,k_(i-1)) (également appelée pré-taille) et de la taille k_i. Par conséquent, l'attribut sub-axis-info contient ces deux nombres et est indiqué comme suit: (m)k pour la pré-taille m et la taille k.

Contraintes :

  • pre-size est au moins égal à 1.
  • size est supérieur à 1.
  • pre-size doit diviser la taille de l'axe complet, c'est-à-dire que pre-size et size divisent la taille de l'axe complet, et que l'axe secondaire ne dépasse pas l'axe complet.
  • La taille du sous-axe n'est pas égale à celle de l'axe complet correspondant. Dans ce cas, l'axe complet doit être utilisé à la place.

Paramètres :

Paramètre Type C++ Description
pre_size int64_t produit des tailles des sous-axes à gauche de ce sous-axe
taille int64_t taille de ce sous-axe

TensorMappingAttr

Mappages de facteurs pour chaque dimension d'un tenseur.

Syntaxe :

#sdy.tensor_mapping<
  ::llvm::ArrayRef<DimMappingAttr>   # dim_mappings
>

Contraintes :

  • Les éléments de dim_mappings doivent respecter les contraintes de DimMappingAttr.
  • Aucun indice de facteur en double entre les dimensions.

Paramètres :

Paramètre Type C++ Description
dim_mappings ::llvm::ArrayRef<DimMappingAttr> mappages de dimensions

TensorShardingAttr

Division des tenseurs

Syntaxe :

#sdy.sharding<
  ::mlir::Attribute,   # mesh_or_ref
  ::llvm::ArrayRef<DimensionShardingAttr>,   # dim_shardings
  ::llvm::ArrayRef<AxisRefAttr>   # replicated_axes
>

Un fractionnement de tenseur est lié à un maillage spécifique et ne peut faire référence qu'aux noms d'axes de ce maillage. Les fractionnements de dimension nous indiquent, pour chaque dimension du tenseur, le long desquels axes (ou sous-axes) il est fractionné de majeur à mineur. Tous les autres axes qui ne fractionnent pas une dimension sont répliqués implicitement ou explicitement (s'ils figurent dans la liste des axes répliqués).

Le maillage auquel ce fractionnement est lié peut être spécifié par un nom de symbole, en référence à un symbole MeshOp correspondant ou à un MeshAttr intégré.

Contraintes :

  • Les éléments de dim_shardings doivent respecter les contraintes listées dans DimensionShardingAttr.
  • Les éléments de replicated_axes doivent respecter les contraintes listées dans AxisRefListAttr.
  • Si le type de tenseur correspondant n'est pas un ShapedType, le fractionnement doit avoir un rang 0 et ne comporter aucune axe répliqué.
  • Le tenseur doit avoir un rang.
  • Le nombre de partitions de dimension est égal au rang du tenseur.
  • Les dimensions de taille 0 ne sont pas fragmentées.
  • Les éléments de replicated_axes sont triés par rapport à mesh_or_ref (voir AxisRefAttr::getMeshComparator).

Paramètres :

Paramètre Type C++ Description
mesh_or_ref ::mlir::Attribute Attribut de maillage ou attribut de référence de symbole de maillage plat
dim_shardings ::llvm::ArrayRef<DimensionShardingAttr> segmentations de dimension
replicated_axes ::llvm::ArrayRef<AxisRefAttr> références d'axe

TensorShardingPerValueAttr

Division des tensors par operande/résultat d'une opération

Syntaxe :

#sdy.sharding_per_value<
  ::llvm::ArrayRef<TensorShardingAttr>   # shardings
>

Liste de TensorShardingAttr, une pour chaque operand/résultat d'une opération.

Contraintes :

  • Les éléments de shardings doivent respecter les contraintes de TensorShardingAttr.

Paramètres :

Paramètre Type C++ Description
segmentations ::llvm::ArrayRef<TensorShardingAttr> sharding par valeur

Enums

PropagationDirection

Énumération de la direction de propagation

Étuis :

Symbole Valeur Chaîne
AUCUN 0 AUCUN
AVANCE 1 AVANCE
ARRIÈRE 2 ARRIÈRE
TOUS LES MODÈLES 3 TOUS LES MODÈLES