Spécification StableHLO

StableHLO est un ensemble d'opérations pour les opérations de haut niveau (HLO) dans les modèles de ML (ML). StableHLO fonctionne comme une couche de portabilité entre différents Frameworks et compilateurs de ML: frameworks de ML qui produisent des programmes StableHLO sont compatibles avec les compilateurs de ML qui consomment des programmes StableHLO.

Notre objectif est de simplifier et d'accélérer le développement du ML en créant l'interopérabilité entre différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et les compilateurs de ML (tels que XLA et IREE). Dans cette optique, fournit une spécification pour le langage de programmation StableHLO.

Cette spécification contient trois sections principales. Tout d'abord, le La section Programmes décrit la structure des programmes StableHLO. qui se composent de fonctions StableHLO qui sont elles-mêmes des opérations StableHLO. Au sein de cette structure, la section Ops spécifie la sémantique opérations individuelles. La section Exécution fournit une sémantique pour toutes ces opérations s'exécutent ensemble dans un programme. Enfin, la fonction La section Notation présente la notation utilisée tout au long de la spécifique.

Pour afficher la spécification d'une version précédente de StableHLO, ouvrez le dépôt à l'adresse version taguée qui vous intéresse. (par exemple, la spécification StableHLO v0.19.0). Pour afficher les modifications survenues lors de chaque chargement de version mineure de StableHLO, consultez le journal des versions dans VhloDialect.td.

Programmes

Program ::= {Func}

Les programmes StableHLO consistent en un nombre arbitraire de fonctions StableHLO. Vous trouverez ci-dessous un exemple de programme avec une fonction @main comportant trois entrées (%image, %weights et %bias) et un résultat. Corps de la fonction comporte 6 opérations.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Fonctions

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Les fonctions StableHLO (également appelées fonctions nommées) ont un identifiant, des entrées/sorties et un corps. À l'avenir, nous prévoyons de Intégrer des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, n° 740, n° 744).

Identifiants

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Les identifiants StableHLO sont semblables aux identifiants dans de nombreux programmes avec deux particularités: 1) tous les identifiants ont des sigils distinguer différents types d'identifiants, 2) les identifiants de valeur peuvent être entièrement numérique pour simplifier la génération de programmes StableHLO.

Types

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Les types StableHLO sont classés en types de valeurs (également appelés types de première classe), qui représentent des valeurs StableHLO et des types autres que des valeurs qui décrivent d'autres éléments du programme. Les types StableHLO sont similaires aux types dans de nombreux langages de programmation, la principale particularité étant domaine spécifique qui donne des résultats inhabituels (par exemple, des types scalaires ne sont pas des types de valeur).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Les types de Tensor représentent des Tensors, c'est-à-dire des tableaux multidimensionnels. Ils ont un forme et un type d'élément, où une forme représente des valeurs non négatives ou tailles de dimensions inconnues dans l'ordre croissant des valeurs dimensions (également appelées axes) numérotées de 0 à R-1. La Le nombre de dimensions R est appelé classement. Par exemple, tensor<2x3xf32> est un type de Tensor avec la forme 2x3 et le type d'élément f32. Il a deux dimensions (ou, en d'autres termes, deux axes) - 0e dimension et 1re dimension - dont les tailles sont 2 et 3. Son classement est de 2.

Les formes peuvent être partiellement ou totalement inconnues (dynamiques). Ex. : tensor<?x2xf64> est en partie inconnue, tandis que tensor<?x?xf64> l'est complètement. Dynamique les dimensions sont représentées à l'aide d'un ?. Les formes ne peuvent pas être désclassées.

À l'avenir, nous prévoyons d'étendre les types de Tensor au-delà les tailles de dimension et les types d'éléments, par exemple, pour inclure des mises en page (#629) et la parcimonie (#1078)

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nom Type Contraintes
storage_type type d'entier (C1-C3), (C8)
storage_min constante d'entier (C1), (C3), (C7)
storage_max constante d'entier (C2), (C3), (C7)
expressed_type type à virgule flottante (C4)
quantization_dimension constante facultative entière (C10-C12)
scales nombre variadique de constantes à virgule flottante (C4-C6), (C9), (C10), (C13)
zero_points nombre variadique de constantes entières (C7-C9)

Les types d'éléments quantifiés représentent des valeurs entières d'un type de stockage dans la plage comprise entre storage_min et storage_max (inclus) correspondant à valeurs à virgule flottante d'un type exprimé. Pour une valeur entière donnée i, la valeur à virgule flottante correspondante f peut être calculée comme suit : f = (i - zero_point) * scale, où scale et zero_point sont appelés paramètres de quantification. Les champs storage_min et storage_max sont facultatifs dans la grammaire, mais leur valeur par défaut est min_value(storage_type) et max_value(storage_type) respectivement. Les types d'éléments quantifiés ont la les contraintes suivantes:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Si is_empty(quantization_dimension), alors size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Pour le moment, QuantizationScale est une constante à virgule flottante, mais il existe un fort intérêt pour les échelles basées sur des entiers, représentées par des multiplicateurs et changements de direction. Nous prévoyons de l'examiner prochainement. (#1404)

Une discussion est en cours sur la sémantique de QuantizationZeroPoint, y compris le type, les valeurs et s'il ne peut y avoir qu'un seul ou potentiellement plusieurs points zéro dans un type de Tensor quantifié. D'après les résultats de cette discussion, la spécification autour de zéro point peut changer ultérieurement (#1405).

Une autre discussion en cours concerne la sémantique de QuantizationStorageMin. et QuantizationStorageMax pour déterminer si des contraintes doivent être sur ces valeurs et sur les valeurs des Tensors quantifiés (#1406)

Enfin, nous prévoyons de représenter les échelles inconnues et les valeurs zéro de la même manière que nous prévoyons d'explorer la représentation des (#1407)

Les types de Tensors quantifiés représentent des Tensors avec des éléments quantifiés. Ces Les Tensors sont exactement les mêmes que les Tensors standards, si ce n'est que leurs éléments utilisent des types d'éléments quantifiés au lieu de types d'éléments standards.

Dans les tenseurs quantifiés, la quantification peut être par Tensor, ce qui signifie que une valeur scale et une zero_point pour l'intégralité du Tensor ou par axe, c'est-à-dire qu'avoir plusieurs scales et zero_points, une paire par tranche de une dimension particulière quantization_dimension. Plus formellement, dans un Tensor t avec la quantification par axe, il existe dim(t, quantization_dimension) tranches pour quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], etc. Tous les éléments de la ie tranche utilisent scales[i] et zero_points[i] comme leurs paramètres de quantification. Les types de Tensor quantifiés ont les caractéristiques suivantes : contraintes:

  • Pour la quantification par Tensor: <ph type="x-smartling-placeholder">
      </ph>
    • Aucune contrainte supplémentaire.
  • Pour la quantification par axe: <ph type="x-smartling-placeholder">
      </ph>
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Les types de jetons représentent les jetons, c'est-à-dire les valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer un ordre d'exécution aux opérations comme décrit dans la section Exécution.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Les types de tuples représentent des tuples, c'est-à-dire des listes hétérogènes. Les Tuples sont un héritage fonctionnalité qui n'existe que pour la compatibilité avec HLO. Dans HLO, les tuples sont utilisée pour représenter des entrées et sorties variables. Dans StableHLO, les entrées variables et les sorties sont prises en charge de manière native, et la seule utilisation de tuples dans StableHLO consiste à représenter de manière exhaustive l'ABI HLO T, tuple<T> et tuple<tuple<T>> peut être sensiblement différente en fonction d'une la mise en œuvre. Nous prévoyons d'apporter des modifications à l'ABI HLO à l'avenir. ce qui peut nous permettre de supprimer les types de tuples de StableHLO (#598)

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Les types d'éléments représentent les éléments des types de Tensor. Contrairement à de nombreuses fonctions langues, ces types ne sont pas de première classe dans StableHLO. Cela signifie que Les programmes StableHLO ne peuvent pas représenter directement des valeurs de ces types (par conséquent, représenter des valeurs scalaires de type T avec un Tensor à 0 dimensions est idiomatique. de type tensor<T>).

  • Le type booléen représente les valeurs booléennes true et false.
  • Les types d'entiers peuvent être signés (si) ou non signés (ui), et ont l'une des largeurs de bits acceptées (2, 4, 8, 16, 32 ou 64) ; Les types siN signés représentent des valeurs entières comprises entre -2^(N-1) et 2^(N-1)-1 inclus, et les types uiN non signés représentent des valeurs entières comprises entre 0 et 2^N-1 inclus.
  • Il existe plusieurs types à virgule flottante: <ph type="x-smartling-placeholder">
  • Les types complexes représentent des valeurs complexes ayant une partie réelle. et une partie imaginaire du même type d'élément. Complexe compatible les types sont complex<f32> (les deux parties sont de type f32) et complex<f64> (les deux parties sont de type f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Les types de fonctions représentent les fonctions nommées et anonymes. Ils ont les types d'entrée (liste des types à gauche de ->) et les types de sortie (la liste des types sur la droite de ->). Dans de nombreux langages de programmation, langages, les types de fonctions sont de première classe, mais pas dans StableHLO.

StringType ::= 'string'

Le type de chaîne représente des séquences d'octets. Contrairement à de nombreuses fonctions langues, le type de chaîne n'est pas la première classe dans StableHLO et n'est utilisé spécifier des métadonnées statiques pour les éléments du programme.

Opérations

Les opérations StableHLO (également appelées opérations) représentent un ensemble fermé. d'opérations de haut niveau dans des modèles de machine learning. Comme indiqué ci-dessus, La syntaxe StableHLO s'inspire fortement de MLIR, qui n'est pas nécessairement la méthode la plus alternative ergonomique, mais c'est sans doute la meilleure solution pour permettre à StableHLO de ce qui renforce l'interopérabilité entre les frameworks de ML et les compilateurs de ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Les opérations StableHLO (également appelées opérations) ont un nom, des entrées/sorties et une signature. Ce nom se compose du préfixe stablehlo. et Un mnémonique qui identifie de manière unique l'une des opérations prises en charge. Voir ci-dessous pour une liste complète de toutes les opérations prises en charge.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Les opérations utilisent les entrées et produisent des sorties. Les entrées sont classées en les valeurs d'entrée (calculées lors de l'exécution), les fonctions d'entrée (fournies de manière statique, car dans StableHLO, les fonctions ne sont pas des valeurs de première classe) et d'entrée (également fournis de manière statique). Le type d'entrées et de sorties consommée et produite par une opération dépend de son mode mnémotechnique. Par exemple, add "op" consomme deux valeurs d'entrée et génère une valeur de sortie. En comparaison, les L'opération select_and_scatter consomme 3 valeurs d'entrée, 2 fonctions d'entrée et 3 attributs d'entrée.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Les fonctions d'entrée (également appelées fonctions anonymes) sont très semblables aux fonctions nommées, sauf que: 1) elles n'ont pas d'identifiant (donc le nom "anonyme"), 2) ils ne déclarent pas de types de sortie (les types de sortie sont inférée à partir de l'opération return dans la fonction).

La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir les Unused ci-dessus), qui assure la compatibilité avec MLIR. Dans MLIR, il existe un concept plus général de « régions » qui peut comporter plusieurs "blocs" reliés entre eux par des jump ops. Ces blocs ont des ID qui correspondent à l'environnement de production Unused, afin de pouvoir les distinguer les uns des autres. StableHLO n'a pas de jump ops. La partie correspondante de la syntaxe MLIR est donc unused (mais est toujours là).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Les attributs d'entrée ont un nom et une valeur faisant partie des constantes. Elles constituent le principal moyen de spécifier des métadonnées statiques pour un programme éléments. Par exemple, l'opération concatenate utilise l'attribut dimension pour spécifier la dimension selon laquelle ses valeurs d'entrée sont concaténées. De même, L'opération slice utilise plusieurs attributs tels que start_indices et limit_indices. pour spécifier les limites utilisées pour segmenter la valeur d'entrée.

