Interprète

Modèle de données

Les programmes StableHLO sont des calculs sur des Tensors (tableaux à n dimensions) qui, dans le modèle actuel, sont implémentés à l'aide de la classe Tensor. La classe de stockage sous-jacente d'un objet Tensor, detail::Buffer, stocke le mlir::ShapedType du Tensor avec un objet mlir::HeapAsmResourceBlob représentant un blob modifiable de données de Tensor présentées sous forme de tableau d'octets contigus dans l'ordre croissant/mineur. Les objets detail::Buffer sont comptabilisés en référence pour simplifier la gestion de la mémoire.

Les éléments individuels d'un Tensor sont représentés à l'aide de la classe Element, qui utilise une union discriminée contenant APInt, APFloat ou pair<APFloat,APFloat> pour le stockage. Le dernier est utilisé pour stocker des éléments de types complexes.

Tensor dispose des API suivantes pour interagir avec ses éléments individuels:

  • Element Tensor::get(llvm::ArrayRef<int64_t> index): pour extraire un élément de Tensor individuel avec l'indice multidimensionnel index en tant qu'objet Element.
  • void Tensor::set(llvm::ArrayRef<int64_t> index, Element element); : pour mettre à jour un objet Element element en un Tensor avec l'indice multidimensionnel index.

Fonctionnement de l'interprète

La fonction d'entrée de l'interpréteur est

SmallVector<Tensor> eval(func::FuncOp func, ArrayRef<Tensor> args);

qui effectue les opérations suivantes:

  1. Suit les arguments SSA de func et les valeurs Tensor d'exécution associées, fournies dans args, à l'aide d'un mappage de tables de symboles, M.
  2. Pour chaque opération dans func, dans l'ordre SSACFG :
    • Invoque eval sur l'opération. Pour chaque opérande SSA de l'opération, extrayez sa valeur d'exécution à partir de M, qui sera fournie en tant qu'argument à l'appel de eval.
    • Effectue le suivi du ou des résultats SSA de l'opération et de la valeur évaluée dans M.

Le eval au niveau de l'opération mentionné à la section (2) est responsable de l'implémentation de la sémantique d'exécution de l'opération. Voici un exemple pour stablehlo::AddOp. Dans l'exemple, les éléments individuels des Tensors lhs et rhs sont extraits par paire en tant qu'objets Element, puis ajoutés. Le résultat de l'ajout, un objet Element, est stocké dans le Tensor result final.

Tensor eval(AddOp op, const Tensor &lhs, const Tensor &rhs) {
  Tensor result(op.getType());

  for (auto it = result.index_begin(); it != result.index_end(); ++it)
    result.set(*it, lhs.get(*it) + rhs.get(*it));

  return result;
}

Dans l'ensemble, la conception de l'interpréteur est optimisée pour la lisibilité des implémentations des fonctions eval pour des opérations individuelles, car elle est destinée à servir d'implémentation de référence pour StableHLO. Par exemple, au lieu de définir eval en tant que fonction de modèle et de le paramétrer avec des types d'éléments, nous encapsulons des détails sur la façon dont les différents types d'éléments sont traités dans Element::operator+, etc., ce qui simplifie l'implémentation de eval.

Utiliser l'interpréteur pour le pliage constant

Nous pouvons utiliser le mécanisme d'interpréteur pour plier les opérations avec des valeurs d'opérande constantes. L'extrait de code suivant illustre l'implémentation du pliage de stablehlo::AddOp avec des opérandes de type à virgule flottante:

OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
  auto attrs = adaptor.getOperands();
  DenseElementsAttr lhsData = attrs[0].dyn_cast<DenseElementsAttr>();
  DenseElementsAttr rhsData = attrs[1].dyn_cast<DenseElementsAttr>();
  if (!lhsData || !rhsData) return {};

  auto lhs = Tensor(lhsData);
  auto rhs = Tensor(rhsData);
  auto result = eval(*this, lhs, rhs);

  SmallVector<APFloat> values;
  for (auto i = 0; i < result.getNumElements(); ++i) {
    Element element = result.get(i);
    values.push_back(element.getValue().cast<FloatAttr>().getValue());
  }

  return DenseElementsAttr::get(result.getType(), values);
}

