L'état actuel du dynamisme est décrit plus formellement dans la RFC sur le dynamisme. Cette page fournit une vue d'ensemble de la RFC et présente les API et les outils importants pour interagir avec les programmes dynamiques.
Présentation de la terminologie et de l'assistance pour le dynamisme
Tout d'abord, voici quelques termes qui apparaîtront dans ce document, ainsi qu'une brève présentation de leur prise en charge dans StableHLO:
Dimensions dynamiques
Les dimensions dynamiques désignent toute dimension dont la taille est inconnue.
Dans StableHLO, nous représentons les dimensions dynamiques à l'aide de ?
, c'est-à-dire tensor<16x?xf32>
.
dynamisme limité
Le dynamisme limité fait référence à une dimension dynamique dont la valeur a une borne supérieure connue. Cela est généralement utile pour remplir le Tensor pendant l'exécution.
Dans StableHLO, nous représentons le dynamisme limité en utilisant #stablehlo.bounds
comme encodage de Tensor, c'est-à-dire un Tensor de rang 2 avec une dimension dynamique limitée à 16, et l'autre sans limite peut être représenté par tensor<?x?xf32, #stablehlo.bounds<16, ?>>
.
StableHLO peut représenter un dynamisme limité, mais la compatibilité avec les frameworks est limitée, car elle provient de TensorFlow et, dans une certaine mesure, est compatible avec PyTorch/XLA.
Dynamisme illimité
Comme son nom l'indique, le dynamisme illimité fait référence à une dimension dynamique sans limite de taille connue. Ce type de dynamisme est très courant dans StableHLO, avec la prise en charge de JAX, PyTorch/XLA et TF, souvent utilisée pour exporter des modèles avec une taille de lot ou une longueur de séquence dynamique.
Dans StableHLO, il suffit d'éliminer les limites d'encodage pour cette forme de dynamisme, à savoir tensor<?x?xf32>
.
Polymorphisme de forme
Le polymorphisme de forme est un terme que nous avons hérité de JAX.
Le polymorphisme de forme a deux implications principales:
- Tout le dynamisme du programme remonte à ses arguments d'entrée.
- Tout dynamisme concerne uniquement les formes des tenseurs, c'est-à-dire qu'il n'est pas dépendant des données.
Avec ces deux règles, une fois que les formes statiques d'un programme sont connues, nous pouvons prendre un programme dynamique et l'affiner complètement en un programme statique pour la compilation (voir "Passages de compilation pour affiner les programmes dynamiques").
En règle générale, le polymorphisme de forme utilise un dynamisme illimité. Si des formes d'argument connues peuvent conduire à un programme entièrement statique, il n'est pas nécessaire de deviner comment lier les valeurs.
Dynamisme dépendant des données
Le dynamisme dépendant des données fait référence aux tailles de dimensions dynamiques qui se rapportent aux données à l'intérieur d'un Tensor. L'exemple canonique est une fonction nonzeros
qui renvoie les indices de tous les éléments qui sont 0
dans une valeur de tenseur. Il est impossible de connaître la forme sans évaluer les données, mais elle peut souvent être compilée à l'aide d'un dynamisme limité, en utilisant de la mémoire supplémentaire pour réduire la taille potentielle du Tensor de sortie.
De nombreuses opérations dynamiques dépendantes des données peuvent être modélisées à l'aide d'un dynamisme limité, où une limite supérieure à la taille d'un tenseur est spécifiée, et le matériel l'implémente généralement via le remplissage de tenseur. Aujourd'hui, le dynamisme dépendant des données est partiellement pris en charge dans PyTorch/XLA et TensorFlow, mais JAX ne suit actuellement pas les opérations qui entraînent un dynamisme dépendant des données.
Exporter des programmes avec des dimensions dynamiques
Pour savoir comment exporter des programmes avec des tailles de lot ou des longueurs de séquence dynamiques, consultez nos tutoriels StableHLO:
- Tutoriel JAX > Exportation avec taille de lot dynamique
- Tutoriel PyTorch/XLA > Exportation avec taille de lot dynamique
Passages du compilateur pour affiner les programmes dynamiques
Supprimer le pipeline de la passe de dynamisme
Il existe quelques passes utiles pour affiner les formes. Elles sont toutes regroupées dans un pipeline de passes createStablehloRemoveDynamismPipeline
:
void createStablehloRemoveDynamismPipeline(OpPassManager &pm,
TypeRange refinedTypes);
Cartes individuelles pour le affinage du dynamisme
Individuellement, les cartes qui s'avèrent utiles pour affiner les formes sont les suivantes:
stablehlo-refine-arguments
pour remplacer les arguments d'entrée par des types de tenseurs concrets.stablehlo-refine-shapes
pour propager les nouvelles informations de forme d'argument d'entrée dans l'ensemble du programme.stablehlo-canonicalize-dynamism
pour remplacer les opérations dynamiques par leurs variantes statiques.
Consultez la documentation associée pour obtenir des informations à jour et des exemples.
Exemple: En quoi le dynamisme est-il utile et comment puis-je l'utiliser ?
Le dynamisme a de nombreuses utilisations. Ici, nous nous concentrerons principalement sur le cas d'utilisation courant du polymorphisme de forme : la création d'une représentation de modèle exportée flexible, généralement utilisée pour représenter la taille de lot dynamique ou la longueur de séquence.
Modèle add_one statique
Pour illustrer cela, nous allons utiliser le modèle add_one
simple suivant:
def add_one(x):
return x + 1
Lors du traçage à l'aide d'un tensor<4xf32>
, nous obtenons le programme StableHLO suivant:
// File: add_one.mlir
func.func @add_one(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<4xf32>
%0 = stablehlo.add %arg0, %cst : tensor<4xf32>
return %0 : tensor<4xf32>
}
Ce modèle fonctionnera uniquement pour les arguments d'entrée ayant une forme tensor<4xf32>
. Si nous modifions la taille de lot ou la longueur de séquence, nous devrons retracer le code source et le ramener à StableHLO. Il n'est pas garanti que nous ayons encore accès au code source.
Modèle add_one dynamique
C'est là que le dynamisme polymorphe des formes entre en jeu. À la place, JAX et PyTorch/XLA peuvent émettre le modèle add_one
avec un IR valide de manière dynamique, qui diffusera la constante pour qu'elle corresponde à la forme d'entrée dynamique comme suit:
// File: add_one_dynamic.mlir
func.func public @main(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<f32>
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
%1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
%2 = stablehlo.dynamic_broadcast_in_dim %cst, %1, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%3 = stablehlo.add %arg0, %2 : tensor<?xf32>
return %3 : tensor<?xf32>
}
Cette représentation du modèle est beaucoup plus flexible et permet de spécifier des valeurs telles que la taille de lot ou la longueur de séquence de manière différée. Ce modèle peut être déployé sur des plates-formes compatibles avec les formes dynamiques (comme AI Edge) ou affiné à l'aide des passes de dynamisme mentionnées dans cette documentation.
Affiner le modèle dynamique
Par exemple, l'ordre des cartes suivant peut affiner complètement ce programme:
stablehlo-opt add_one_dynamic.mlir \
--stablehlo-refine-arguments='types=tensor<16xf32>' \
--stablehlo-refine-shapes \
--stablehlo-canonicalize-dynamism
Voici comment le programme est transformé de manière incrémentielle:
// After stablehlo-refine-arguments: Inputs updated, shapes not propagated
func.func public @main(%arg0: tensor<16xf32>) -> tensor<?xf32> {
%c = stablehlo.constant dense<16> : tensor<1xi64>
%0 = stablehlo.custom_call @stablehlo.shape_refinement_operand_wrapper(%arg0, %c) {indices_of_shape_operands = dense<1> : tensor<1xi64>} : (tensor<16xf32>, tensor<1xi64>) -> tensor<?xf32>
...
%3 = stablehlo.dynamic_broadcast_in_dim %cst, %2, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
%4 = stablehlo.add %0, %3 : tensor<?xf32>
return %4 : tensor<?xf32>
}
// After stablehlo-refine-shapes: Shapes propagated, dynamic ops still exist
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c = stablehlo.constant dense<16> : tensor<1xi32>
%0 = stablehlo.dynamic_broadcast_in_dim %cst, %c, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// After stablehlo-canonicalize-dynamism: Dynamic ops replaced with static ops
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<16xf32>
%1 = stablehlo.add %arg0, %0 : tensor<16xf32>
return %1 : tensor<16xf32>
}
// (Bonus) Use ` --stablehlo-aggressive-simplification` pass to canonicalize the
// constant broadcast, leaving us with the original static program in this case.
func.func public @main(%arg0: tensor<16xf32>) -> tensor<16xf32> {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<16xf32>
%0 = stablehlo.add %arg0, %cst : tensor<16xf32>
return %0 : tensor<16xf32>
}