Actuellement, les programmes StableHLO contiennent parfois des attributs qui ne sont pas décrites dans ce document. À l'avenir, nous prévoyons de absorber ces attributs dans l'opset StableHLO ou les interdit dans les programmes StableHLO. En attendant, voici la liste de ces Attributs:

  • layout (#629)
  • mhlo.frontend_attributes (#628)
  • mhlo.sharding (#619)
  • output_operand_aliases (#740)
  • Métadonnées de localisation (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

La signature d'opération comprend les types de toutes les valeurs d'entrée (la liste des types sur à gauche de ->) et les types de toutes les valeurs de sortie (liste des à droite de ->). À proprement parler, les types d'entrée sont redondants, et les types de sortie le sont presque toujours également (parce que pour la plupart des opérations StableHLO, les types de sortie peuvent être déduits des entrées). Néanmoins, l'opération la signature fait délibérément partie de la syntaxe StableHLO pour assurer la compatibilité avec MLIR.

Vous trouverez ci-dessous un exemple d'opération dont l'expression mnémotechnique est select_and_scatter. Il en consomme 3 valeurs d'entrée (%operand, %source et %init_value), deux fonctions d'entrée et trois attributs d'entrée (window_dimensions, window_strides et padding). Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée. (mais pas les types de fonctions d'entrée et d'attributs qui sont fournis de façon intégrée).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Constantes

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Les constantes StableHLO ont un littéral et un type qui représentent ensemble une valeur StableHLO. En général, le type fait partie de la syntaxe constante, sauf lorsqu'elle ne présente aucune ambiguïté (par exemple, une constante booléenne a sans ambiguïté le type i1, tandis qu'une constante entière peut avoir plusieurs types possibles).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Les constantes booléennes représentent les valeurs booléennes true et false. Booléen sont de type i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Les constantes entières représentent des valeurs entières via des chaînes utilisant des nombres décimaux ou en notation hexadécimale. Autres bases (par exemple, binaire ou octale, ne sont pas prises en charge. Les constantes entières présentent les contraintes suivantes:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Les constantes à virgule flottante représentent des valeurs à virgule flottante via des chaînes qui utiliser la notation décimale ou scientifique. De plus, la notation hexadécimale peut être utilisé pour spécifier directement les bits sous-jacents au format à virgule flottante de le type correspondant. Les constantes à virgule flottante présentent les contraintes suivantes:

  • (C1) Si vous utilisez la notation non hexadécimale, is_wellformed(float_literal, float_type)
  • (C2) Si la notation hexadécimale est utilisée, size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Les constantes complexes représentent des valeurs complexes à l'aide de listes d'une partie réelle (en premier) et une partie imaginaire (en deuxième). Par exemple : (1.0, 0.0) : complex<f32> représente 1.0 + 0.0i, et (0.0, 1.0) : complex<f32> représente 0.0 + 1.0i. L'ordre dans lequel ces sont ensuite stockées en mémoire et définies par l'implémentation. Constantes complexes présentent les contraintes suivantes:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Les constantes Tensor représentent les valeurs des Tensors au moyen de listes imbriquées spécifiées via Numpy. Exemple : dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> représente une valeur de Tensor avec le mappage suivant entre les index et les éléments: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5 {1, 2} => 6 L'ordre dans lequel ces éléments sont ensuite stockés en mémoire est définies par l'implémentation. Les constantes Tensor présentent les contraintes suivantes:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), où: <ph type="x-smartling-placeholder">
      </ph>
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), où: <ph type="x-smartling-placeholder">
      </ph>
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • sinon false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Les constantes de Tensors quantifiées représentent les valeurs de Tensor quantifiées à l'aide des mêmes en tant que constantes de Tensor, les éléments étant spécifiés comme des constantes de leur type de stockage. Les constantes de Tensor quantifiées présentent les contraintes suivantes:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Les littéraux de chaîne sont composés d'octets spécifiés à l'aide de caractères ASCII et et d'échappement. Comme ils sont indépendants de l'encodage, leur interprétation octets est défini par l'implémentation. Les littéraux de chaîne sont de type string.

Opérations

abs

Sémantique

Effectue une opération des abscisses par élément sur le Tensor operand et génère une result. Tensor. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les entiers signés: module entier.
  • Pour les nombres à virgule flottante: abs d'IEEE-754.
  • Pour les nombres complexes: module complexe.
  • Pour les types quantifiés: dequantize_op_quantize(abs, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1-C2)

Sorties

Nom Type Contraintes
result Tensor de type entier signé ou à virgule flottante, ou Tensor quantifié par Tensor (C1-C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • baseline_element_type(operand) dans les autres cas.

Exemples

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Autres exemples

add

Sémantique

Effectue l'addition par élément de deux Tensors lhs et rhs, et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: opérateur logique OU.
  • Pour les entiers: addition d'entiers.
  • Pour les nombres à virgule flottante: addition d'IEEE-754.
  • Pour les nombres complexes: addition complexe.
  • Pour les types quantifiés: dequantize_op_quantize(add, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié (C1-C6)
(I2) rhs Tensor ou Tensor quantifié (C1-C5), (C7)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1-C7)

Contraintes

  • Si l'opération utilise des Tensors non quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C1) type(lhs) = type(rhs) = type(result).
  • Si l'opération utilise des Tensors quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Si is_per_axis_quantized(lhs), alors quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = quantization_dimension(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Autres exemples

after_all

Sémantique

Assurez-vous que les opérations à l'origine de l'inputs sont exécutées avant toute opérations qui dépendent de result. L'exécution de cette opération n'a aucun effet, il n'existe que pour établir les dépendances de données de result à inputs.

Entrées

Libellé Nom Type
(I1) inputs nombre variadique de token

Sorties

Nom Type
result token

Exemples

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Autres exemples

all_gather

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs des Tensors operands de chaque processus le long de all_gather_dim et produit results.

L'opération divise la grille de processus StableHLO en process_groups, qui est définis comme suit:

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group:

  • operands...@receiver = [operand@sender for sender in process_group] pour tous receiver dans process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) pour tous process dans process_group.

Entrées

Libellé Nom Type Contraintes
(I1) operands nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1), (C6)
(I2) all_gather_dim constante de type si64 (C1), (C6)
(I3) replica_groups Constante de Tensor bidimensionnelle de type si64 (C2-C4)
(I4) channel_id constante de type si64 (C5)
(I5) use_global_device_ids constante de type i1 (C5)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C6)

Contraintes

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Si use_global_device_ids = true, alors channel_id > 0.
  • (C6) type(results...) = type(operands...), sauf: <ph type="x-smartling-placeholder">
      </ph>
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

Autres exemples

all_reduce

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, applique une réduction fonction computation sur les valeurs des Tensors operands de chaque processus et produit des Tensors results.

L'opération divise la grille de processus StableHLO en process_groups, qui est définis comme suit:

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group:

  • results...@process[result_index] = exec(schedule) pour une arborescence binaire schedule où: <ph type="x-smartling-placeholder">
      </ph>
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est une arborescence binaire définie par l'implémentation, dont l'ordre le balayage est to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

Entrées

Libellé Nom Type Contraintes
(I1) operands nombre variadique de Tensors ou Tensors quantifiés par Tensor (C5), (C6)
(I2) replica_groups nombre variadique de constantes de Tensor unidimensionnelles de type si64 (C1-C3)
(I3) channel_id constante de type si64 (C4)
(I4) use_global_device_ids constante de type i1 (C4)
(I5) computation fonction (C5)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C6-C7)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Si use_global_device_ids = true, alors channel_id > 0.
  • (C5) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>), où is_promotable(element_type(operand), E)
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

Autres exemples

all_to_all

Sémantique

all_to_all

Dans chaque groupe de processus de la grille de processus StableHLO, divise les valeurs de les Tensors operands le long de split_dimension en plusieurs parties, dispersent la division entre les processus, concatène les parties dispersées concat_dimension et génère des Tensors results. L'opération divise la grille de processus StableHLO en process_groups, qui est définis comme suit:

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

Ensuite, dans chaque process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) pour les sender de process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]receiver_index = process_group.index(receiver)
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Entrées

Libellé Nom Type Contraintes
(I1) operands nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1-C3), (C9)
(I2) split_dimension constante de type si64 (C1), (C2), (C9)
(I3) concat_dimension constante de type si64 (C3), (C9)
(I4) split_count constante de type si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Constante de Tensor bidimensionnelle de type si64 (C5-C8)
(I6) channel_id constante de type si64

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C9)

Contraintes

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...) sauf, si split_dimension != concat_dimension: <ph type="x-smartling-placeholder">
      </ph>
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

Autres exemples

et

Sémantique

Effectue l'opérateur AND par élément de deux Tensors lhs et rhs et génère une result Tensor. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: l'opérateur logique AND.
  • Pour les entiers: AND (ET) bit à bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type booléen ou entier (C1)
(I2) rhs Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

Autres exemples

atan2

Sémantique

Effectue des opérations atan2 au niveau des éléments sur les Tensors lhs et rhs, et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: atan2 d'IEEE-754.
  • Pour les nombres complexes: complexe atan2.
  • Pour les types quantifiés: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Autres exemples

batch_norm_grad

Sémantique

Calcule les gradients de plusieurs entrées de la rétropropagation de batch_norm_training. à partir de grad_output, et génère grad_operand, grad_scale et grad_offset Tensors. Plus formellement, cette opération peut être exprimée sous la forme d'une décomposition Opérations StableHLO existantes à l'aide de la syntaxe Python, comme suit:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1-C3), (C5)
(I2) scale Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C4), (C5)
(I3) mean Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C4)
(I4) variance Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C4)
(I5) grad_output Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C2), (C3)
(I6) epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C5)

Sorties

Nom Type Contraintes
grad_operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C2), (C3)
grad_scale Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C4)
grad_offset Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C4)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand grad_scale et grad_offset ont le même baseline_element_type.
  • (C3) operand, grad_output et grad_operand ont la même forme.
  • (C4) scale, mean, variance, grad_scale et grad_offset ont même forme.
  • (C5) size(scale) = dim(operand, feature_index).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Sémantique

Normalise le Tensor operand pour toutes les dimensions, à l'exception du feature_index et génère un Tensor result. Plus formellement, cela peut être exprimée sous la forme d'une décomposition en opérations StableHLO existantes à l'aide de la syntaxe Python suivante:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Pour les types quantifiés, exécute dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1-C7)
(I2) scale Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C3)
(I3) offset Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C4)
(I4) mean Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C5)
(I5) variance Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor (C2), (C6)
(I6) epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C2), (C7)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance et result ont même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Sémantique

Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de feature_index et normalise le Tensor operand générant output, batch_mean et batch_var. Plus formellement, cette opération peut être exprimée sous la forme décomposition en opérations StableHLO existantes à l'aide de la syntaxe Python en tant que ce qui suit:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)
(I2) scale Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor (C2), (C3)
(I3) offset Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor (C2), (C4)
(I4) epsilon constante de type f32 (C1), (C3-C6)
(I5) feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
output Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C7)
batch_mean Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor (C2), (C5)
batch_var Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor (C2), (C6)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var et output ont le même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Sémantique

Effectue une opération de diffusion de bits sur le Tensor operand et génère un Tensor result. où les bits de l'ensemble du Tensor operand sont réinterprétés à l'aide du Type de Tensor result.

Plus formellement, étant donné que E = element_type(operand), E' = element_type(result), et R = rank(operand):

  • Si la valeur est num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
  • Si la valeur est num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
  • Si la valeur est num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])

bits renvoie la représentation en mémoire d'une valeur donnée et de son comportement. est définie par l'implémentation, car la représentation exacte des Tensors est définie par l'implémentation, et la représentation exacte des types d'éléments définies par l'implémentation.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C2)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1-C2)

Contraintes

  • (C1) Avec E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) et R = rank(operand): <ph type="x-smartling-placeholder">
      </ph>
    • Si num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Si la valeur est num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Si la valeur est num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Si is_complex(operand) or is_complex(result), alors is_complex(operand) and is_complex(result)

Exemples

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Autres exemples

broadcast_in_dim

Sémantique

Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données dans le Tensor operand et produit un Tensor result. Plus formellement, result[result_index] = operand[operand_index] où pour les d de axes(operand):

  • operand_index[d] = 0 si dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C2), (C5-C6)
(I2) broadcast_dimensions Constante de Tensor unidimensionnelle de type si64 (C2-C6)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1), (C3), (C5-C6)

Contraintes

  • (C1) La valeur element_type(result) est donnée par: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), à l'exception de quantization_dimension(operand), scales(operand) et zero_points(operand) peuvent être différents de quantization_dimension(result), scales(result) et zero_points(result) resp., sinon.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Pour tous les d de axes(operand): <ph type="x-smartling-placeholder">
      </ph>
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si is_per_axis_quantized(result): <ph type="x-smartling-placeholder">
      </ph>
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si la valeur est dim(operand, quantization_dimension(operand)) = 1, alors scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))

Exemples

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Autres exemples

coque

Sémantique

Génère le résultat de l'exécution d'une seule fonction de branches en fonction de la valeur de index. Plus formellement, result = selected_branch() où:

  • selected_branch = branches[index] si 0 <= index < size(branches).
  • selected_branch = branches[-1] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) index Tensor à 0 dimensions de type si32
(I2) branches nombre variadique de fonctions (C1-C4)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C4)

Contraintes

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Exemples

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Autres exemples

cbrt

Sémantique

Effectue une opération racine cubique par élément sur le Tensor operand et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: rootn(x, 3) d'IEEE-754.
  • Pour les nombres complexes: racine cubique complexe.
  • Pour les types quantifiés: dequantize_op_quantize(cbrt, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Autres exemples

ceil

Sémantique

Effectue un ceil par élément du Tensor operand et produit un Tensor result. Elle met en œuvre l'opération roundToIntegralTowardPositive de la norme IEEE-754. spécifique. Pour les types quantifiés, exécute dequantize_op_quantize(ceil, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Autres exemples

Cholesky

Sémantique

Calcule la décomposition par Cholesky d'un lot de matrices.

Plus formellement, pour tous les i de index_space(result), result[i0, ..., iR-3, :, :] est une décomposition cholesky de a[i0, ..., iR-3, :, :], qui se présente sous la forme d'un triangle inférieur (si lower est true) ou une matrice triangulaire supérieure (si lower est false). Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict le triangle inférieur strict, sont définies par l'implémentation.

S'il existe des valeurs i où la matrice d'entrée n'est pas de type hermitien , le comportement n'est pas défini.

Pour les types quantifiés, exécute dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1-C3)
(I2) lower Constante de Tensor à 0 dimensions de type i1

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Exemples

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

limiter

Sémantique

Bloque chaque élément du Tensor operand entre une valeur minimale et une valeur maximale et produit un Tensor result. Plus formellement, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), où min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index] Pour les types quantifiés, exécute dequantize_op_quantize(clamp, min, operand, max, type(result)).

Ordonner des nombres complexes implique une sémantique surprenante, C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir. pour cette opération (#560).

Entrées

Libellé Nom Type Contraintes
(I1) min Tensor ou Tensor quantifié par Tensor (C1), (C3)
(I2) operand Tensor ou Tensor quantifié par Tensor (C1-C4)
(I3) max Tensor ou Tensor quantifié par Tensor (C2), (C3)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C4)

Contraintes

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Exemples

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Autres exemples

collective_broadcast

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du Tensor operand du processus source aux processus cibles et produit une Tensor result.

L'opération divise la grille de processus StableHLO en process_groups, qui est définis comme suit:

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

Ensuite, result@process est donné par:

  • operand@process_groups[i, 0] s'il existe un i de sorte que le processus soit dans le pays suivant : process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) sinon.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C3)
(I2) replica_groups nombre variadique de constantes de Tensor unidimensionnelles de type si64 (C1), (C2)
(I3) channel_id constante de type si64

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C3)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, où N est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C3) type(result) = type(operand).

Exemples

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du Tensor operand du processus source au processus cible et produit une Tensor result.

L'opération divise la grille de processus StableHLO en process_groups, qui est définis comme suit:

  • cross_replica(source_target_pairs) si channel_id <= 0.
  • cross_partition(source_target_pairs) si channel_id > 0.

Ensuite, result@process est donné par:

  • operand@process_groups[i, 0], s'il existe un i tel que process_groups[i, 1] = process
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) sinon.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C5)
(I2) source_target_pairs Constante de Tensor bidimensionnelle de type si64 (C1-C4)
(I3) channel_id constante de type si64

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, où N est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C5) type(result) = type(operand).

Exemples

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Autres exemples

compare

Sémantique

Effectue une comparaison par élément des Tensors lhs et rhs en fonction comparison_direction et compare_type, et génère un Tensor result.

Les valeurs de comparison_direction et compare_type ont les valeurs suivantes : sémantique:

Pour les éléments de type booléen et entier:

  • EQ : lhs = rhs.
  • NE : lhs != rhs.
  • GE : lhs >= rhs.
  • GT : lhs > rhs.
  • LE : lhs <= rhs.
  • LT : lhs < rhs.

Pour les types d'éléments à virgule flottante avec compare_type = FLOAT, l'opération implémente les opérations IEEE-754 suivantes:

  • EQ : compareQuietEqual.
  • NE : compareQuietNotEqual.
  • GE : compareQuietGreaterEqual.
  • GT : compareQuietGreater.
  • LE : compareQuietLessEqual.
  • LT : compareQuietLess.

Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER, l'opération utilise la combinaison des opérations totalOrder et compareQuietEqual de IEEE-754.

Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag) est effectuées à l'aide des méthodes comparison_direction et compare_type fournies. Ordonner des nombres complexes implique une sémantique surprenante, C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir. lorsque comparison_direction est GE, GT, LE ou LT (560)

Pour les types quantifiés. exécute dequantize_compare(lhs, rhs, comparison_direction).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1-C3)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1-C2)
(I3) comparison_direction énumération de EQ, NE, GE, GT, LE et LT
(I4) compare_type énumération de FLOAT, TOTALORDER, SIGNED et UNSIGNED (C3)

Sorties

Nom Type Contraintes
result Tensor de type booléen (C2)

Contraintes

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • SIGNED si is_signed_integer(element_type(lhs)).
    • UNSIGNED si is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT ou TOTALORDER si is_float(element_type(lhs)).
    • FLOAT si is_complex(element_type(lhs)).

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Autres exemples

complexe

Sémantique

Effectue une conversion au niveau des éléments en une valeur complexe à partir d'une paire de valeurs réelles et des valeurs imaginaires, lhs et rhs, et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type f32 ou f64 (C1-C3)
(I2) rhs Tensor de type f32 ou f64 (C1)

Sorties

Nom Type Contraintes
result Tensor de type complexe (C2), (C3)

Contraintes

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) est de type complex<E>, où E = element_type(lhs)

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Autres exemples

composite

Sémantique

Encapsule une opération composée d'autres opérations StableHLO. en prenant inputs et composite_attributes, et en produisant results. La de l'opération est implémentée par l'attribut decomposition. La L'opération composite peut être remplacée par sa décomposition sans modifier le programme la sémantique. Dans les cas où l'intégration de la décomposition ne fournit pas utilisez plutôt custom_call.