Pour le moment, nous ne travaillons pas activement sur l'intégration de l'interpréteur dans le pliage constant, car nous n'envisageons pas d'implémenter le dossier pour StableHLO. Toutefois, à l'avenir, nous prévoyons d'utiliser l'interpréteur pour un pliage constant en MHLO. Nous améliorerons alors l'ergonomie de l'extrait de code ci-dessus (par exemple, nous pourrions avoir une fonction d'assistance qui regroupe des opérandes constants dans des objets Tensor et décompresse les résultats Tensor en OpFoldResult).

Tester l'interpréteur StableHLO

L'interpréteur reçoit en entrée (A) un programme StableHLO et (B) des valeurs de données à transmettre au programme. Il génère des valeurs de données de sortie, qui sont mises en correspondance avec les valeurs de données attendues fournies par l'utilisateur. Les valeurs de données (B) sont codées en dur dans le programme lui-même à l'aide d'opérations stablehlo.constant. L'interpréteur évalue le programme d'entrée. La ou les sorties de l'opération testée sont vérifiées par des vérifications (par exemple, check.expect_eq ou check.expect_almost_eq), comme indiqué ci-dessous. check.expect_eq et check.expect_eq_const vérifient l'égalité au niveau du bit pour tous les types compatibles, et check.expect_almost_eq et check.expect_almost_eq_const vérifient la quasi-égalité dans une tolérance, comme expliqué dans les consignes de test (G6), pour les types à virgule flottante et complexes.

// CHECK-LABEL: Evaluated results of function: add_op_test_ui4
func.func @add_op_test_ui4() {
  %0 = stablehlo.constant dense<[0, 2]> : tensor<2xui4>
  %1 = stablehlo.constant dense<[15, 3]> : tensor<2xui4>
  %2 = stablehlo.add %0, %1 : tensor<2xui4>
  check.expect_eq_const %2, [15, 5] : tensor<2xui4>
  func.return
}

Un utilitaire de test stablehlo-translate --interpret (code) est chargé d'analyser le programme et d'interpréter chaque fonction, y compris les opérations qui la constituent. Nous disposons d'une suite de tests dédiée, composée de plusieurs tests effectuant différents comportements d'exécution, pour chaque opération StableHLO. Cliquez ici pour accéder aux tests (ex. interpréteur_*.mlir).

Consignes de test

(G1) Devons-nous tester tous les types compatibles pour chaque opération ?

Nous pouvons utiliser une combinaison des règles suivantes pour prendre une décision:

  1. Lors de l'implémentation d'une opération, s'il existe du code dans la fonction eval correspondante pour gérer un type particulier, il est impératif de disposer de tests pour couvrir ce type. Par exemple, pour l'opération add, il existe un code exclusif pour gérer les types entiers, booléens, à virgule flottante et complexes. Nous avons donc besoin d'un test pour chaque catégorie de types.

  2. Si un ensemble de types est géré de manière uniforme dans la fonction eval correspondante, un seul test pour tous ces types devrait suffire. Par exemple, pour l'opération add, toutes les variantes de types entiers (si4, u4, si8, u8, etc.) sont traitées de la même manière à l'aide des API llvm::APInt. Nous pouvons donc ignorer l'ajout de tests pour chacune de ces variantes et ajouter un seul test représentatif à la place. Pour éviter toute ambiguïté concernant le choix du représentant, nous devons utiliser les consignes suivantes:

    • Si tous les types, gérés de manière uniforme, ont le même type primitif (c'est-à-dire s'ils sont tous des nombres entiers, à virgule flottante ou complexes), choisissez celui qui a la largeur de bits maximale.
    • Si tous les types, gérés de manière uniforme, possèdent un mélange de types primitifs, choisissez celui avec le type primitif suivant, par ordre décroissant de préférence: entier, à virgule flottante, booléen, complexe.

(G2) Comment déterminons-nous le nombre de tests nécessaires pour couvrir le comportement d'une opération ?

