API du compilateur

Contexte

Nous supposons que les lecteurs connaissent au moins les principes de base de la représentation du fractionnement, qui décrit comment le fractionnement d'un tenseur peut être exprimé dans Shardy. Ce document explique comment les représentations de fractionnement peuvent être utilisées dans un programme, par exemple pour associer un fractionnement à un tenseur spécifique du programme.

La propagation du partitionnement consiste à décider d'un partitionnement pour chaque tenseur d'un programme en fonction des contraintes de partitionnement d'un sous-ensemble des tenseurs. L'API du compilateur de Shardy propose plusieurs façons d'influencer/contrôler la propagation du sharding. De plus, il permet aux utilisateurs d'insérer des calculs partitionnés manuellement dans leurs programmes.

Objectif

Ce document décrit la conception de ces composants d'API dans Shardy et explique leur comportement et leurs invariants. Notez que cette API est utilisée pour contrôler la propagation du sharding, mais que ce document NE traite pas du comportement de la propagation ni de sa conception.

Présentation

  • Division en parties d'entrée/sortie : joignez une division en parties à une entrée ou une sortie de la fonction principale pour indiquer que c'est ainsi que le tenseur d'entrée/sortie doit être divisé en parties lorsqu'il est transmis à la fonction ou renvoyé par celle-ci.

  • Contrainte de fractionnement : associez un fractionnement à un tenseur intermédiaire (par exemple, le résultat d'une multiplication matricielle) pour indiquer que c'est ainsi que ce tenseur, ou un sous-ensemble de ses utilisations, doit être fractionné.

  • Groupe de fractionnement : regroupez plusieurs tenseurs par ID pour indiquer qu'ils doivent être fractionnés de la même manière.

  • Calcul manuel : enferme un sous-calcul partitionné manuellement à l'aide d'un sous-ensemble d'axes de maillage, où les fractionnements le long de ces axes manuels sont spécifiés pour toutes les entrées et sorties, et à l'intérieur du sous-calcul, les types de tenseurs sont locaux par rapport à ces fractionnements.

Conception détaillée

Divisions d'entrée/sortie

Permet aux utilisateurs de spécifier un fractionnement pour les entrées et les sorties de la fonction principale.

Dans MLIR, des attributs peuvent être associés aux arguments et aux résultats de la fonction. Les utilisateurs peuvent donc associer des attributs de fractionnement à la fonction de cette manière.

Exemple :

@mesh_xy = <["x"=2, "y"=2]>

// The 1st input has a sharding specified, but the 2nd input doesn't.
// The output has a sharding specified.
func @main(%arg0: tensor<8x8xf32>
            {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"}, {}]>},
            %arg1: tensor<8x16xf32>)
    -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_xy, [{}, {"y"}]>}) {
  ...
}

Contrainte de segmentation

Permet aux utilisateurs d'associer un fractionnement à un tenseur intermédiaire dans leur programme, ce qui indique au partitionneur que c'est ainsi que ce tenseur, ou un sous-ensemble de ses utilisations, doit être fractionné.

Il s'agit d'une opération MLIR qui utilise le tenseur en entrée et à laquelle est associé un attribut de fractionnement. L'opération peut:

  • Ne pas être utilisé (dangling) : cela signifie que le fractionnement associé est la façon dont le tenseur lui-même doit être fractionné.
  • Avoir des utilisations : cela signifie que le partitionnement associé est la façon dont les utilisations de l'opération de contrainte de partitionnement doivent être partitionnées, tandis que d'autres utilisations du tenseur d'entrée peuvent avoir un partitionnement différent (si le tenseur d'entrée n'a pas d'autres utilisations, le comportement est le même que dans le cas où il n'y a pas d'utilisations). La propagation déterminera le fractionnement du tenseur lui-même et le refragmentera si nécessaire.

Il peut comporter des fractionnements de dimension ouverts, ce qui signifie que l'opérande peut être fractionné davantage selon les axes disponibles.

@mesh_xy = <["x"=2, "y"=2]>

%0 = ... : tensor<8x8xf32>
%1 = sdy.sharding_constraint %0 <@mesh_xy, [{"x"}, {?}]> : tensor<8x8xf32>

Groupe de segmentation

Dans les cas où il n'y a pas de dépendances de données ou de dépendances de données fortes entre deux ou plusieurs tenseurs, alors que les utilisateurs savent que ces tenseurs doivent être partitionnés de la même manière ou de manière similaire, l'API Shardy permet de spécifier cette relation. Cela permet aux utilisateurs de spécifier explicitement que les tenseurs doivent être partitionnés les uns par rapport aux autres.