Le champ version (par défaut, 0) est utilisé pour indiquer quand un élément composite est changement de sémantique.

Entrées

Libellé Nom Type
(I1) inputs nombre variadique de valeurs
(I2) name constante de type string
(I3) composite_attributes dictionnaire d'attributs
(I4) decomposition constante de type string
(I5) version constante de type si32

Sorties

Nom Type
results nombre variadique de valeurs

Contraintes

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Exemples

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

Autres exemples

concatenate

Sémantique

Concatène inputs avec la dimension dimension dans le même ordre que celui donné et génère un Tensor result. Plus formellement, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], où:

  1. id = d0 + ... + dk-1 + kd.
  2. d est égal à dimension, et d0... correspond à la de taille de dimension sur inputs.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1-C6)
(I2) dimension constante de type si64 (C2), (C4), (C6)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C5-C6)

Contraintes

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)), sauf dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]), sauf: <ph type="x-smartling-placeholder">
      </ph>
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Exemples

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Autres exemples

constante

Sémantique

Génère un Tensor output à partir d'une constante value.

Entrées

Libellé Nom Type Contraintes
(I1) value constante (C1)

Sorties

Nom Type Contraintes
output Tensor ou Tensor quantifié (C1)

Contraintes

  • (C1) type(value) = type(output).

Exemples

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Autres exemples

d'effectuer une conversion

Sémantique

Effectue une conversion au niveau des éléments d'un type d'élément à un autre sur operand et produit un Tensor result.

Pour les conversions boolean-to-any-supported-type, la valeur false est convertie en zéro, et la valeur true en un. Pour any-supported-type-to-boolean, une valeur nulle est convertie en false et les valeurs non nulles sont converties en true. Découvrez ci-dessous comment cela fonctionnent pour des types complexes.

Pour les conversions impliquant entier en entier, entier en virgule flottante ou floating-point-to-floating-point, si la valeur source peut être exactement représentée dans le type "destination", la valeur du résultat correspond exactement représentation. Sinon, le comportement reste à déterminer (#180)

Pour les conversions impliquant une conversion floating-point-to-integer, la partie fractionnaire est la suivante : tronquées. Si la valeur tronquée ne peut pas être représentée dans le type de destination, comportement à déterminer (#180).

Les conversions complexes à complexes suivent le même comportement : les conversions floating-point-to-floating-point pour convertir des valeurs réelles et des parties imaginaires.

Pour les conversions de type complex-to-any-other-type et complex-to-any-other-type : la valeur imaginaire source est ignorée, ou la valeur imaginaire de destination est mis à zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.

En principe, cette opération peut exprimer une déquantification (conversion à partir de Tensors quantifiés en Tensors standards), la quantification (conversion des Tensors en Tensors quantifiés) et requantification (conversion entre des Tensors quantifiés) mais pour le moment, nous avons des opérations dédiées. uniform_dequantize pour le premier cas d'utilisation et uniform_quantize pour le deuxième et troisième cas d'utilisation. À l'avenir, ces deux opérations pourront être fusionnées dans convert (#1576).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor (C1)

Contraintes

  • (C1) shape(operand) = shape(result).

Exemples

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Autres exemples

Convolution

Sémantique

Calcule les produits scalaires entre les fenêtres de lhs et des tranches de rhs, puis produit result Le schéma suivant montre comment les éléments de result sont calculés à partir de lhs et rhs à l'aide d'un exemple concret.

Convolution

Plus formellement, envisagez de recadrer les entrées ci-dessous en termes de lhs. pour pouvoir exprimer des fenêtres de lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1)
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0])
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Ce recadrement utilise les fonctions d'assistance suivantes:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1], où j[d] = i[permutation[d]].

Si feature_group_count = 1 et batch_group_count = 1, alors pour toutes output_spatial_index à index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product, où:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]) Il semble que cette fonctionnalité ne soit plus utilisée. Nous prévoyons donc de la supprimer à l'avenir. (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Si la valeur est feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension).

Si la valeur est batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension)

Pour les types quantifiés, exécute dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Pour les types hybrides quantifiés, exécute hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs Tensor ou Tensor quantifié (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides Constante de Tensor unidimensionnelle de type si64 (C2-C3), (C25)
(I4) padding Constante de Tensor bidimensionnelle de type si64 (C4), (C25)
(I5) lhs_dilation Constante de Tensor unidimensionnelle de type si64 (C5-C6), (C25)
(I6) rhs_dilation Constante de Tensor unidimensionnelle de type si64 (C7-C8), (C25)
(I7) window_reversal Constante de Tensor unidimensionnelle de type i1 (C9)
(I8) input_batch_dimension constante de type si64 (C10), (C13), (C25)
(I9) input_feature_dimension constante de type si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C12), (C13), (C25)
(I11). kernel_input_feature_dimension constante de type si64 (C14), (C18)
(I12). kernel_output_feature_dimension constante de type si64 (C15-C16), (C18), (C25), (C29)
(I13). kernel_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C17-C18), (C25)
(I14). output_batch_dimension constante de type si64 (C20), (C25)
(I15). output_feature_dimension constante de type si64 (C20), (C25), (C30)
(I16). output_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C19-C20), (C25)
(I17) feature_group_count constante de type si64 (C11), (C14), (C16), (C21), (C23)
(I18). batch_group_count constante de type si64 (C10), (C15), (C22), (C23), (C25)
(I19). precision_config nombre variadique d'énumérations de DEFAULT, HIGH et HIGHEST (C24)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C25-C28), (C30), (C32-34)

Contraintes

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Avec input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Avec kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Donnée output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • dim(lhs, input_batch_dimension) / batch_group_count si result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si result_dim = output_feature_dimension.
    • num_windows dans les autres cas, où:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si l'opération utilise des Tensors non quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si l'opération utilise des Tensors quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Si is_per_axis_quantized(rhs), puis sur quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Si is_per_axis_quantized(result), alors quantization_dimension(result) = output_feature_dimension
    • Si la valeur est is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Si is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result)
    • Si la valeur est !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Exemples

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

Autres exemples

cosinus

Sémantique

Effectue une opération cosinus par élément sur le Tensor operand et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: cos d'IEEE-754.
  • Pour les nombres complexes: cosinus complexe.
  • Pour les types quantifiés: dequantize_op_quantize(cosine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Autres exemples

count_leading_zeros

Sémantique

Effectue un comptage par élément du nombre de bits zéro au début dans operand. et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Autres exemples

custom_call

Sémantique

Encapsule une opération call_target_name définie par l'implémentation qui accepte inputs et called_computations, et génère results. has_side_effect, backend_config et api_version peuvent être utilisés pour fournir définies par l'implémentation.

Pour le moment, cette opération contient une collection assez désorganisée de de métadonnées qui reflètent l'évolution organique de son fonctionnement équivalent dans le compilateur XLA. À l'avenir, nous prévoyons d'unifier ces métadonnées (#741)

Entrées

Libellé Nom Type
(I1) inputs nombre variadique de valeurs
(I2) call_target_name constante de type string
(I3) has_side_effect constante de type i1
(I4) backend_config constante de type string ou dictionnaire d'attributs
(I5) api_version constante de type si32
(I6) called_computations nombre variadique de constantes de type string

Sorties

Nom Type
results nombre variadique de valeurs

Exemples

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

diviser

Sémantique

Effectue une division par élément des Tensors lhs et diviseur rhs. et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les entiers: division des entiers, qui produit le quotient algébrique avec n'importe lequel partie fractionnaire supprimée.
  • Pour les nombres à virgule flottante: division d'IEEE-754.
  • Pour les nombres complexes: division complexe.
  • Pour les types quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Autres exemples

dot_general

Sémantique

Calcule les produits scalaires entre les tranches de lhs et les tranches de rhs, et génère une Tensor result.

Plus formellement, result[result_index] = dot_product, où:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_indexsize(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) et size(result_rhs_index) = size(rhs_result_dimensions)
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

Pour les types quantifiés, exécute dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Pour les types hybrides quantifiés, exécute hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config contrôle le compromis entre vitesse et précision pour sur les backends d'accélérateur. Il peut s'agir de l'un des éléments suivants (au niveau Dans un premier temps, la sémantique de ces valeurs d'énumération est sous-spécifiée, mais nous n'avons prévoient de résoudre ce problème dans #755):

  • DEFAULT: calcul le plus rapide, mais approximation la moins précise du numéro d'origine.
  • HIGH: calcul plus lent, mais approximation plus précise du numéro d'origine.
  • HIGHEST: calcul le plus lent, mais approximation la plus précise du numéro d'origine.

Un DotAlgorithm définit les principales propriétés de l'algorithme utilisé pour implémenter l'opération par point, qui définit également la précision. Si l'attribut de l'algorithme sont définis, l'élément precision_config doit être défini sur DEFAULT. DotAlgorithms n'ont pas de valeur par défaut, car les paramètres par défaut sont définis. Par conséquent, tous les champs de l'algorithme peuvent être définis sur None pour spécifier une algorithme vide avec points, qui utilisera plutôt la valeur precision_config.

Les champs DotAlgorithm incluent:

  • lhs_precision_type et rhs_precision_type, la précision des fonctions LHS et à droite de l'opération sont arrondies. Les types de précision sont indépendants des de stockage des entrées et des sorties.
  • accumulation_type est la précision utilisée pour l'accumulation.
  • lhs_component_count, rhs_component_count et num_primitive_operations s'appliquent lorsque nous faisons un algorithme qui décompose le LHS et/ou le RHS en plusieurs composants et effectue plusieurs tâches sur ces valeurs : généralement pour émuler une précision plus élevée (par exemple, Exploiter le type de données bfloat16 de l'intelligence artificielle pour effectuer des calculs de plus grande précision: bf16_6x, tf32_3x, etc.). Pour les algorithmes sans décomposition, ces valeurs doit être défini sur 1.
  • allow_imprecise_accumulation pour spécifier si l'accumulation dans une précision inférieure est autorisé pour certaines étapes (par exemple, CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Exemples d'attributs DotAlgorithm:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

C'est aux implémentations de décider quelles combinaisons sont acceptées. Dans général, il n'est pas garanti que chaque algorithme soit compatible type d'accélérateur par le consommateur de StableHLO. Si un algorithme donné n'est pas une erreur doit être générée, plutôt que de revenir à une alternative. La vérification StableHLO s'efforcera au mieux, empêchant ainsi les algorithmes qui ne sont pas compatibles avec aucun matériel.

Voir xla_data.proto > Algorithm pour certaines valeurs d'algorithme acceptées. Le ticket n° 2483 capture le plan pour créer un document centralisé sur les algorithmes pris en charge par le backend.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs Tensor ou Tensor quantifié (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions Constante de Tensor unidimensionnelle de type si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Constante de Tensor unidimensionnelle de type si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config nombre variadique d'énumérations de DEFAULT, HIGH et HIGHEST (C11), (C21)
(I8) lhs_precision_type FloatType ou TensorFloat32 (C21)
(I9) rhs_precision_type FloatType ou TensorFloat32 (C21)
(I10) accumulation_type FloatType ou TensorFloat32 (C21)
(I11). lhs_component_count constante de type si32 (C21), (C22)
(I12). rhs_component_count constante de type si32 (C21), (C23)
(I13). num_primitive_operations constante de type si32 (C21), (C24)
(I14). allow_imprecise_accumulation constante de type bool (C21)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C12), (C14), (C18-C20)

Contraintes

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Si l'opération utilise des Tensors non quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C13) element_type(lhs) = element_type(rhs).
  • Si l'opération utilise des Tensors quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) pas dans rhs_contracting_dimensions.
    • Si la valeur est is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Si is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result)
    • Si la valeur est !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • Si la valeur est !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation): <ph type="x-smartling-placeholder">
      </ph>
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Exemples

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Autres exemples

dynamic_broadcast_in_dim

Sémantique

Cette opération est fonctionnellement identique à broadcast_in_dim op, mais la forme du résultat est spécifiée dynamiquement via output_dimensions.

L'opération accepte également les attributs facultatifs known_expanding_dimensions et known_non_expanding_dimensions. pour exprimer des connaissances statiques sur le comportement d'expansion des dimensions. Si aucune valeur n'est spécifiée, toutes les dimensions sont supposées pouvoir s'étendre.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C2), (C5-C6), (C9)
(I2) output_dimensions Tensor unidimensionnel de type entier (C7)
(I3) broadcast_dimensions Tensor constant unidimensionnel de type entier (C2-C6)
(I4) known_expanding_dimensions Tensor constant unidimensionnel de type entier (C8-C9)
(I5) known_non_expanding_dimensions Tensor constant unidimensionnel de type entier (C8-C9)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1), (C3), (C5-C7)

Contraintes

  • (C1) La valeur element_type(result) est donnée par: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), à l'exception de quantization_dimension(operand), scales(operand) et zero_points(operand) peuvent être différents de quantization_dimension(result), scales(result) et zero_points(result) resp., sinon.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Pour tous les d de axes(operand): <ph type="x-smartling-placeholder">
      </ph>
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si is_per_axis_quantized(result): <ph type="x-smartling-placeholder">
      </ph>
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si la valeur est dim(operand, quantization_dimension(operand)) = 1, alors scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_non_expanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_non_expanding_dimensions < rank(operand).

Exemples

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Autres exemples

dynamic_conv

Sémantique

Cette opération est fonctionnellement identique à Convolution op, mais la marge intérieure est spécifiée dynamiquement via padding.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs Tensor ou Tensor quantifié (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding Tensor bidimensionnel de type entier (C4)
(I4) window_strides Constante de Tensor unidimensionnelle de type si64 (C2-C3)
(I5) lhs_dilation Constante de Tensor unidimensionnelle de type si64 (C5-C6)
(I6) rhs_dilation Constante de Tensor unidimensionnelle de type si64 (C7-C8)
(I7) window_reversal Constante de Tensor unidimensionnelle de type i1 (C9)
(I8) input_batch_dimension constante de type si64 (C10), (C13)
(I9) input_feature_dimension constante de type si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C12), (C13)
(I11). kernel_input_feature_dimension constante de type si64 (C14), (C18)
(I12). kernel_output_feature_dimension constante de type si64 (C15-C16), (C18), (C28)
(I13). kernel_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C17-C18)
(I14). output_batch_dimension constante de type si64 (C20)
(I15). output_feature_dimension constante de type si64 (C20), (C29)
(I16). output_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C19-C20)
(I17) feature_group_count constante de type si64 (C11), (C14), (C16), (C21), (C23)
(I18). batch_group_count constante de type si64 (C10), (C15), (C22), (C23)
(I19). precision_config nombre variadique d'énumérations de DEFAULT, HIGH et HIGHEST (C24)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C25-C27), (C29), (C31-C33)

Contraintes

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Avec input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Avec kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Donnée output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]: <ph type="x-smartling-placeholder">
      </ph>
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • dim(lhs, input_batch_dimension) / batch_group_count si result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si result_dim = output_feature_dimension.
    • num_windows dans les autres cas, où:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si l'opération utilise des Tensors non quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si l'opération utilise des Tensors quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Si is_per_axis_quantized(rhs), puis sur quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Si is_per_axis_quantized(result), alors quantization_dimension(result) = output_feature_dimension
    • Si la valeur est is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Si is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result)
    • Si la valeur est !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Exemples

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

Autres exemples

dynamic_gather

Sémantique

