Shardy est un système de partitionnement de tenseur basé sur MLIR pour tous les dialectes. Conçue grâce à la collaboration des équipes GSPMD et PartIR, elle intègre le meilleur des deux systèmes, ainsi que l'expérience partagée des équipes et des utilisateurs.
Avantages
- Contrôle et prévisibilité accrus pour les utilisateurs en combinant la propagation GSPMD avec le partitionnement incrémentiel de PartIR.
- Nouvelles fonctionnalités basées sur l'expérience partagée, par exemple la prise en charge innovante des refontes, qui génèrent généralement des communications supplémentaires, sauf si les utilisateurs savent comment les contourner.
- Amélioration de la facilité d'utilisation et de la facilité de débogage pour augmenter la vitesse de l'utilisateur final, par exemple en utilisant une représentation de fractionnement basée sur des axes.
- Un codebase Open Source simple utilisant MLIR, avec un ensemble plus large de contributeurs actifs (internes, externes et dans différents fuseaux horaires) pour aider les utilisateurs.
Composants
- Représentation du fractionnement: représentation du fractionnement basée sur des axes, liée à un maillage logique spécifique (parmi plusieurs maillages potentiels) et compatible avec les fractionnements et axes de dimension contraignants, les axes de fractionnement pour des opérations telles que la refonte, les priorités de partitionnement incrémentiel, etc.
- API du compilateur: ensemble de composants de compilateur pouvant être utilisés avec la représentation du fractionnement pour influencer la propagation du fractionnement.
- Divisions d'entrée/sortie : joignez une division à une entrée ou une sortie de la fonction principale pour indiquer que le tenseur d'entrée/sortie doit être divisé de cette manière lorsqu'il est transmis à la fonction ou renvoyé par celle-ci.
- Contrainte de fractionnement : joignez un fractionnement à un tenseur intermédiaire (par exemple, le résultat d'une multiplication matricielle) pour indiquer que c'est ainsi que ce tenseur, ou un sous-ensemble de ses utilisations, doit être fractionné.
- Shard As/Like (Diviser comme/à la manière de) : regroupez plusieurs tenseurs par ID pour indiquer qu'ils doivent être divisés de la même manière.
- Calcul manuel : inclut un sous-calcul partitionné manuellement à l'aide d'un sous-ensemble d'axes de maillage, où les fractionnements le long de ces axes manuels sont spécifiés pour toutes les entrées et sorties, et dans le sous-calcul, les types de tenseur sont locaux par rapport à ces fractionnements.
- Propagation de fractionnement: algorithme de propagation qui combine les priorités utilisateur et les contraintes de fractionnement, avec des modèles de coût et des heuristiques de compilation :
- Priorités définies par l'utilisateur, par exemple effectuer le parallélisme par lot, puis ZeRO
- Priorités basées sur les opérations, par exemple les opérations élémentaires d'abord, puis les multiplications matricielles, etc.
- Heuristiques plus précises, par exemple, privilégier les dimensions de lot.
- Partitionneur SPMD: composant qui réduit les décisions de propagation du fractionnement en partitionnant le programme en un programme SPMD, en ajoutant le mouvement/le formatage des données et les opérations collectives nécessaires au cours du processus.
- À court terme, l'implémentation initiale utilisera le partitionneur SPMD GSPMD actuel.
- À long terme, nous prévoyons de créer un nouveau partitionneur SPMD basé sur MLIR.
Dépôt du code
Le projet Shardy est en cours de développement et nous sollicitons les commentaires de la communauté Open Source. Le code Shardy est disponible sur https://github.com/openxla/shardy.