Pour ce faire, nous introduisons la notion de groupes de fragments, où chaque groupe contient un nombre illimité d'instructions associées au même ID de groupe de fragments. Les groupes de fragmentation imposent que les fragmentations d'un même groupe soient identiques.

Par exemple, dans un programme utilisateur hypothétique tel que celui illustré ci-dessous, nous souhaitons tronquer la sortie du programme exactement comme l'entrée du programme, alors qu'il n'y a pas de dépendances de données entre les deux.

Si nous exécutons ce programme, la propagation du fractionnement ne pourra pas inférer le fractionnement des tenseurs %1 et %2, et ils finiront par être répliqués. Toutefois, en associant un attribut shard_group indiquant que l'%0 d'entrée et l'%2 de sortie se trouvent dans le même shard_group, nous permettons la propagation du @mesh_xy, [{"x"},{"y"}]> de fractionnement de l'%0 d'entrée à l'%2 de sortie, puis au reste du graphique, qui est diffusé en tant que constante %1 ici. Nous pouvons attribuer une valeur à un groupe avec l'opération sdy.sharding_group.

@mesh_xy = <["x"=2, "y"=2]>

module @"jit_zeros_like" {
  func.func @main(%arg0: tensor<8x2xi64> {sdy.sharding = #sdy.sharding<@mesh_xy, [{"x"},{"y"}]>} }) -> (tensor<8x2xi64>) {
    %0 = sdy.sharding_group %arg0, id=0 : tensor<8x2xi64>
    %1 = stablehlo.constant dense<0> : tensor<8x2xi64>
    %2 = sdy.sharding_group %1, id=0 : tensor<8x2xi64>
    return %2 : tensor<8x2xi64>
  }
}

Dans cet exemple simple ci-dessus, nous aurions également pu spécifier explicitement le même fractionnement à la sortie que l'entrée, ce qui aurait le même effet, car nous savons déjà quel fragment nous voulons attribuer à l'entrée à l'avance. Toutefois, dans des cas plus réalistes, nous utilisons le fractionnement pour synchroniser le fractionnement de plusieurs tenseurs sans nécessairement connaître le fractionnement de l'un d'eux, tandis que Shardy s'occupera du reste et trouvera le meilleur fractionnement à leur attribuer.

Calcul manuel

Les utilisateurs peuvent souhaiter contrôler explicitement la façon dont certaines parties de leur calcul sont partitionnées et les collectifs utilisés. Par exemple, certains utilisateurs souhaitent appliquer manuellement la multiplication matricielle collective (à partir de l'API de présentation) plutôt que de la différer au compilateur. Nous fournissons une API de calcul manuel qui leur permet de le faire.

Il s'agit de l'opération MLIR avec une seule région pour le sous-calcul manuel. Les utilisateurs spécifient les fractionnements d'entrée/sortie de ce sous-calcul à l'aide d'un sous-ensemble (y compris éventuellement tous) des axes de maillage. Le sous-calcul serait local/manuel par rapport aux axes de maillage spécifiés (axes manuels) et global/non partitionné par rapport aux axes non spécifiés (axes libres). Le sous-calcul peut être fractionné davantage le long des axes libres lors de la propagation, de la même manière que le calcul en dehors de cette opération.

Exemple :

@mesh_name = <["data"=2, "model"=2]>

%0 = ... : tensor<16x32xf32>
%1 = sdy.manual_computation(%0)
    in_shardings=[<@mesh_name, [{"data"}, {"model",?}]>]
    out_shardings=[<@mesh_name, [{"data"}, {?}]>]
    manual_axes={"data"}
    (%arg1: tensor<8x32xf32>) {
  // body
  return %42 : tensor<8x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>

Invariants

  1. Tous les in_shardings, out_shardings et manual_axes doivent faire référence à la même maille. manual_axes est trié par rapport au maillage.

  2. manual_axes doit être utilisé explicitement dans tous les fractionnements entrants/sortants, c'est-à-dire que pour chaque fractionnement, tous les axes manuels doivent fractionner une dimension ou être répliqués explicitement.

  3. Si une axe libre (tout axe de maillage qui ne figure pas dans manual_axes) existe dans l'une des segmentations entrantes/sortantes, il doit être inférieur à tout axe manuel de la même segmentation de dimension (dans l'exemple ci-dessus, une segmentation de dimension {"model", "data"} serait non valide).

  4. La région/le corps du calcul est le calcul local (par exemple, y compris les collectifs spécifiés par l'utilisateur). Il doit être local par rapport au fractionnement des entrées/sorties le long des axes manuels (voir la note ci-dessus).

Enchâssement de calculs manuels

Vous pouvez imbriquer plusieurs calculs manuels les uns dans les autres, à condition que chacun d'eux fonctionne sur son propre ensemble d'axes manuels.