Cette opération est fonctionnellement identique à rassembler op, avec le slice_sizes spécifié de manière dynamique en tant que valeur.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices Tensor de type entier (C2), (C3), (C13)
(I3) slice_sizes Tensor unidimensionnel de type entier (C8), (C11-C13)
(I4) offset_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C6-C8), (C13)
(I6) start_index_map Constante de Tensor unidimensionnelle de type si64 (C3), (C9), (C10)
(I7) index_vector_dim constante de type si64 (C2), (C3), (C13)
(I8) indices_are_sorted constante de type i1

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C5), (C13-C14)

Contraintes

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) où: <ph type="x-smartling-placeholder">
      </ph>
    • batch_dim_sizes = shape(start_indices), sauf que la dimension de start_indices correspondant à index_vector_dim n'est pas inclus.
    • offset_dim_sizes = shape(slice_sizes), sauf que les dimensions dans slice_sizes correspondant à collapsed_slice_dims ne sont pas incluses.
    • combine place batch_dim_sizes sur les axes correspondant à batch_dims et offset_dim_sizes sur les axes correspondant à offset_dims.
  • (C14) element_type(operand) = element_type(result).

Exemples

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Autres exemples

dynamic_iota

Sémantique

Cette opération est fonctionnellement identique à iota op, mais la forme du résultat est spécifiée dynamiquement via output_shape.

Entrées

Libellé Nom Type Contraintes
(I1) output_shape Tensor unidimensionnel de type entier (C1), (C2)
(I2) iota_dimension si64 (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C2)

Contraintes

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Exemples

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

Autres exemples

dynamic_pad

Sémantique

Cette opération est fonctionnellement identique à pad op, mais avec edge_padding_low, edge_padding_high et interior_padding spécifiées dynamiquement en tant que valeurs.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C4)
(I2) padding_value Tensor à 0 dimensions ou Tensor quantifié par Tensor (C1)
(I3) edge_padding_low Tensor unidimensionnel de type entier (C1), (C4)
(I4) edge_padding_high Tensor unidimensionnel de type entier (C1), (C4)
(I5) interior_padding Tensor unidimensionnel de type entier (C2-C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C3-C6)

Contraintes

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemples

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Autres exemples

dynamic_reshape

Sémantique

Cette opération est fonctionnellement identique à remodeler op, mais la forme du résultat est spécifiée dynamiquement via output_shape.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C3)
(I2) output_shape Tensor unidimensionnel de type entier (C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1-C4)

Contraintes

  • (C1) La valeur element_type(result) est donnée par: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), à l'exception de quantization_dimension(operand) et Sinon, quantization_dimension(result) peut être différent.
  • (C2) size(operand) = size(result).
  • (C3) Si is_per_axis_quantized(operand): <ph type="x-smartling-placeholder">
      </ph>
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

Autres exemples

dynamic_slice

Sémantique

Extrait une tranche de operand à l'aide d'index de départ calculés dynamiquement. et génère un Tensor result. start_indices contiennent les index de départ de la part de chaque dimension susceptible d'être ajustée, et slice_sizes contenir les tailles du secteur pour chaque dimension. Plus formellement, result[result_index] = operand[operand_index] où:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C4)
(I2) start_indices nombre variadique de Tensors à 0 dimensions de type entier (C2), (C3)
(I3) slice_sizes Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C5)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Exemples

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Autres exemples

dynamic_update_slice

Sémantique

Génère un Tensor result égal au Tensor operand, sauf que la tranche commençant à start_indices est mise à jour avec les valeurs de update. Plus formellement, result[result_index] est défini comme suit:

  • update[update_index] si 0 <= update_index < shape(update) où: <ph type="x-smartling-placeholder">
      </ph>
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C4), (C6)
(I2) update Tensor ou Tensor quantifié par Tensor (C2), (C3), (C6)
(I3) start_indices nombre variadique de Tensors à 0 dimensions de type entier (C4), (C5)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Exemples

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Autres exemples

exponentiel

Sémantique

Effectue une opération exponentielle au niveau des éléments sur le Tensor operand et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: exp d'IEEE-754.
  • Pour les nombres complexes: exponentiel complexe
  • Pour les types quantifiés: dequantize_op_quantize(exponential, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Autres exemples

exponential_minus_one

Sémantique

Effectue une opération exponentielle par élément moins une sur le Tensor operand et et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: expm1 d'IEEE-754.
  • Pour les nombres complexes: exponentiel complexe moins un.
  • Pour les types quantifiés: dequantize_op_quantize(exponential_minus_one, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Autres exemples

fft

Sémantique

Effectue les transformations de Fourier avant et inverses pour des valeurs réelles et complexes d'entrées et de sorties.

fft_type est l'un des éléments suivants :

  • FFT: transfert de données FFT complexe à complexe.
  • IFFT: FFT complexe à complexe inverse.
  • RFFT: transfert FFT réel vers complexe.
  • IRFFT: FFT réel-complexe inverse (il s'agit d'une méthode complexe qui renvoie un résultat réel).

Plus formellement, étant donné la fonction fft, qui accepte des Tensors unidimensionnels de les types complexes en entrée, produit des Tensors unidimensionnels des mêmes types que puis calcule la transformation de Fourier discrète:

Pour fft_type = FFT, result est défini comme le résultat final d'une série de L où L = size(fft_length). Par exemple, pour L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

De plus, étant donné la fonction ifft, qui a la même signature de type et calcule l'inverse de fft:

Pour fft_type = IFFT, result est défini comme l'inverse des calculs. pour fft_type = FFT. Par exemple, pour L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

De plus, étant donné la fonction rfft, qui accepte des Tensors unidimensionnels de types à virgule flottante produit des Tensors unidimensionnels de types complexes des la même sémantique à virgule flottante et fonctionne comme suit:

  • rfft(real_operand) = truncated_result
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand)
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Lorsque la transformation de Fourier discrète est calculée pour des opérandes réels, le premier Les éléments N/2 + 1 du résultat définissent sans ambiguïté le reste du résultat, Le résultat de rfft est donc tronqué pour éviter le calcul d'éléments redondants).

Pour fft_type = RFFT, result est défini comme le résultat final d'une série de L où L = size(fft_length). Par exemple, pour L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Enfin, pour la fonction irfft, qui a le même type de signature et calcule l'inverse de rfft:

Pour fft_type = IRFFT, result est défini comme l'inverse des calculs. pour fft_type = RFFT. Par exemple, pour L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type complexe ou à virgule flottante (C1), (C2), (C4), (C5)
(I2) fft_type énumération de FFT, IFFT, RFFT et IRFFT (C2), (C5)
(I3) fft_length Constante de Tensor unidimensionnelle de type si64 (C1), (C3), (C4)

Sorties

Nom Type Contraintes
result Tensor de type complexe ou à virgule flottante (C2), (C4), (C5)

Contraintes

  • (C1) size(fft_length) <= rank(operand).
  • (C2) La relation entre les types d'éléments operand et result varie: <ph type="x-smartling-placeholder">
      </ph>
    • Si fft_type = FFT, element_type(operand) et element_type(result) sont du même type complexe.
    • Si fft_type = IFFT, element_type(operand) et element_type(result) sont du même type complexe.
    • Si la valeur est fft_type = RFFT, element_type(operand) est un type à virgule flottante et element_type(result) est un type complexe de la même valeur à virgule flottante la sémantique.
    • Si la valeur est fft_type = IRFFT, element_type(operand) est un type complexe et element_type(result) est un type à virgule flottante de la même valeur la sémantique.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Si operand et result comportent le Tensor real d'un type à virgule flottante, puis shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand), sauf: <ph type="x-smartling-placeholder">
      </ph>
    • Si la valeur est fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
    • Si la valeur est fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1

Exemples

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

sol

Sémantique

Effectue un prix plancher par élément du Tensor operand et produit un Tensor result. Elle met en œuvre l'opération roundToIntegralTowardNegative de la norme IEEE-754. spécifique. Pour les types quantifiés, exécute dequantize_op_quantize(floor, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Autres exemples

rassembler

Sémantique

Regroupe les tranches du Tensor operand à partir des décalages spécifiés dans start_indices. et génère un Tensor result.

Le schéma suivant montre comment les éléments de result sont mappés sur les éléments de operand à l'aide d'un exemple concret. Le schéma sélectionne quelques exemples de result index et explique en détail les indices operand auxquels ils correspondent.

rassembler

Plus formellement, result[result_index] = operand[operand_index], où:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • start_indices[bi0, ..., :, ..., biN]bi correspond à des éléments individuels dans batch_index et : sont insérés à l'index index_vector_dim, si index_vector_dim < rank(start_indices)
    • [start_indices[batch_index]] dans les autres cas.
  • Pour d_operand dans axes(operand), <ph type="x-smartling-placeholder">
      </ph>
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) si d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 dans les autres cas.
  • Pour d_operand dans axes(operand), <ph type="x-smartling-placeholder">
      </ph>
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] si d_operand = operand_batching_dims[i_batching] et d_start = start_indices_batching_dims[i_batching]
    • full_batching_index[d_operand] = 0 dans les autres cas.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN]oi sont des individus dans offset_index, et 0 est inséré au niveau des index de collapsed_slice_dims et operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index

Si indices_are_sorted est défini sur true, l'implémentation peut supposer que Les éléments start_indices sont triés par rapport à start_index_map. Dans le cas contraire, les valeurs ce comportement n'est pas défini. Plus formellement, pour toutes les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2)

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices Tensor de type entier (C2-C3), (C14), (C17), (C22)
(I3) offset_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims Constante de Tensor unidimensionnelle de type si64 (C13-C17)
(I7) start_index_map Constante de Tensor unidimensionnelle de type si64 (C3), (C18-C19)
(I8) index_vector_dim constante de type si64 (C2-C3), (C15), (C22)
(I9) slice_sizes Constante de Tensor unidimensionnelle de type si64 (C9), (C12), (C20-C22)
(I10) indices_are_sorted constante de type i1

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C5), (C22-C23)

Contraintes

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) où: <ph type="x-smartling-placeholder">
      </ph>
    • batch_dim_sizes = shape(start_indices), sauf que la dimension de start_indices correspondant à index_vector_dim n'est pas inclus.
    • offset_dim_sizes = slice_sizes, sauf que les dimensions de slice_sizes correspondant à collapsed_slice_dims et Les operand_batching_dims ne sont pas inclus.
    • combine place batch_dim_sizes sur les axes correspondant à batch_dims et offset_dim_sizes sur les axes correspondant à offset_dims.
  • (C23) element_type(operand) = element_type(result).

Exemples

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

Autres exemples

get_dimension_size

Sémantique

Génère la taille de l'élément dimension donné de l'élément operand. Plus formellement, result = dim(operand, dimension) La sémantique ne concerne que la forme composant du type. Le type d'élément peut être n'importe quel élément.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1)
(I2) dimension constante de type si64 (C1)

Sorties

Nom Type
result Tensor à 0 dimensions de type si32

Contraintes

  • (C1) 0 <= dimension < rank(operand).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Autres exemples

get_tuple_element

<ph type="x-smartling-placeholder">
</ph>

Sémantique

Extrait l'élément à la position index du tuple operand et génère une result Plus formellement, result = operand[index].

Entrées

Libellé Nom Type Contraintes
(I1) operand tuple (C1), (C2)
(I2) index constante de type si32 (C1), (C2)

Sorties

Nom Type Contraintes
result tout type compatible (C2)

Contraintes

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Exemples

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Autres exemples

si

Sémantique

Génère le résultat de l'exécution d'une seule fonction à partir de true_branch ou false_branch en fonction de la valeur de pred. Plus formellement, result = pred ? true_branch() : false_branch().

Entrées

Libellé Nom Type Contraintes
(I1) pred Tensor à 0 dimensions de type i1
(I2) true_branch fonction (C1-C3)
(I3) false_branch fonction (C1), (C2)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C3)

Contraintes

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Exemples

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Autres exemples

image

Sémantique

Extrait la partie imaginaire, élément par élément, de operand et génère une Tensor result. Plus formellement, pour chaque élément x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type complexe ou à virgule flottante (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • element_type(operand) dans les autres cas.

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Autres exemples

flux d'alimentation

Sémantique

Lit les données du flux d'entrée et génère results.

La sémantique de infeed_config est définie par l'implémentation.

results est constitué de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux des sorties distinctes pour plus de clarté (#670)

Entrées

Libellé Nom Type
(I1) token token
(I2) infeed_config constante de type string

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C1-C3)

Contraintes

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Exemples

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Autres exemples

IoTa

Sémantique

Remplit un Tensor output avec des valeurs par ordre croissant à partir de zéro. avec la dimension iota_dimension. Plus formellement,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Entrées

Libellé Nom Type Contraintes
(I1) iota_dimension si64 (C1)

Sorties

Nom Type Contraintes
output Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) 0 <= iota_dimension < rank(output).

Exemples

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Autres exemples

is_finite

Sémantique