L'objectif est de couvrir de manière exhaustive la logique de l'interpréteur pour l'opération (c'est-à-dire tous les cas particuliers de l'implémentation) avec un nombre minimal de tests. Il est important de réduire le nombre de tests pour faciliter la gestion. Moins nous avons de tests, plus il est facile de les examiner et de s'assurer qu'ils couvrent l'opération de manière exhaustive. Par conséquent, nous nous attendons à ce que la plupart des opérations les plus simples finissent par n'avoir qu'un seul test. Si, pour une raison quelconque, une couverture complète n'est pas pratique, il est acceptable de s'arrêter à 90 % ou plus. Ce choix sera décidé au cas par cas lors de l'examen des demandes d'extraction.

(G3) Pourquoi ne pas ajouter des tests pour l'infrastructure de l'interpréteur ?

L'infrastructure d'interprétation est principalement simple et peut être ajoutée à notre base de confiance. La seule partie importante est la façon dont les différents types sont empaquetés et décompressés dans le stockage de l'interpréteur sous-jacent. Comme indiqué dans la section (G1), nous ne testerons que les types d'opérations gérés différemment. Il est donc possible que le code de packaging/dépaquetage, correspondant à différentes variantes de types entiers/à virgule flottante, ne soit pas entièrement couvert lors des tests. Pour assurer une couverture complète, nous pouvons choisir une opération comme constant qui prend en charge tous les types d'éléments StableHLO et écrire des tests exhaustifs.

(G4) Si l'implémentation d'une opération dépend d'autres opérations, devons-nous écrire des tests pour cette dernière ?

Non. Par exemple, l'implémentation de batch_norm_grad peut être basée sur divide, subtract, multiply et d'autres éléments. Nous devons éviter de tester la deuxième opération tout en testant la première.

(G5) Devons-nous écrire des tests pour vérifier les comportements définis ou non définis par l'implémentation ?

Nous ne devons pas écrire de tests qui présentent les comportements définis ou non définis par l'implémentation de l'opération. Les tests effectués montrant des comportements définis par l'implémentation montrent un comportement local de l'interpréteur qui ne doit pas être généralisé. Les tests avec un comportement non défini ne contribuent pas à la compréhension du comportement de l'opération.

(G6) Lors de l'écriture des tests pour les types à virgule flottante, à quel degré de précision le résultat attendu doit-il être spécifié dans les vérifications ?

Pour les opérations élémentaires (addition, soustraction, multiplication, division et carré), une implémentation conforme à la spécification IEEE doit fournir un résultat arrondi à 0,5 ULP du résultat mathématique exact. Cela dit, nous pouvons sans risque imaginer que le résultat attendu de ces opérations ne soit pas séparé de plus d'un nœud de page de destination. Toutefois, cela peut ne pas fonctionner pour les fonctions transcendantes (sine, cosine, etc.) pour lesquelles les garanties de précision sont définies par l'implémentation (rationale).

L'implémentation actuelle utilise une valeur de tolérance unique de 0,0001. L'exemple suivant illustre la tolérance ci-dessus en action.

func.func @check_tolerance() {
  %0 = stablehlo.constant dense<0.2> : tensor<f32>

  // The following check succeeds as %0 is almost equal to the provided
  // constant modulo the tolerance, mentioned above.
  check.expect_almost_eq_const %0, dense<0.19999> : tensor<f32>

  // The following check fails as %0 is not bitwise equal to the provided
  // constant.
  check.expect_eq_const %0, dense<0.19999> : tensor<f32>

  func.return
}

Il ne s'agit que de la première étape pour tester la précision numérique des opérations StableHLO. Pour le moment, il s'agit d'une partie sous-spécifiée de la spécification StableHLO, et des efforts continus sont en cours pour la résoudre (#1156) sur la base de notre expérience pratique de l'utilisation de StableHLO et des commentaires des partenaires. Nous mettrons à jour l'infrastructure en conséquence au fur et à mesure de l'avancement du processus.

(G7) Avez-vous quelque chose à dire sur le codage des tests ?

  1. Assurez-vous d'utiliser le nom réel des entrées/sorties au lieu d'utiliser par défaut les valeurs SSA (par exemple, %0, %1, etc.).
  2. Assurez-vous que les tests utilisent le format plutôt imprimé, le cas échéant.

(G8) Devons-nous inclure l'exemple déjà fourni dans les spécifications ? Oui (pour des tests complets).