Effectue une vérification au niveau des éléments si la valeur de x est finie (c'est-à-dire si ni +Inf, -Inf, or NaN) et génère un Tensor y. Implémentation de isFinite de la spécification IEEE-754. Pour les types quantifiés, le résultat est toujours true.

Entrées

Libellé Nom Type Contraintes
(I1) x Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
y Tensor de type booléen (C1)

Contraintes

  • (C1) shape(x) = shape(y).

Exemples

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Autres exemples

log

Sémantique

Effectue une opération logarithme par élément sur le Tensor operand et produit une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: log d'IEEE-754.
  • Pour les nombres complexes: logarithme complexe.
  • Pour les types quantifiés: dequantize_op_quantize(log, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Autres exemples

log_plus_one

Sémantique

Effectue un logarithme par élément plus une opération sur le Tensor operand. et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: logp1 d'IEEE-754.
  • Pour les nombres complexes: logarithme complexe plus un.
  • Pour les types quantifiés: dequantize_op_quantize(log_plus_one, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Autres exemples

logistique

Sémantique

Effectue des opérations logistiques par élément sur le Tensor operand et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: division(1, addition(1, exp(-x))) d'IEEE-754.
  • Pour les nombres complexes: logistique complexe.
  • Pour les types quantifiés: dequantize_op_quantize(logistic, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Autres exemples

carte

<ph type="x-smartling-placeholder">
</ph>

Sémantique

Applique une fonction de carte computation à inputs avec dimensions et génère un Tensor result.

Plus formellement, result[result_index] = computation(inputs...[result_index]).

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1-C4)
(I2) dimensions Constante de Tensor unidimensionnelle de type si64 (C3)
(I3) computation fonction (C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C4)

Contraintes

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation est de type (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>Ei = element_type(inputs[i]) et E' = element_type(result).

Exemples

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Autres exemples

maximum

Sémantique

Effectue une opération maximale au niveau des éléments sur les Tensors lhs et rhs et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: opérateur logique OU.
  • Pour les entiers: entier maximum.
  • Pour les nombres à virgule flottante: maximum d'IEEE-754.
  • Pour les nombres complexes: maximum lexicographique pour la paire (real, imaginary). Ordonner des nombres complexes implique une sémantique surprenante, C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir. pour cette opération (#560).
  • Pour les types quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Autres exemples

minimum

Sémantique

Effectue une opération minimale par élément sur les Tensors lhs et rhs et génère une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: l'opérateur logique AND.
  • Pour les entiers: entier minimum.
  • Pour les nombres à virgule flottante: minimum d'IEEE-754.
  • Pour les nombres complexes: minimum lexicographique pour la paire (real, imaginary). Ordonner des nombres complexes implique une sémantique surprenante, C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir. pour cette opération (#560).
  • Pour les types quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Autres exemples

multiplier

Sémantique

Effectue le produit élément par élément de deux Tensors lhs et rhs, et produit une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: l'opérateur logique AND.
  • Pour les entiers: multiplication d'entiers.
  • Pour les nombres à virgule flottante: multiplication d'IEEE-754.
  • Pour les nombres complexes: multiplication complexe.
  • Pour les types quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Autres exemples

negate

Sémantique

Effectue une négation par élément du Tensor operand et produit une result Tensor. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les entiers signés: négation des entiers.
  • Pour les entiers non signés: conversion de bits en entier signé, négation d'entier, conversion de bits à un entier non signé.
  • Pour les nombres à virgule flottante: negate d'IEEE-754.
  • Pour les nombres complexes: négation complexe.
  • Pour les types quantifiés: dequantize_op_quantize(negate, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Autres exemples

not

Sémantique

Effectue l'opérateur NOT au niveau des éléments du Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: l'opérateur logique NOT.
  • Pour les entiers: l'opérateur NOT au niveau du bit.

Arguments

Nom Type Contraintes
operand Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

Autres exemples

optimization_barrier

Sémantique

Assurez-vous que les opérations qui produisent l'operand sont exécutées avant toute opérations qui dépendent de result et empêchent les transformations de compilation de déplacer les opérations au-delà de la barrière. En dehors de cela, l'opération est une identité, par exemple result = operand.

Arguments

Nom Type Contraintes
operand nombre variadique de Tensors, Tensors ou jetons quantifiés par Tensor (C1)

Sorties

Nom Type Contraintes
result nombre variadique de Tensors, Tensors ou jetons quantifiés par Tensor (C1)

Contraintes

  • (C1) type(operand...) = type(result...).

Exemples

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Autres exemples

ou

Sémantique

Effectue l'opération OR par élément de deux Tensors lhs et rhs et génère une result Tensor. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: opérateur logique OU.
  • Pour les entiers: OR bit à bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier ou booléen (C1)
(I2) rhs Tensor de type entier ou booléen (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier ou booléen (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

Autres exemples

sortie

Sémantique

Écrit inputs dans le flux de sortie et génère un jeton result.

La sémantique de outfeed_config est définie par l'implémentation.

Entrées

Libellé Nom Type
(I1) inputs nombre variadique de Tensors ou de Tensors quantifiés
(I2) token token
(I3) outfeed_config constante de type string

Sorties

Nom Type
result token

Exemples

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Autres exemples

clavier

Sémantique

Développe operand avec une marge intérieure autour du Tensor et entre les éléments. du Tensor avec l'élément padding_value donné.

edge_padding_low et edge_padding_high spécifient la quantité de marge intérieure ajoutée. à la valeur basse (à côté de l'indice 0) et à la valeur élevée (à côté de l'indice le plus élevé) chaque dimension respectivement. La quantité de marge intérieure peut être négative, la valeur absolue de la marge intérieure négative indique le nombre d'éléments à supprimer de la dimension spécifiée.

interior_padding spécifie la quantité de marge intérieure ajoutée entre deux dans chaque dimension, qui ne peuvent pas être négatifs. Une marge intérieure intérieure avant le remplissage des bords, de sorte que le remplissage négatif du bord supprime les éléments de l'opérande avec remplissage en intérieur.

Plus formellement, result[result_index] est défini comme suit:

  • operand[operand_index] si result_index = edge_padding_low + operand_index * (interior_padding + 1)
  • padding_value dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C4)
(I2) padding_value Tensor à 0 dimensions ou Tensor quantifié par Tensor (C1)
(I3) edge_padding_low Constante de Tensor unidimensionnelle de type si64 (C1), (C4)
(I4) edge_padding_high Constante de Tensor unidimensionnelle de type si64 (C1), (C4)
(I5) interior_padding Constante de Tensor unidimensionnelle de type si64 (C2-C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C3-C6)

Contraintes

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemples

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Autres exemples

partition_id

Sémantique

Génère partition_id du processus actuel.

Sorties

Nom Type
result Tensor à 0 dimensions de type ui32

Exemples

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Autres exemples

Popcnt

Sémantique

Effectue un décompte par élément du nombre de bits défini dans le Tensor operand. et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Autres exemples

puissance

Sémantique

Effectue une interpolation par élément du Tensor lhs en fonction du Tensor rhs et et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les entiers: exponentielle d'entiers.
  • Pour les nombres à virgule flottante: pow d'IEEE-754.
  • Pour les nombres complexes: exponentielle complexe.
  • Pour les types quantifiés: dequantize_op_quantize(power, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Autres exemples

real

Sémantique

Elle extrait la partie réelle, au niveau des éléments, de operand et génère une result. Tensor. Plus formellement, pour chaque élément x: real(x) = is_complex(x) ? real_part(x) : x

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type complexe ou à virgule flottante (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • element_type(operand) dans les autres cas.

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Autres exemples

réception

Sémantique

Reçoit les données d'un canal avec channel_id et génère results.

Si is_host_transfer est défini sur true, l'opération transfère les données du hôte. Sinon, il transfère les données d'un autre appareil. Cela signifie que définies par l'implémentation. Cet indicateur duplique les informations fournies dans channel_type. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir. (#666)

results est constitué de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux des sorties distinctes pour plus de clarté (#670)

Entrées

Libellé Nom Type Contraintes
(I1) token token (C4)
(I2) channel_id constante de type si64
(I3) channel_type énumération de DEVICE_TO_DEVICE et HOST_TO_DEVICE (C1)
(I4) is_host_transfer constante de type i1 (C1)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C2-C4)

Contraintes

  • (C1) channel_type est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • HOST_TO_DEVICE si is_host_transfer = true,
    • DEVICE_TO_DEVICE dans les autres cas.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Exemples

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Autres exemples

reduce

Sémantique

Applique une fonction de réduction body à inputs et init_values le long de dimensions et produit des Tensors results.

L'ordre des réductions est défini par l'implémentation, ce qui signifie que body et init_values doit former un monoid pour garantir que l'opération produit le les mêmes résultats pour toutes les entrées dans toutes les implémentations. Cependant, cette condition n'est pas valable pour de nombreuses réductions populaires. Exemple : l'addition à virgule flottante pour body et zéro pour init_values ne forment pas un monooïde, car l'addition à virgule flottante n'est pas associative.

Plus formellement, results...[j0, ..., jR-1] = reduce(input_slices_converted), où:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], où : sont insérés à dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) pour une arborescence binaire schedule où: <ph type="x-smartling-placeholder">
      </ph>
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est une arborescence binaire complète définie par l'implémentation, dont l'ordre le balayage comprend: <ph type="x-smartling-placeholder">
      </ph>
    • input_slices_converted...[index] valeurs, pour tous les index de index_space(input_slices_converted) dans l'ordre lexicographique croissant sur index.
    • Entrepôt de quantité définie par l'implémentation init_values_converted aux positions définies par l'implémentation

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1-C4), (C6), (C7)
(I2) init_values nombre variadique de Tensors à 0 dimensions ou de Tensors quantifiés par Tensor (C2), (C3)
(I3) dimensions Constante de Tensor unidimensionnelle de type si64 (C4), (C5), (C7)
(I4) body fonction (C6)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C3), (C7), (C8)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)is_promotable(element_type(inputs[i]), Ei)
  • (C7) shape(results...) = shape(inputs...), à la différence près que la dimension Les tailles de inputs... correspondant à dimensions ne sont pas incluses.
  • (C8) element_type(results[i]) = Ei pour tous les i de [0,N).

Exemples

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Autres exemples

reduce_precision

Sémantique

Effectue une conversion de operand au niveau des éléments en un autre type à virgule flottante qui utilise exponent_bits et mantissa_bits, puis revient à la version d'origine à virgule flottante et génère un Tensor output.

Plus formellement:

  • Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine. à la valeur la plus proche représentable par mantissa_bits en utilisant roundToIntegralTiesToEven.
  • Ensuite, si les valeurs mantissa_bits sont inférieures au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués à mantissa_bits.
  • Ensuite, si les bits d'exposant du résultat intermédiaire ne correspondent pas au plage fournie par exponent_bits, le résultat intermédiaire déborde sur l'infini à l'aide du signe d'origine ou le dépassement de zéro à zéro à l'aide de la panneau d'origine.
  • Pour les types quantifiés, exécute dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)
(I2) exponent_bits constante de type si32 (C2)
(I3) mantissa_bits constante de type si32 (C3)

Sorties

Nom Type Contraintes
output Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Exemples

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Autres exemples

reduce_scatter

Sémantique

reduce_scatter

Dans chaque groupe de processus de la grille StableHLO, effectue une réduction, en utilisant computations sur les valeurs du Tensor operand de chaque processus, divise le résultat de la réduction avec scatter_dimension en plusieurs parties, puis disperse le résultat les parties divisées entre les processus pour produire le result.

L'opération divise la grille de processus StableHLO en process_groups, qui est définis comme suit:

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] pour les sender de process_group, où receiver_index = process_group.index(receiver).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension constante de type si64 (C1), (C2), (C8)
(I3) replica_groups Constante de Tensor bidimensionnelle de type si64 (C3-C5)
(I4) channel_id constante de type si64 (C6)
(I5) use_global_device_ids constante de type i1 (C6)
(I6) computation fonction (C7)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C8-C9)

Contraintes

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Si use_global_device_ids = true, alors channel_id > 0.
  • (C7) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>), où is_promotable(element_type(operand), E)
  • (C8) shape(result) = shape(operand), sauf: <ph type="x-smartling-placeholder">
      </ph>
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Autres exemples

reduce_window

Sémantique

Applique une fonction de réduction body aux fenêtres de inputs et init_values et génère results.

Le schéma suivant montre comment les éléments de results... sont calculés à partir de inputs... à l'aide d'un exemple concret.

reduce_window

Plus formellement, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (voir reduce) :

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values nombre variadique de Tensors à 0 dimensions ou de Tensors quantifiés par Tensor (C1), (C13)
(I3) window_dimensions Constante de Tensor unidimensionnelle de type si64 (C4), (C5), (C15)
(I4) window_strides Constante de Tensor unidimensionnelle de type si64 (C6), (C7), (C15)
(I5) base_dilations Constante de Tensor unidimensionnelle de type si64 (C8), (C9), (C15)
(I6) window_dilations Constante de Tensor unidimensionnelle de type si64 (C10), (C11), (C15)
(I7) padding Constante de Tensor bidimensionnelle de type si64 (C12), (C15)
(I8) body fonction (C13)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1), (C14-C16)

Contraintes

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)is_promotable(element_type(inputs[i]), Ei)
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows où: <ph type="x-smartling-placeholder">
      </ph>
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei pour l'ensemble des i de [0,N).

Exemples

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Autres exemples

reste

Sémantique

Effectue le reste par élément des Tensors lhs et diviseur rhs et et génère un Tensor result.

Plus formellement, le signe du résultat est tiré du dividende, et le la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur. Le reste est calculé comme suit : lhs - d * rhs, où d est calculé comme suit :

  • Pour les entiers: stablehlo.divide(lhs, rhs).
  • Pour les nombres à virgule flottante: division(lhs, rhs) conformément à la norme IEEE-754 avec attribut d'arrondi roundTowardZero
  • Nombres complexes: à déterminer (#997)
  • Pour les types quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Pour les types d'éléments à virgule flottante, cette opération contraste avec la méthode Opération remainder de la spécification IEEE-754 où d est une valeur intégrale le plus proche de la valeur exacte de lhs/rhs, avec des liens au nombre pair.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Autres exemples

replica_id

Sémantique

Génère replica_id du processus actuel.

Sorties

Nom Type
result Tensor à 0 dimensions de type ui32

Exemples

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Autres exemples

remodeler

Sémantique

Effectue un remodelage du Tensor operand en un Tensor result. Conceptuellement, il revient à conserver la même représentation canonique, mais peut être amenée la forme (p.ex., de tensor<2x3xf32> à tensor<3x2xf32> ou tensor<6xf32>.

Plus formellement, result[result_index] = operand[operand_index]result_index et operand_index ont la même position dans le texte lexicographique l'ordre de index_space(result) et index_space(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C3)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1-C3)

Contraintes

  • (C1) La valeur element_type(result) est donnée par: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), à l'exception de quantization_dimension(operand) et Sinon, quantization_dimension(result) peut être différent.
  • (C2) size(operand) = size(result).
  • (C3) Si is_per_axis_quantized(operand): <ph type="x-smartling-placeholder">
      </ph>
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Autres exemples

inverser

Sémantique

Inverse l'ordre des éléments dans operand le long de la dimensions spécifiée. et génère un Tensor result. Plus formellement, result[result_index] = operand[operand_index] où:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 si d dans dimensions.
  • operand_index[d] = result_index[d] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C3)
(I2) dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C3)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C3)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Exemples

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Autres exemples

rng

<ph type="x-smartling-placeholder">
</ph>

Sémantique

Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution et produit un Tensor result d'une forme donnée shape.

Si la valeur est rng_distribution = UNIFORM, les nombres aléatoires sont générés. suivant la distribution uniforme sur l'intervalle [a, b). Si la valeur est a >= b, le comportement n'est pas défini.

Si la valeur est rng_distribution = NORMAL, les nombres aléatoires sont générés. suivant la distribution normale avec une moyenne = a et l'écart type = b. Si la valeur est b < 0, le comportement n'est pas défini.

La manière exacte dont les nombres aléatoires sont générés est définie par l'implémentation. Pour Par exemple, ils peuvent être déterministes ou non, et utiliser ou non état caché.

Lors des conversations avec de nombreux intervenants, cette opération s'est révélée aussi efficace Nous prévoyons donc de les supprimer à l'avenir. (#597)

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor à zéro dimension de type entier, booléen ou à virgule flottante (C1), (C2)
(I2) b Tensor à zéro dimension de type entier, booléen ou à virgule flottante (C1), (C2)
(I3) shape Constante de Tensor unidimensionnelle de type si64 (C3)
(I4) rng_distribution énumération de UNIFORM et NORMAL (C2)

Sorties

Nom Type Contraintes
result Tensor de type entier, booléen ou à virgule flottante (C1-C3)

Contraintes

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Si rng_distribution = NORMAL, alors is_float(a).
  • (C3) shape(result) = shape.

Exemples

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Sémantique

Renvoie une output remplie de bits aléatoires uniformes et un état de sortie mis à jour. output_state à l'aide de l'algorithme de génération de nombres pseudo-aléatoires rng_algorithm avec un état initial initial_state. La sortie est garantie fonction déterministe de initial_state, mais sa valeur n'est pas garantie déterministe entre les implémentations.

rng_algorithm est l'un des éléments suivants :

  • DEFAULT: algorithme défini par l'implémentation.
  • THREE_FRY: variante définie par l'implémentation de l'algorithme Threefry*.
  • PHILOX: variante définie par l'implémentation de l'algorithme Philox*.

* Voir Salmon et al. SC 2011. Nombres aléatoires parallèles: aussi simples que 1, 2, 3.

Entrées

Libellé Nom Type Contraintes
(I1) rng_algorithm énumération de DEFAULT, THREE_FRY et PHILOX (C2)
(I2) initial_state Tensor unidimensionnel de type ui64 (C1), (C2)

Sorties

Nom Type Contraintes
output_state Tensor unidimensionnel de type ui64 (C1)
output Tensor de type entier ou à virgule flottante

Contraintes

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • définie par l'implémentation si rng_algorithm = DEFAULT.
    • 2 si rng_algorithm = THREE_FRY.
    • 2 ou 3 si rng_algorithm = PHILOX.

Exemples

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Sémantique

Effectue un arrondi au niveau des éléments vers l'entier le plus proche, ce qui permet de séparer les relations à partir de zéro sur le Tensor operand et produit un Tensor result. Implémentations l'opération roundToIntegralTiesToAway de la spécification IEEE-754. Pour quantifiés, effectue dequantize_op_quantize(round_nearest_afz, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Autres exemples

round_nearest_even

Sémantique

Effectue un arrondi au niveau des éléments vers le nombre entier le plus proche, ce qui permet de casser les liens vers l'entier pair sur le Tensor operand et produit une result Tensor. Elle met en œuvre l'opération roundToIntegralTiesToEven de la norme IEEE-754. spécifique. Pour les types quantifiés, exécute dequantize_op_quantize(round_nearest_even, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Autres exemples

rsqrt

Sémantique

Effectue une opération de racine carrée réciproque par élément sur le Tensor operand et et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: rSqrt d'IEEE-754.
  • Pour les nombres complexes: racine carrée réciproque complexe.
  • Pour les types quantifiés: dequantize_op_quantize(rsqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Autres exemples

disperser

Sémantique

Génère des Tensors results égaux aux Tensors inputs, sauf que plusieurs tranches spécifiées par scatter_indices sont mises à jour avec les valeurs updates avec update_computation.

Le schéma suivant montre comment les éléments de updates... sont mappés sur les éléments de results... à l'aide d'un exemple concret. Le diagramme choisit quelques exemples updates... indexe et explique en détail les indices results... qu'il correspondent.

disperser

Plus formellement, pour tous les update_index de index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • scatter_indices[si0, ..., :, ..., siN]si sont des individus dans update_scatter_index et : est insérée au niveau Index index_vector_dim, si index_vector_dim < rank(scatter_indices)
    • [scatter_indices[update_scatter_index]] dans les autres cas.
  • Pour d_input dans axes(inputs[0]), <ph type="x-smartling-placeholder">
      </ph>
    • full_start_index[d_input] = start_index[d_start] si d_input = scatter_dims_to_operand_dims[d_start]
    • full_start_index[d_input] = 0 dans les autres cas.
  • Pour d_input dans axes(inputs[0]), <ph type="x-smartling-placeholder">
      </ph>
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] si d_input = input_batching_dims[i_batching] et d_start = scatter_indices_batching_dims[i_batching]
    • full_batching_index[d_input] = 0 dans les autres cas.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN]wi sont des individus dans update_window_index, et 0 est inséré au niveau des index de inserted_window_dims et input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

Par conséquent, results = exec(schedule, inputs), où:

  • schedule est une permutation définie par l'implémentation de index_space(updates[0])
  • exec([update_index, ...], results) = exec([...], updated_results) où: <ph type="x-smartling-placeholder">
      </ph>
    • Si result_index est dans les limites de shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results est une copie de results avec results...[result_index] définie sur updated_values....
    • Sinon, procédez comme suit :
    • updated_results = results.
  • exec([], results) = results.

Si indices_are_sorted est défini sur true, l'implémentation peut supposer que Les éléments scatter_indices sont triés par rapport à scatter_dims_to_operand_dims, sinon le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Si unique_indices est défini sur true, l'implémentation peut supposer que tous les Les index result_index dispersés sont uniques. Si unique_indices correspond à true, mais les index dispersés ne sont pas uniques, alors le comportement est non défini.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices Tensor de type entier (C4), (C15), (C19), (C22)
(I3) updates nombre variadique de Tensors ou Tensors quantifiés par Tensor (C3-C6), (C8)
(I4) update_window_dims Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims Constante de Tensor unidimensionnelle de type si64 (C14-C18).
(I8) scatter_dims_to_operand_dims Constante de Tensor unidimensionnelle de type si64 (C19-C21)
(I9) index_vector_dim constante de type si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted constante de type i1
(I11). unique_indices constante de type i1
(I12). update_computation fonction (C23)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C24-C25)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) <ph type="x-smartling-placeholder">
      </ph>
    • size(input_batching_dims)`.
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) où: <ph type="x-smartling-placeholder">
      </ph>
    • update_scatter_dim_sizes = shape(scatter_indices) sauf que la dimension de scatter_indices correspondant à index_vector_dim n'est pas inclus.
    • update_window_dim_sizes <= shape(inputs[0]) sauf que les dimensions de inputs[0] correspondant à inserted_window_dims et input_batching_dims ne sont pas inclus.
    • combine place update_scatter_dim_sizes sur les axes correspondant à update_scatter_dims et update_window_dim_sizes au niveau des axes correspondant à update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), où is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei pour l'ensemble des i de [0,N).

Exemples

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

Autres exemples

select

Sémantique

Elle génère un Tensor result où chaque élément est sélectionné à partir de on_true ou Tensor on_false basé sur la valeur de l'élément correspondant de pred. Plus formellement, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], où pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Pour les types quantifiés, exécute dequantize_select_quantize(pred, on_true, on_false, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) pred Tensor de type i1 (C1)
(I2) on_true Tensor ou Tensor quantifié par Tensor (C1-C2)
(I3) on_false Tensor ou Tensor quantifié par Tensor (C2)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C2)

Contraintes

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Exemples

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Autres exemples

select_and_scatter

Sémantique

Distribue les valeurs du Tensor source à l'aide de scatter en fonction du résultat de la fonction reduce_window du Tensor input à l'aide de select et produit un Tensor result.

Le schéma suivant montre comment les éléments de result sont calculés à partir de operand et source à l'aide d'un exemple concret.

select_and_scatter

Plus formellement:

  • selected_values = reduce_window_without_init(...) avec les entrées suivantes:

    • inputs = [operand].
    • window_dimensions, window_strides et padding, qui sont utilisés tels quels.
    • base_dilations = windows_dilations = 1.
    • body est défini comme suit:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    E = element_type(operand) et reduce_window_without_init fonctionnent exactement comme reduce_window, sauf que le schedule de l'instance reduce (voir reduce) n'inclut pas les valeurs init. Il est actuellement non spécifié, que se passe-t-il si la fenêtre correspondante ne comporte pas de valeurs ? (#731)

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) où:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index si selected_values[source_index] comporte l'élément operand à partir de operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C4), (C6), (C8-C11)
(I2) source Tensor ou Tensor quantifié par Tensor (C1), (C2)
(I3) init_value Tensor à 0 dimensions ou Tensor quantifié par Tensor (C3)
(I4) window_dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C5)
(I5) window_strides Constante de Tensor unidimensionnelle de type si64 (C2), (C6), (C7)
(I6) padding Constante de Tensor bidimensionnelle de type si64 (C2), (C8)
(I7) select fonction (C9)
(I8) scatter fonction (C10)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C11-C12)

Contraintes

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows où: <ph type="x-smartling-placeholder">
      </ph>
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select est de type (tensor<E>, tensor<E>) -> tensor<i1>, où E = element_type(operand)
  • (C10) scatter est de type (tensor<E>, tensor<E>) -> tensor<E>, où is_promotable(element_type(operand), E)
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Exemples

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Autres exemples

envoyer

Sémantique

Envoie inputs à un canal channel_id et génère un jeton result.

Si is_host_transfer est défini sur true, l'opération transfère les données vers hôte. Sinon, il transfère les données vers un autre appareil. Cela signifie que définies par l'implémentation. Cet indicateur duplique les informations fournies dans channel_type. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir. (#666)

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variadique de Tensors ou de Tensors quantifiés
(I2) token token
(I3) channel_id constante de type si64
(I4) channel_type énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST (C1)
(I5) is_host_transfer constante de type i1 (C1)

Sorties

Nom Type
result token

Contraintes

  • (C1) channel_type est défini comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • DEVICE_TO_HOST si is_host_transfer = true,
    • DEVICE_TO_DEVICE dans les autres cas.

Exemples

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Autres exemples

shift_left

Sémantique

Effectue un décalage vers la gauche par élément sur le Tensor lhs en fonction du nombre de rhs de bits et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier (C1)
(I2) rhs Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Autres exemples

shift_right_arithmetic

Sémantique

Effectue un décalage arithmétique vers la droite au niveau des éléments sur le Tensor lhs en rhs de bits et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier (C1)
(I2) rhs Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Autres exemples

shift_right_logical

Sémantique

Effectue un décalage logique vers la droite au niveau des éléments sur le Tensor lhs par rhs de bits et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier (C1)
(I2) rhs Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Autres exemples

signe "=".

Sémantique

Renvoie le signe de l'élément operand au niveau des éléments et produit un Tensor result. Plus formellement, pour chaque élément x, la sémantique peut être exprimée à l'aide de Syntaxe Python comme suit:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Pour les types quantifiés, exécute dequantize_op_quantize(sign, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Autres exemples

sinus

Sémantique

Effectue une opération sinus par élément sur le Tensor operand et génère une result. Tensor. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: sin d'IEEE-754.
  • Pour les nombres complexes: sinus complexe.
  • Pour les types quantifiés: dequantize_op_quantize(sine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Autres exemples

slice

Sémantique

Extrait une tranche de operand à l'aide d'index de départ calculés de manière statique. et génère un Tensor result. start_indices contiennent les index de départ de la tranche de chaque dimension, limit_indices, contient les index de fin ; (exclusif) pour le secteur de chaque dimension et strides contiennent les pas pour chaque dimension.

Plus formellement, result[result_index] = operand[operand_index]operand_index = start_indices + result_index * strides

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C3), (C5)
(I2) start_indices Constante de Tensor unidimensionnelle de type si64 (C2), (C3), (C5)
(I3) limit_indices Constante de Tensor unidimensionnelle de type si64 (C2), (C3), (C5)
(I4) strides Constante de Tensor unidimensionnelle de type si64 (C2), (C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Exemples

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Autres exemples

trier

Sémantique

Trie les tranches unidimensionnelles de inputs le long de la dimension dimension. selon un comparator et produit results.

Contrairement aux entrées similaires dans d'autres opérations, dimension autorise les valeurs négatives, à l'aide de la sémantique décrite ci-dessous. À l'avenir, cela pourra être refusé pour des raisons de cohérence (#1377)

Si is_stable est "true", le tri est stable, c'est-à-dire l'ordre relatif des éléments considérés comme égaux par le comparateur est conservé. Pour l'étui où il n'y a qu'une seule entrée, deux éléments e1 et e2 sont considérés comme par le comparateur si et seulement si comparator(e1, e2) = comparator(e2, e1) = false Consultez la formalisation ci-dessous. sur la généralisation à plusieurs entrées.

Plus formellement, pour tous les result_index de index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1]riN sont des individus éléments dans result_index, et : est inséré au niveau de adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • sort trie un segment unidimensionnel dans l'ordre non décroissant, que comparator_together renvoie true si l'argument de gauche est inférieur au second argument de droite.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1-C5)
(I2) dimension constante de type si64 (C4)
(I3) is_stable constante de type i1
(I4) comparator fonction (C5)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C2), (C3)

Contraintes

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, où R = rank(inputs[0]).
  • (C5) comparator a un type (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, où Ei = element_type(inputs[i]).

Exemples

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Autres exemples

carré

Sémantique

Effectue une opération de racine carrée par élément sur le Tensor operand et produit une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: squareRoot d'IEEE-754.
  • Pour les nombres complexes: racine carrée complexe.
  • Pour les types quantifiés: dequantize_op_quantize(sqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Autres exemples

subtract

Sémantique

Effectue la soustraction par élément de deux Tensors lhs et rhs, et produit une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les entiers: soustraction d'entiers.
  • Pour les nombres à virgule flottante: subtraction d'IEEE-754.
  • Pour les nombres complexes: soustraction complexe.
  • Pour les types quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Autres exemples

tan

Sémantique

Effectue une opération de tangente par élément sur le Tensor operand et produit une Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: tan d'IEEE-754.
  • Pour les nombres complexes: tangente complexe.
  • Pour les types quantifiés: dequantize_op_quantize(tan, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

Autres exemples

Tanh

Sémantique

Effectue une opération de tangente hyperbolique par élément sur le Tensor operand et et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante: tanh d'IEEE-754.
  • Pour les nombres complexes: tangente hyperbolique complexe.
  • Pour les types quantifiés: <ph type="x-smartling-placeholder">
      </ph>
    • dequantize_op_quantize(tanh, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Autres exemples

transposer

Sémantique

Permute les dimensions du Tensor operand à l'aide de permutation et génère une Tensor result. Plus formellement, result[result_index] = operand[operand_index]result_index[d] = operand_index[permutation[d]].

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C4)
(I2) permutation Constante de Tensor unidimensionnelle de type si64 (C2-C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1), (C3-C4)

Contraintes

  • (C1) La valeur element_type(result) est donnée par: <ph type="x-smartling-placeholder">
      </ph>
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), à l'exception de quantization_dimension(operand) et Sinon, quantization_dimension(result) peut être différent.
  • (C2) permutation est une permutation de range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Si is_per_axis_quantized(result), alors quantization_dimension(operand) = permutation(quantization_dimension(result))

Exemples

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Autres exemples

triangular_solve

Sémantique

Résoudre des lots de systèmes d'équations linéaires avec des triangles inférieurs ou supérieurs matrices de coefficients.

Plus formellement, étant donné a et b, result[i0, ..., iR-3, :, :] est la solution. vers op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] quand left_side est true ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] lorsque left_side correspond à false, ce qui résout la variable xop(a) est déterminé par transpose_a, qui peut prendre l'une des valeurs suivantes:

  • NO_TRANSPOSE: effectue l'opération à l'aide de a tel quel.
  • TRANSPOSE: effectue une opération sur la transposition de a.
  • ADJOINT: effectue l'opération sur la transposition conjuguée de a.

Les données d'entrée sont lues uniquement à partir du triangle inférieur de a, si lower est true ou triangle supérieur de a, sinon. Les données de sortie sont renvoyées dans le même triangle. les valeurs de l'autre triangle sont définies par l'implémentation.

Si unit_diagonal est "true", l'implémentation peut supposer que la diagonale éléments de a sont égaux à 1. Sinon, le comportement n'est pas défini.

Pour les types quantifiés, exécute dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1-C3)
(I2) b Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1-C4)
(I3) left_side constante de type i1 (C3)
(I4) lower constante de type i1
(I5) unit_diagonal constante de type i1
(I6) transpose_a énumération de NO_TRANSPOSE, TRANSPOSE et ADJOINT

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) La relation entre shape(a) et shape(b) est définie comme suit: <ph type="x-smartling-placeholder">
      </ph>
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Exemples

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

<ph type="x-smartling-placeholder">
</ph>

Sémantique

Génère un tuple result à partir des valeurs val.

Entrées

Libellé Nom Type Contraintes
(I1) val nombre variadique de valeurs (C1)

Sorties

Nom Type Contraintes
result tuple (C1)

Contraintes

  • (C1) result est de type tuple<E0, ..., EN-1>, où Ei = type(val[i]).

Exemples

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Autres exemples

uniform_dequantize

Sémantique

Effectue la conversion élément par élément du Tensor quantifié operand en Tensor à virgule flottante result en fonction des paramètres de quantification définis par le type operand.

Plus formellement, result = dequantize(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Exemples

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Sémantique

Effectue une conversion élément par élément du Tensor à virgule flottante ou quantifié operand sur un Tensor quantifié result en fonction de la quantification les paramètres définis par le type result.

Plus formellement,

  • Si la valeur est is_float(operand): <ph type="x-smartling-placeholder">
      </ph>
    • result = quantize(operand, type(result)).
  • Si la valeur est is_quantized(operand): <ph type="x-smartling-placeholder">
      </ph>
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor quantifié (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Exemples

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

tandis que

Sémantique

Génère le résultat de l'exécution de la fonction body zéro fois ou plus, tandis que la La fonction cond génère true. Plus formellement, la sémantique peut être exprimée à l'aide de la syntaxe Python suivante:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Le comportement d'une boucle infinie est à déterminer (#383)

Entrées

Libellé Nom Type Contraintes
(I1) operand nombre variadique de Tensors, Tensors quantifiés ou jetons (C1-C3)
(I2) cond fonction (C1)
(I3) body fonction (C2)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C3)

Contraintes

  • (C1) cond est de type (T0, ..., TN-1) -> tensor<i1>, où Ti = type(operand[i])
  • (C2) body est de type (T0, ..., TN-1) -> (T0, ..., TN-1), où Ti = type(operand[i])
  • (C3) type(results...) = type(operand...).

Exemples

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Autres exemples

XOR

Sémantique

Effectue une opération XOR par élément de deux Tensors lhs et rhs, et produit une result Tensor. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: opérateur logique "XOR".
  • Pour les entiers: opération XOR (OU exclusif) bit à bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type booléen ou entier (C1)
(I2) rhs Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Autres exemples

Dialect Interop

À l'heure actuelle, les programmes StableHLO contenant parfois des opérations qui ne sont pas définis par StableHLO.

Module, fonction, appel et retour

StableHLO utilise les opérations MLIR en amont pour ModuleOp, FuncOp, CallOp et RenvoiOp. Cela a permis d'améliorer l'interopérabilité avec les machines MLIR existantes, car de nombreux des passes utiles sont écrites ciblant FuncOp et ModuleOp, et de nombreuses compilations les pipelines s'attendent à ce que ces opérations soient présentes. Les garanties de compatibilité totale sont appliquée à ces opérations. En cas de changement dans ces opérations incompatible (par exemple, suppression), des équivalents StableHLO seront ajoutés pour préserver et la compatibilité avec d'autres appareils.

CHLO

L'ensemble d'opérations CHLO contient des opérations de niveau supérieur qui se décomposent en StableHLO. Il n'existe actuellement aucune garantie de compatibilité pour CHLO. Compatibilité les garanties chlo-legalize-to-stablehlo doit être utilisée avant la sérialisation.

Opérations de tracé

Il est courant au sein de la communauté d'avoir recours à certaines opérations Dialectes MLIR utilisés dans les programmes StableHLO dynamiques pour effectuer des calculs de forme. Le plus souvent, il s'agit du dialecte shape. Opérations telles que shape_of ou num_elements, dialecte tensor comme dim ou from_elements, et le type index intégré.

Le document Dynamism RFC > O2 les indique comme étant hors du champ d'application, mais la prise en charge des types index est incluses à des fins d'interopérabilité. Il n'existe aucune garantie de compatibilité opérations ou types. La commande shape-legalize-to-stablehlo peut servir à convertir ces opérations en opérations StableHLO entièrement compatibles.

Opérations obsolètes

Plusieurs opérations StableHLO ont été héritées MHLO qui sont obsolètes et seront bientôt supprimés de StableHLO. Les détails complets de ces sont indiquées dans le fichier StableHLO v1.0 Cleanup #2283. Le problème de l'outil de suivi pour ces abandons est le n° 2340.

Ces opérations appartiennent à plusieurs catégories:

  • "Pas dans HLO" des opérations StableHLO, car celles-ci faisaient initialement partie l'opset StableHLO, mais on a par la suite considéré qu'il ne s'adaptait pas correctement: broadcast, create_token, cross-replica-sum, dot, einsum torch_index_select, unary_einsum (n° 3)
  • Opérations inutilisées : ces opérations ont peut-être été utiles à un moment donné, mais les opérations étaient soit sous-développées, soit les pipelines utilisant ces opérations ont été refactorisées pour ne plus en avoir besoin. Cela inclut map, tuple (#598), comparaisons get_tuple_element, rng, complex #560, et la convolution window_reversal (#1181).

Certaines de ces opérations peuvent être facilement supprimées car elles peuvent être exprimées à l’aide de opérations existantes (broadcast, create_token, cross-replica-sum, dot, unary_einsum) et seront supprimées après le délai de compatibilité existant (6 mois). D'autres sont toujours en cours de suppression (einsum, get_tuple_element, map, torch_index_select rng, tuple et complex comparaisons, window_reversal). En attente des commentaires de la communauté, ces opérations seront soit supprimées, soit ajoutées à la spécification avec une prise en charge totale. Jusqu'au si ces contrats à terme sont connus, ils ne sont garantis que 6 mois de compatibilité.

Exécution

Exécution séquentielle

Un programme StableHLO est exécuté en fournissant des valeurs d'entrée à la fonction main et le calcul des valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées Exécuter le graphe des opérations en mode root dans l'opération return correspondante

L'ordre d'exécution est défini par l'implémentation, tant qu'il est aligné avec Dataflow, c'est-à-dire si les opérations sont exécutées avant leur utilisation. Dans StableHLO, les opérations à effets secondaires consomment un seul jeton et en produisent un (plusieurs jetons être multiplexé en un seul jeton via after_all), de sorte que l'ordre d'exécution du côté est également aligné sur Dataflow. Par exemple, dans le programme ci-dessous, il existe deux ordres d'exécution possibles: %0%1%2return et %1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Plus formellement, un processus StableHLO est une combinaison des éléments suivants: 1) un programme StableHLO, 2) des états des opérations (pas encore exécuté), (déjà exécuté)) et 3) les valeurs intermédiaires sur lesquelles le processus travaille. Le processus commence par les valeurs d'entrée de la fonction main, le graphique des opérations mettant à jour les états et les valeurs intermédiaires, se termine par des valeurs de sortie. Plus de formalisation à déterminer (#484)

Exécution parallèle

Les programmes StableHLO peuvent être exécutés en parallèle et organisés dans une grille de processus 2D. de num_replicas par num_partitions, qui sont tous deux de type ui32.

Dans la grille de processus StableHLO, num_replicas * num_partitions de StableHLO processus s'exécutent en même temps. Chaque processus a une process_id = (replica_id, partition_id), où replica_id dans replica_ids = range(num_replicas) et partition_id dans partition_ids = range(num_partitions), qui ont toutes les deux saisissez ui32.

La taille de la grille de processus est connue de manière statique pour chaque programme (dans le à l'avenir, nous prévoyons d'en faire une partie explicite des programmes StableHLO. #650) et la position au sein de la grille de processus est connue de manière statique pour chaque processus. Chaque processus comporte accès à sa position dans la grille de processus via replica_id et partition_id opérations

Au sein de la grille des processus, les programmes peuvent tous être identiques (dans la section Programme, Données multiples" style), peuvent tous être différents (dans la section "Programmes, Données multiples" ou un autre style d'annonce. À l'avenir, nous prévoyons pour permettre la définition de programmes StableHLO parallèles dans d'autres idiomes, y compris GSPMD (#619).

Dans la grille de processus, les processus sont pour la plupart indépendants les uns des autres : ont des états d'opération distincts et des valeurs d'entrée/intermédiaire/sortie distinctes. La plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.

Étant donné que l'exécution de la plupart des opérations n'utilise que des valeurs de la même processus, il est généralement sans ambiguïté de désigner ces valeurs par leur nom. Toutefois, cette définition n'est pas suffisante pour décrire la sémantique des opérations collectives. qui donne lieu à la notation name@process_id pour faire référence à la valeur name au cours d'un processus particulier. De ce point de vue, un name non qualifié peut être est considéré comme un raccourci pour name@(replica_id(), partition_id())).

L'ordre d'exécution entre les processus est défini par l'implémentation, à l'exception de synchronisation introduite par la communication point à point et les opérations collectives comme décrit ci-dessous.

Communication point à point

Les processus StableHLO peuvent communiquer entre eux via Canaux StableHLO. Une chaîne est représentée par un identifiant positif de type. si64 Grâce à diverses opérations, il est possible d'envoyer des valeurs aux canaux les recevoir des chaînes.

Structuration plus poussée, par exemple d'où viennent ces identifiants de critères, les processus et les programmes en prennent conscience et le type de synchronisation introduites par eux, est à déterminer (#484)

Communication en flux continu

Chaque processus StableHLO a accès à deux interfaces de streaming:

  • InFeed, qui peut être lu.
  • Flux de sortie dans lequel des opérations d'écriture peuvent être effectuées.

Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et donc des processus, les flux d'entrée et de sortie définies par l'implémentation.

Structuration plus poussée, par exemple comment la communication en flux influence l'exécution et le type de synchronisation qu'il introduit, est à déterminer. (#484)

Opérations collectives

StableHLO comporte six opérations collectives: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute et reduce_scatter Toutes ces opérations divisent les processus dans le processus StableHLO en groupes de processus StableHLO et exécuter un calcul commun dans chaque groupe de processus, indépendamment des autres groupes de processus.

Au sein de chaque groupe de processus, des opérations collectives peuvent introduire une synchronisation de sécurité. Structuration plus poussée, par exemple à savoir quand exactement la synchronisation se produit, comment exactement les processus parviennent à cet obstacle, et ce qui se passe si ce n'est pas le cas, à déterminer (#484)

Si le groupe de processus implique une communication entre partitions, c'est-à-dire qu'il existe processus du groupe de processus dont les ID de partition sont différents, puis l'exécution de l'opération collective a besoin d'une chaîne, et l'opération collective doit fournir une channel_id à inclure de type si64. La communication entre instances répliquées n'a pas besoin canaux de distribution.

Les calculs effectués par les opérations collectives sont spécifiques à des opérations individuelles. et sont décrites dans les sections dédiées aux opérations individuelles ci-dessus. Cependant, les stratégies dans laquelle la grille de processus est divisée en groupes de processus. Ces groupes sont partagés entre ces opérations. et sont décrits dans cette section. Plus formellement, StableHLO prend en charge en suivant quatre stratégies.

cross_replica

Seules les communications multi-instances répliquées ont lieu au sein de chaque groupe de processus. Ce utilise replica_groups (une liste de listes d'ID d'instances répliquées) et calcule un produit cartésien de replica_groups par partition_ids. replica_groups doit comporter des éléments uniques et couvrir tous les replica_ids. Plus formellement, en utilisant Syntaxe Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica produira [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]

cross_partition

Seules les communications entre partitions ont lieu au sein de chaque groupe de processus. Ce utilise partition_groups, une liste de listes d'ID de partition, et calcule un produit cartésien de partition_groups par replica_ids. partition_groups doit comporter des éléments uniques et couvrir tous les partition_ids. Plus formellement, en utilisant la syntaxe Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour partition_groups = [[0, 1]] et num_replicas = 4, cross_partition produira [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]

cross_replica_and_partition

Des communications entre instances répliquées et entre partitions peuvent avoir lieu au sein de chaque de processus. Cette stratégie utilise replica_groups, une liste de listes d'instances répliquées - et calcule les produits cartésiens de chaque replica_group par partition_ids replica_groups doit comporter des éléments uniques et tous les couvrir replica_ids Plus formellement, en utilisant la syntaxe Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica_and_partition produira [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]

flattened_ids

Cette stratégie utilise flattened_id_groups, une liste de listes "aplaties" les identifiants de processus sous la forme replica_id * num_partitions + partition_id ; les transforme en identifiants de processus. flattened_id_groups doit comporter des éléments uniques et couvrent tous les process_ids. Plus formellement, en utilisant la syntaxe Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 et num_partitions = 2, flattened_ids produira [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]

Précision

Pour le moment, StableHLO n'offre aucune garantie de précision numérique, mais cela pourrait changer à l'avenir (#1156)

Sémantique d'exécution d'une opération quantifiée

L'interprétation des opérations StableHLO quantifiées peut varier en fonction du la configuration requise et les capacités matérielles. Par exemple, certains matériels peuvent choisir interpréter les opérations quantifiées à l'aide d'une fonction "dequantize, effectuer des opérations à virgule flottante et enfin l'opération de quantification ». stratégie. D'autres peuvent effectuer l'intégralité avec des calculs arithmétiques entiers. Par conséquent, l'interprétation de des opérations StableHLO quantifiées est exclusivement déterminé par la clé la mise en œuvre. Interprétation de la quantification hybride (#1575) doit être basée sur sa sémantique, telle qu'indiquée dans la spécification (via 1792).

Erreurs

Les programmes StableHLO sont validés par un vaste ensemble de contraintes pour des opérations individuelles, ce qui exclut de nombreuses classes d'erreurs avant l'exécution. Toutefois, des conditions d'erreur restent possibles, par exemple via des dépassements d'entiers, accès hors limites, etc. Sauf indication contraire explicite, toutes ces erreurs peut entraîner un comportement défini par l'implémentation, mais cela peut changer au niveau (#1157)

Exceptions à virgule flottante

À titre d'exception à cette règle, les exceptions à virgule flottante dans les programmes StableHLO ont un comportement bien défini. Les opérations entraînant des exceptions définies par Norme IEEE-754 (opération non valide, division par zéro, dépassement, dépassement, dépassement ou exceptions inexactes) génèrent des résultats par défaut (tels que définis dans la norme) et poursuivre l'exécution sans générer l'indicateur d'état correspondant ; similaire à Traitement des exceptions raiseNoFlag par rapport à la norme. Exceptions pour les images non standards (par exemple, des fonctions arithmétiques complexes et certaines fonctions transcendantes) définies par l'implémentation.

Incohérences au niveau des formes

StableHLO accepte les Tensors de forme dynamique. Cependant, les formes doivent s'accorder l'environnement d'exécution. Sinon, le comportement n'est pas défini. StableHLO ne fait pas explicitement fournissent une opération qui peut affirmer qu'un Tensor a une forme donnée au moment de l'exécution. La génération du code correct relève de la responsabilité du producteur.

À titre d'exemple, le programme ci-dessous est correct. Cependant, au moment de l'exécution, les formes exactes de %arg0 et %arg1 doivent être identiques. Sinon, la le comportement du programme n'est pas défini:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

Pour décrire la syntaxe, ce document utilise le type de produit ISO modifié d'EBNF (ISO/IEC 14977:1996, Wikipédia), avec deux modifications: 1) les règles sont définies à l'aide de ::= au lieu de =,

2) La concaténation est exprimée à l'aide de la juxtaposition plutôt que de ,.

Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constantes" et "Opérations") nous utilisons des formules basées sur la syntaxe Python, avec prise en charge pour exprimer de manière concise les opérations de tableau, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans de rares cas, lorsque de plus grands extraits de code nous utilisons la syntaxe Python vanilla qui est toujours introduite explicitement.

Formules

Examinons le fonctionnement des formules à partir d'un exemple tiré de dot_general. spécifique. L'une des contraintes de cette opération se présente comme suit: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)

Les noms utilisés dans cette formule proviennent de deux sources: 1) des fonctions globales, Exemple : dim, 2) définitions des membres de l'élément du programme correspondant : Entrées lhs, lhs_batching_dimensions, rhs et rhs_batching_dimensions défini dans la section "Entrées" de dot_general.

Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python avec certaines extensions axées sur la concision. Pour donner un sens à la formule, transformons en syntaxe Python vanilla.

A) Dans ces formules, nous utilisons = pour représenter l'égalité. La première étape pour obtenir la syntaxe Python, remplacez = par ==, comme suit: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)

B) De plus, ces formules prennent en charge les points de suspension (...) qui transforment les expressions scalaires en expressions de Tensor. En résumé, f(xs...) signifie plus ou moins "pour chaque x scalaire dans le Tensor xs, calculer une valeur f(x) scalaire, puis renvoyer toutes les valeurs ces résultats scalaires ensemble sous la forme d'un résultat de Tensor". Dans la syntaxe Python vanilla, notre exemple de formule devient: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]

Grâce aux points de suspension, il est souvent possible d'éviter de travailler au niveau des scalaires individuels. Toutefois, dans certains cas délicats, les données la syntaxe peut être utilisée comme dans la formule start_indices[bi0, ..., :, ..., biN] de la spécification gather. Au service de la concision, nous ne faisons pas fournir un formalisme exact pour traduire cette syntaxe en Python vanilla, dans espère qu'elle sera toujours compréhensible de manière intuitive au cas par cas. Veuillez nous indiquer si certaines formules spécifiques semblent opaques et nous essaierons de les améliorer.

De plus, vous remarquerez que les formules utilisent des points de suspension pour développer toutes sortes de listes, des Tensors, des listes de Tensors (p.ex., qui peuvent provenir d'un variadique (nombre de Tensors), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas le formalisme (par exemple, les listes ne font même pas partie du système de types StableHLO) et s’appuient plutôt sur une compréhension intuitive.

C) La dernière notation remarquable que nous utilisons est implicite la diffusion d'annonces. Bien que l'opset StableHLO ne prenne pas en charge la diffusion implicite, des formules, également au service de la concision. En résumé, si une valeur scalaire est utilisée dans un contexte où un Tensor est attendu, le scalaire est diffusé à la forme attendue.

Pour poursuivre l'exemple dot_general, voici une autre contrainte: 0 <= lhs_batching_dimensions < rank(lhs) Tel que défini dans les dot_general spécification, lhs_batching_dimensions est un Tensor, mais 0 et Les rank(lhs) sont des scalaires. Après avoir appliqué la diffusion implicite, la formule devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Lorsqu'elle est appliquée à une opération dot_general spécifique, cette formule correspond à un Tensor de valeurs booléennes. Lorsque des formules sont utilisées comme contraintes, le la contrainte indique si la formule renvoie la valeur true ou un Tensor qui ne contient que des éléments true.

Noms

Dans les formules, la portée lexicale inclut: 1) les fonctions globales, 2) les définitions de membres,

3) définitions locales. Vous trouverez ci-dessous la liste des fonctions globales. La liste de définitions d'éléments dépend de l'élément du programme auquel la notation est appliqué à:

  • Pour les opérations, les définitions des membres incluent les noms introduits dans "Entrées" et "Sorties" .
  • Pour tout le reste, les définitions des membres incluent les parties structurelles du élément du programme, portant le nom des non-terminaux EBNF correspondants. La plupart des le nom de ces parties structurelles est obtenu en convertissant Noms des non-terminaux à utiliser avec snake case (par exemple, IntegerLiteral => integer_literal), mais il arrive que les noms soient abrégés (par exemple, QuantizationStorageType => storage_type), auquel cas les noms sont introduit explicitement de la même manière que "Entrées" / "Sorties" sections en activité caractéristiques techniques.
  • De plus, les définitions des membres incluent toujours self pour faire référence au l'élément de programme correspondant.

Valeurs

Lorsque les formules sont évaluées, elles fonctionnent avec les types de valeurs suivants: 1) Value (valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> ; ils connaissent toujours leurs types), 2) Placeholder (valeurs futures, par exemple lhs, rhs ou result ; leurs valeurs réelles) les valeurs ne sont pas encore connues, seuls leurs types sont connus), 3) Type (types définis dans la section "Types"), 4) Function (fonctions globales telles que définies dans la section "Fonctions").

Selon le contexte, les noms peuvent faire référence à différentes valeurs. Plus en particulier la "sémantique" pour les opérations (et ses équivalents pour les autres programmes ) définit la logique d'exécution, de sorte que toutes les entrées sont disponibles en tant que Value. En revanche, la colonne "Contraintes" des opérations (et leurs équivalents) définit "compile-time" c'est-à-dire quelque chose qui est généralement exécuté avant l'exécution, Ainsi, seules les entrées constantes sont disponibles en tant que Value, et les autres entrées sont disponible uniquement en tant que Placeholder.

Noms Dans "Sémantique" Dans "Contraintes"
Fonctions globales Function Function
Entrées constantes Value Value
Entrées non constantes Value Placeholder
Sorties Value Placeholder
Définitions locales Dépend de la définition Dépend de la définition

Prenons un exemple d'opération transpose:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Pour cette opération, permutation est une constante. Il est donc disponible en tant que Value. tant au niveau de la sémantique que des contraintes. En revanche, operand et result sont disponible en tant que Value en sémantique, mais uniquement en tant que Placeholder dans les contraintes.

Fonctions

Construction de types

Aucune fonction ne peut être utilisée pour créer des types. Au lieu de cela, nous directement utilisez la syntaxe du type, car elle est généralement plus concise. Exemple : (tensor<E>, tensor<E>) -> (tensor<E>) au lieu de function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Fonctions sur les types

  • element_type est défini sur les types de Tensors et les types de Tensors quantifiés. renvoie respectivement TensorElementType ou QuantizedTensorElementType. une partie du TensorType ou du QuantizedTensorType correspondant.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool vérifie si le type x peut être promu pour saisir y. Lorsque x et y sont des QuantizedTensorElementType, la promotion ne s'applique qu'à storage_type. Cette version spécifique de la promotion est actuellement utilisé dans le contexte du calcul de réduction (reportez-vous à RFC pour en savoir plus).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized_tensor_element_type(x)

  • is_type_name(x: Value | Placeholder | Type) -> Value Disponible pour tous de données. Par exemple, is_float(x) renvoie true si x est de type FloatType. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour is_type_name(type(x))

  • max_value(x: Type) -> Value renvoie la valeur maximale d'un TensorElementType Si x n'est pas de type TensorElementType, la fonction renvoie None.

  • min_value(x: Type) -> Value renvoie la valeur minimale possible d'une TensorElementType Si x n'est pas de type TensorElementType, la fonction renvoie None.

  • member_name(x: Value | Placeholder | Type) -> Any Disponible pour tous les membres définitions member_name de tous types. Exemple : tensor_element_type(x) renvoie la partie TensorElementType d'un TensorType correspondant. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour member_name(type(x)) Si x n'est pas un type disposant d'un membre approprié, ou une valeur ou un espace réservé de ce type, renvoie None.

  • is_empty_algorithm(*args: Type) vérifie si tous les champs de l'algorithme à points sont définis à None. Cela est nécessaire, car l'implémentation est définie pour les algorithmes . comportements par défaut. Par conséquent, spécifier une valeur par défaut serait incorrect.

Construction des valeurs

  • operation_name(*xs: Value | Type) -> Value Disponible pour toutes les opérations. Par exemple, add(lhs, rhs) prend deux valeurs de Tensor lhs et rhs, et renvoie le résultat de l'évaluation de l'opération add avec ces entrées. Pour certaines opérations, par exemple broadcast_in_dim, les types de leurs sorties sont les suivants : "portant", c'est-à-dire nécessaire pour évaluer une opération. Dans ce cas, la fonction utilise ces types comme arguments.

Fonctions sur les valeurs

  • Tous les opérateurs et fonctions Python sont disponibles. Exemple : les deux abonnement et le segmentage les notations de Python peuvent être indexées sous forme de Tensors quantifiés, et les tuples.

  • to_destination_type(x: Value, destination_type: Type) -> Value est défini le et renvoie la valeur convertie de x en fonction des type(x) et destination_type comme suit:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

La fusion de convert, uniform_quantize et Opérations uniform_dequantize (#1576). Après la fusion, nous n'avons plus besoin de la fonction ci-dessus et pouvons utiliser le nom de l'opération pour convert à la place.

  • is_nan(x: Value) -> Value est défini sur les Tensors et renvoie true si tous les éléments de x sont NaN ou false dans les autres cas. Si x n'est pas un Tensor, renvoie None.

  • is_sorted(x: Value) -> Value est défini sur les Tensors et renvoie true si les éléments de x sont triés par ordre croissant l'ordre lexicographique de leurs index, ou false dans les autres cas. Si x n'est pas un renvoie None.

  • is_unique(x: Value) -> Value est défini sur les Tensors et renvoie true si x. ne contient pas d'éléments en double, ni false dans le cas contraire. Si x n'est pas un Tensor, renvoie None.

  • member_name(x: Value) -> Any est défini pour toutes les définitions de membre. member_name de l'ensemble des valeurs. Par exemple, real_part(x) renvoie RealPart partie d'un ComplexConstant correspondant. Si x n'est pas une valeur ayant un membre approprié, renvoie None.

  • same(x: Value) -> Value est défini sur les Tensors et renvoie true si les éléments de x sont tous égaux les uns par rapport aux autres, ou false dans le cas contraire. Si le Tensor ne comporte aucun élément, c'est-à-dire "tous égaux les uns aux autres". renvoie true. Si x n'est pas un Tensor, la fonction renvoie None.

  • split(x: Value, num_results: Value, axis: Value) -> Value est défini le et renvoie num_results tranches de x le long de l'axe axis. Si x n'est pas un Tensor ou dim(x, axis) % num_results != 0, renvoie None.

  • is_defined_in_parent_scope(x: Value) -> Value est défini sur des chaînes et renvoie true si x est le nom d'une fonction définie dans le même champ d'application. en tant que fonction parente de l'opération concernée.

  • is_namespaced_op_name(x: Value) -> Value est défini sur des chaînes et renvoie true si x est un nom d'opération valide, c'est-à-dire qu'il respecte le code expression: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Calculs de formes

  • axes(x: Value | Placeholder | Type) -> Value est un raccourci pour range(rank(x))

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value est un raccourci pour shape(x)[axis]

  • dims(x: Value | Placeholder | Type, axes: List) -> List est un raccourci pour list(map(lambda axis: dim(x, axis), axes))

  • index_space(x: Value | Placeholder | Type) -> Value est défini sur les Tensors. et renvoie les index size(x) pour la colonne TensorType correspondante, triée de la façon suivante : ordre lexicographique croissant (par exemple, [0, ..., 0], [0, ..., 1], ..., shape(x) - 1 Si x n'est pas un type de Tensor, un type de Tensor quantifié ou une valeur ou un espace réservé de l'un de ces types, renvoie None.

  • rank(x: Value | Placeholder | Type) -> Value est un raccourci pour size(shape(x))

  • shape(x: Value | Placeholder | Type) -> Value est défini dans la colonne "Functions" sur les types" via member_name.

  • size(x: Value | Placeholder | Type) -> Value est un raccourci pour reduce(lambda x, y: x * y, shape(x))

Calculs de quantification

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type est un raccourci pour element_type(baseline_type(x)).

  • baseline_type est défini sur les types de Tensors et les types de Tensors quantifiés. les transforme en "référence", c'est-à-dire un type ayant la même forme, mais ayant les valeurs par défaut des paramètres de quantification du type d'élément sont rétablis. C'est est une astuce pratique pour comparer les types de Tensors et quantifiés de manière uniforme, ce qui est assez souvent nécessaire. Pour les types quantifiés, cela permet comparer des types en ignorant les paramètres de quantification, c'est-à-dire shape, storage_type, expressed_type, storage_min, storage_max et quantization_dimension (pour le type quantifié par axe) doit tous correspondre, mais Les valeurs entre scales et zero points peuvent être différentes.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize est défini sur des types de Tensor quantifiés et les transforme en types de Tensor à virgule flottante. Cela se produit via la conversion d'éléments quantifiés qui représentent des valeurs entières du type de stockage valeurs à virgule flottante du type exprimé en utilisant le point zéro et l'échelle associé au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize est défini sur les types de Tensors à virgule flottante et les transforme en types de Tensor quantifiés. Cela se produit via la conversion des valeurs à virgule flottante du type exprimé en valeurs entières correspondantes du type de stockage. en utilisant le point zéro et l'échelle associées au type d'élément quantifié.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize permet de spécifier des calculs au niveau des éléments sur les Tensors quantifiés. Elle déquantifie, c'est-à-dire transforme les éléments quantifiés en des types exprimés, puis effectue une opération, puis effectue une quantification les résultats dans leurs types de stockage. Pour le moment, cette fonction ne fonctionne pour la quantification par Tensor. La quantification par axe est en cours de développement (#1574)
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op permet de spécifier une quantification pondérée uniquement pour Opération hybride qui accepte lhs en virgule flottante et rh en types quantifiés. Il déquantifie les entrées quantifiées en types exprimés et effectue des calculs en float. Type d'élément du Tensor lhs flottant et type exprimé de rh quantifiée doit être identique.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Calculs en grille

  • cross_partition(replica_groups: Value) -> Value Consultez l'instance "cross_replica" ci-dessus.

  • cross_replica(replica_groups: Value) -> Value Consultez l'instance "cross_replica" ci-dessus.

  • cross_replica_and_partition(replica_groups: Value) -> Value Consultez le &quot;cross_replica_and_partition&quot; ci-dessus.

  • flattened_ids(replica_groups: Value) -> Value Voir la colonne "flattened_ids" ci-dessus.

Dynamique

Les valeurs StableHLO peuvent avoir des tailles de dimension dynamiques (par exemple, tensor<?xi64> Toutefois, les valeurs StableHLO ne peuvent pas comporter de nombres dynamiques de dimensions (non classées dynamisme, par exemple tensor<*xi64>). Les opérandes et les résultats sont autorisés à utiliser des des dimensions, même s'il existe des contraintes sur les tailles. Les contraintes seront vérifiées statiquement si possible, sinon ils sont différés à l'exécution et incohérences entraîneront un comportement non défini. Vous trouverez des exemples ci-dessous.

Incohérences de forme pour les opérations unaires par élément

Prenons l'exemple du programme de jouets suivant:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Un tel programme est inhabituel, car il n'est pas courant de connaître la forme du mais pas la forme de l'entrée. Il s'agit toutefois d'un StableHLO valide programme. Il n'est pas possible de valider de manière statique l'opération abs dans cette programme, car la forme exacte de l'opérande est inconnue. Cependant, les formes sont certainement compatibles, et vous pouvez le vérifier de manière statique: ? pourrait s'avérer serait 2 au moment de l'exécution, et il n'y aurait aucun problème. Toutefois, ? pourrait s'avèrent également être un autre nombre entier, auquel cas le comportement n'est pas défini.

Notez que si une taille de dimension est dynamique dans le résultat, il ne peut pas un comportement indéfini. En effet, il n'y a pas de "présence" Il ne peut donc pas y avoir ne correspondent pas.

Incohérences de forme pour les opérations binaires par élément

Prenons l'exemple du programme de jouets suivant:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

Dans le cas d'opérations binaires par élément, la forme des entrées et le le résultat doit concorder au moment de l'exécution. Au moment de la compilation, les dimensions statiques doivent être égales. sinon ils ont juste besoin d'être compatibles. Si n'importe quelle dimension est dynamique dans les entrées, il est possible qu'il n'y ait pas de définition au moment de l'exécution, car il est possible que la taille dynamique ne corresponde pas taille dans l'autre opérande (statique ou dynamique). Si toutes les entrées sont statique, le fait que le résultat soit dynamique ou non n'a pas d'importance: statiquement, les dimensions connues sont vérifiées de manière statique, contrairement aux dimensions dynamiques imposent des contraintes.

Incohérences de formes pour les opérations dont la forme de sortie est un opérande

Prenons l'exemple du programme de jouets suivant:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Les valeurs de l'opérande de forme au moment de l'exécution doivent correspondre à la forme du résultat. sinon le comportement n'est pas défini. Autrement dit, au moment de l'exécution, %arg0 doit avoir une la valeur de dense<[3, 4]> : tensor<2xi32>. Si l'opérande de forme est constant, cette peut être vérifiée de manière statique. Si la forme du résultat est entièrement dynamique, ne peut pas être une incohérence.