Spécification StableHLO

StableHLO est un ensemble d'opérations pour les opérations de haut niveau (HLO) dans les modèles de machine learning (ML). StableHLO fonctionne comme une couche de portabilité entre différents frameworks de ML et compilateurs de ML : les frameworks de ML qui produisent des programmes StableHLO sont compatibles avec les compilateurs de ML qui consomment des programmes StableHLO.

Notre objectif est de simplifier et d'accélérer le développement du ML en créant une interopérabilité plus poussée entre les différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et les compilateurs de ML (tels que XLA et IREE). À cette fin, ce document fournit une spécification pour le langage de programmation StableHLO.

Cette spécification comporte trois sections principales. Tout d'abord, la section Programmes décrit la structure des programmes StableHLO, qui consistent en des fonctions StableHLO qui, à leur tour, consistent en des opérations StableHLO. Au sein de cette structure, la section Ops spécifie la sémantique des opérations individuelles. La section Execution fournit une sémantique pour toutes ces opérations exécutées ensemble dans un programme. Enfin, la section Notation décrit la notation utilisée dans l'ensemble de la spécification.

Pour afficher les spécifications d'une version précédente de StableHLO, ouvrez le dépôt à la version taguée de votre choix. Par exemple, la spécification StableHLO v0.19.0. Pour afficher les modifications apportées à chaque version mineure de StableHLO, consultez le journal des versions dans VhloDialect.td.

Programmes

Program ::= {Func}

Les programmes StableHLO se composent d'un nombre arbitraire de fonctions StableHLO. Vous trouverez ci-dessous un exemple de programme avec une fonction @main qui comporte trois entrées (%image, %weights et %bias) et une sortie. Le corps de la fonction comporte 6 opérations.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Fonctions

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Les fonctions StableHLO (également appelées fonctions nommées) possèdent un identifiant, des entrées/sorties et un corps. À l'avenir, nous prévoyons d'introduire des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, #740, #744).

Identifiants

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Les identifiants StableHLO sont semblables aux identifiants de nombreux langages de programmation, avec deux particularités : 1) tous les identifiants comportent des sigils qui distinguent les différents types d'identifiants ; 2) les identifiants de valeur peuvent être entièrement numériques pour simplifier la génération de programmes StableHLO.

Types

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Les types StableHLO sont classés en types de valeurs (également appelés types de première classe) qui représentent les valeurs StableHLO et en types non de valeur qui décrivent d'autres éléments de programme. Les types StableHLO sont semblables à ceux de nombreux langages de programmation, la principale particularité étant la nature spécifique au domaine de StableHLO, qui entraîne des résultats inhabituels (par exemple, les types scalaires ne sont pas des types de valeurs).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Les types de Tensor représentent des Tensors, c'est-à-dire des tableaux multidimensionnels. Ils ont une forme et un type d'élément, où une forme représente des tailles de dimension non négatives ou inconnues dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0 à R-1. Le nombre de dimensions R est appelé classement. Par exemple, tensor<2x3xf32> est un type de tenseur avec la forme 2x3 et le type d'élément f32. Il comporte deux dimensions (ou, en d'autres termes, deux axes) : la dimension 0 et la dimension 1, dont les tailles sont respectivement 2 et 3. Son classement est de 2.

Les formes peuvent être partiellement ou totalement inconnues (dynamiques), par exemple, tensor<?x2xf64> est partiellement inconnu et tensor<?x?xf64> est totalement inconnu. Les tailles de dimension dynamique sont représentées à l'aide d'un ?. Les formes ne peuvent pas être désclassées.

À l'avenir, nous prévoyons d'étendre les types de tenseurs au-delà des tailles de dimension et des types d'éléments, par exemple pour inclure les mises en page (numéro 629) et la sparsité (numéro 1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nom Type Contraintes
storage_type type entier (C1-C3), (C8)
storage_min constante entière (C1), (C3), (C7)
storage_max constante d'entier (C2), (C3), (C7)
expressed_type type à virgule flottante (C4)
quantization_dimension constante facultative (nombre entier) (C10-C12)
scales nombre variadique de constantes à virgule flottante (C4-C6), (C9), (C10), (C13)
zero_points nombre variadique de constantes entières (C7-C9)

Les types d'éléments quantifiés représentent les valeurs entières d'un type de stockage compris entre storage_min et storage_max (inclus) qui correspondent aux valeurs à virgule flottante d'un type exprimé. Pour une valeur entière donnée i, la valeur à virgule flottante correspondante f peut être calculée comme f = (i - zero_point) * scale, où scale et zero_point sont appelés paramètres de quantification. Les éléments storage_min et storage_max sont facultatifs dans la grammaire, mais leurs valeurs par défaut sont respectivement min_value(storage_type) et max_value(storage_type). Les types d'éléments quantifiés présentent les contraintes suivantes:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Si is_empty(quantization_dimension), alors size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Pour le moment, QuantizationScale est une constante à virgule flottante, mais les échelles basées sur des entiers, représentées par des multiplicateurs et des décalages, suscitent un vif intérêt. Nous prévoyons d'explorer cette question prochainement (#1404).

La sémantique de QuantizationZeroPoint est en cours de discussion, y compris concernant le type et les valeurs, et sur la possibilité de n'avoir qu'un seul ou plusieurs points zéro dans un type de Tensor quantifié. D'après les résultats de cette discussion, les spécifications concernant les points nuls peuvent changer à l'avenir (numéro 1405).

Une autre discussion en cours concerne la sémantique de QuantizationStorageMin et QuantizationStorageMax pour déterminer si des contraintes doivent être imposées à ces valeurs et aux valeurs des tenseurs quantifiés (numéro 1406).

Enfin, nous prévoyons d'explorer la représentation d'échelles et de points zéro inconnus, comme nous prévoyons de le faire pour la représentation de tailles de dimension inconnues (numéro 1407).

Les types de tenseurs quantifiés représentent des tenseurs dont les éléments sont quantifiés. Ces Tensors sont exactement les mêmes que les Tensors standards, sauf que leurs éléments ont des types d'éléments quantifiés, et non des types d'éléments standards.

Dans les tenseurs quantifiés, la quantification peut être par Tensor, ce qui signifie qu'elle peut avoir une valeur scale et zero_point pour l'ensemble du Tensor ou par axe, ce qui signifie qu'elle peut avoir plusieurs scales et zero_points, une paire par tranche d'une dimension quantization_dimension particulière. Plus formellement, dans un tenseur t avec quantification par axe, il existe dim(t, quantization_dimension) tranches de quantization_dimension : t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], etc. Tous les éléments de la ie tranche utilisent scales[i] et zero_points[i] comme paramètres de quantification. Les types de Tensors quantifiés présentent les contraintes suivantes:

  • Pour la quantification par tenseur :
    • Aucune autre contrainte.
  • Pour la quantification par axe :
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Les types de jetons représentent des jetons, c'est-à-dire des valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer l'ordre d'exécution aux opérations, comme décrit dans la section Exécution.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Les types de tuple représentent des tupels, c'est-à-dire des listes hétérogènes. Les tupels sont une ancienne fonctionnalité qui n'existe que pour la compatibilité avec HLO. Dans HLO, les tupels sont utilisés pour représenter les entrées et les sorties variadiques. Dans StableHLO, les entrées et sorties variadiques sont prises en charge en mode natif, et le seul usage des tupels dans StableHLO est de représenter de manière exhaustive l'ABI HLO, par exemple, T, tuple<T> et tuple<tuple<T>> peuvent être sensiblement différents en fonction d'une implémentation particulière. À l'avenir, nous prévoyons d'apporter des modifications à l'ABI HLO, ce qui nous permettra peut-être de supprimer les types de tuple de StableHLO (numéro 598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Les types d'éléments représentent les éléments des types de Tensor. Contrairement à de nombreux langages de programmation, ces types ne sont pas de première classe dans StableHLO. Cela signifie que les programmes StableHLO ne peuvent pas représenter directement les valeurs de ces types (par conséquent, il est courant de représenter les valeurs scalaires de type T avec des valeurs de tenseur à dimension 0 de type tensor<T>).

  • Le type booléen représente les valeurs booléennes true et false.
  • Les types d'entiers peuvent être signés (si) ou non signés (ui) et avoir l'une des largeurs de bits acceptées (2, 4, 8, 16, 32 ou 64). Les types siN signés représentent les valeurs d'entiers de -2^(N-1) à 2^(N-1)-1 inclus, et les types uiN non signés représentent les valeurs d'entiers de 0 à 2^N-1 inclus.
  • Les types à virgule flottante peuvent être l'un des éléments suivants :
  • Les types complexes représentent des valeurs complexes ayant une partie réelle et une partie imaginaire du même type d'élément. Les types complexes compatibles sont complex<f32> (les deux parties sont de type f32) et complex<f64> (les deux parties sont de type f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Les types de fonctions représentent à la fois les fonctions nommées et anonymes. Ils ont des types d'entrée (la liste des types à gauche de ->) et des types de sortie (la liste des types à droite de ->). Dans de nombreux langages de programmation, les types de fonction sont de première classe, mais pas dans StableHLO.

StringType ::= 'string'

Le type de chaîne représente des séquences d'octets. Contrairement à de nombreux langages de programmation, le type de chaîne n'est pas de première classe dans StableHLO et n'est utilisé que pour spécifier des métadonnées statiques pour les éléments de programme.

Opérations

Les opérations StableHLO (également appelées opérations) représentent un ensemble fermé d'opérations de haut niveau dans les modèles de machine learning. Comme indiqué ci-dessus, la syntaxe StableHLO est fortement inspirée de MLIR, qui n'est pas nécessairement l'alternative la plus ergonomique, mais qui est sans doute la plus adaptée à l'objectif de StableHLO, qui est de créer une interopérabilité accrue entre les frameworks ML et les compilateurs ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Les opérations StableHLO (également appelées opérations) ont un nom, des entrées/sorties et une signature. Le nom se compose du préfixe stablehlo. et d'un mnémonique qui identifie de manière unique l'une des opérations compatibles. Vous trouverez ci-dessous la liste complète de toutes les opérations prises en charge.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Les opérations consomment des entrées et produisent des sorties. Les entrées sont classées en valeurs d'entrée (calculées lors de l'exécution), fonctions d'entrée (fournies de manière statique, car dans StableHLO, les fonctions ne sont pas des valeurs de première classe) et attributs d'entrée (également fournis de manière statique). Le type d'entrées et de sorties consommées et produites par une opération dépend de son mnémonique. Par exemple, l'opération add consomme deux valeurs d'entrée et produit une valeur de sortie. En comparaison, l'opération select_and_scatter consomme 3 valeurs d'entrée, 2 fonctions d'entrée et 3 attributs d'entrée.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Les fonctions d'entrée (également appelées fonctions anonymes) sont très similaires aux fonctions nommées, à l'exception des points suivants : 1) elles n'ont pas d'identifiant (d'où le nom "anonyme"), 2) elles ne déclarent pas de types de sortie (les types de sortie sont inférés à partir de l'opération return dans la fonction).

La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir la production Unused ci-dessus) qui est là pour la compatibilité avec MLIR. Dans MLIR, il existe un concept plus général de "régions" pouvant comporter plusieurs "blocs" d'opérations reliés entre eux via des jump ops. Ces blocs ont des ID qui correspondent à la production Unused afin de pouvoir les distinguer les uns des autres. StableHLO ne comporte pas d'opérations de saut. Par conséquent, la partie correspondante de la syntaxe MLIR n'est pas utilisée (mais elle est toujours là).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Les attributs d'entrée ont un nom et une valeur qui correspond à l'une des constantes acceptées. Il s'agit du principal moyen de spécifier des métadonnées statiques pour les éléments de programme. Par exemple, l'opération concatenate utilise l'attribut dimension pour spécifier la dimension via laquelle ses valeurs d'entrée sont concaténées. De même, l'opération slice utilise plusieurs attributs tels que start_indices et limit_indices pour spécifier les limites utilisées pour découper la valeur d'entrée.

Pour le moment, les programmes StableHLO dans la nature contiennent parfois des attributs qui ne sont pas décrits dans ce document. À l'avenir, nous prévoyons d'intégrer ces attributs dans l'ensemble d'opérations StableHLO ou de les interdire dans les programmes StableHLO. En attendant, voici la liste de ces attributs:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Métadonnées de lieu (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

La signature d'opération se compose des types de toutes les valeurs d'entrée (la liste des types à gauche de ->) et des types de toutes les valeurs de sortie (la liste des types à droite de ->). À strictement parler, les types d'entrée sont redondants, et les types de sortie sont presque toujours redondants également (car pour la plupart des opérations StableHLO, les types de sortie peuvent être déduits des entrées). Néanmoins, la signature d'opération fait délibérément partie de la syntaxe StableHLO pour assurer la compatibilité avec MLIR.

Vous trouverez ci-dessous un exemple d'opération dont l'expression mnémotechnique est select_and_scatter. Elle consomme trois valeurs d'entrée (%operand, %source et %init_value), deux fonctions d'entrée et trois attributs d'entrée (window_dimensions, window_strides et padding). Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée (mais pas les types de fonctions et d'attributs d'entrée qui sont fournis en ligne).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Constantes

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Les constantes StableHLO comportent un littéral et un type qui, ensemble, représentent une valeur StableHLO. En général, le type fait partie de la syntaxe de la constante, sauf lorsqu'il est sans ambiguïté (par exemple, une constante booléenne a un type i1 sans ambiguïté, tandis qu'une constante entière peut avoir plusieurs types possibles).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Les constantes booléennes représentent les valeurs booléennes true et false. Les constantes booléennes sont de type i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Les constantes entières représentent des valeurs entières via des chaînes qui utilisent la notation décimale ou hexadécimale. Les autres bases, telles que les bases binaires ou octales, ne sont pas acceptées. Les constantes entières présentent les contraintes suivantes:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Les constantes à virgule flottante représentent des valeurs à virgule flottante via des chaînes qui utilisent une notation décimale ou scientifique. De plus, la notation hexadécimale peut être utilisée pour spécifier directement les bits sous-jacents au format à virgule flottante du type correspondant. Les constantes à virgule flottante présentent les contraintes suivantes:

  • (C1) Si une notation non hexadécimale est utilisée, is_wellformed(float_literal, float_type).
  • (C2) Si vous utilisez la notation hexadécimale, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Les constantes complexes représentent des valeurs complexes à l'aide de listes d'une partie réelle (qui vient en premier) et d'une partie imaginaire (qui vient en deuxième). Par exemple, (1.0, 0.0) : complex<f32> représente 1.0 + 0.0i et (0.0, 1.0) : complex<f32> représente 0.0 + 1.0i. L&#39;ordre dans lequel ces parties sont ensuite stockées en mémoire est défini par l&#39;implémentation. Les constantes complexes présentent les contraintes suivantes:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Les constantes de tenseur représentent les valeurs de tenseur à l'aide de listes imbriquées spécifiées via la notation NumPy. Par exemple, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> représente une valeur de Tensor avec le mappage suivant entre les index et les éléments : {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. L'ordre dans lequel ces éléments sont ensuite stockés en mémoire est défini par l'implémentation. Les constantes de tenseur sont soumises aux contraintes suivantes:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), où :
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), où :
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • Sinon, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Les constantes de tenseur quantifié représentent les valeurs de tenseur quantifié à l'aide de la même notation que les constantes de tenseur, avec des éléments spécifiés en tant que constantes de leur type de stockage. Les constantes de tenseur quantifiées présentent les contraintes suivantes :

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Les chaînes littérales sont constituées d'octets spécifiés à l'aide de caractères ASCII et de séquences d'échappement. Ils sont indépendants de l'encodage, de sorte que l'interprétation de ces octets est définie par mise en œuvre. Les littéraux de chaîne sont de type string.

Opérations

abs

Sémantique

Effectue une opération abs élément par élément sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les entiers signés : module d'entier.
  • Pour les nombres à virgule flottante : abs d'IEEE-754.
  • Pour les nombres complexes: module complexe.
  • Pour les types quantifiés: dequantize_op_quantize(abs, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1-C2)

Sorties

Nom Type Contraintes
result Tenseur de type entier signé ou à virgule flottante, ou tenseur quantifié par tenseur (C1-C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • baseline_element_type(operand) dans les autres cas.

Exemples

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

 Autres exemples

add

Sémantique

Effectue l'addition par élément de deux Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les valeurs booléennes : opérateur logique "OR".
  • Pour les entiers: addition d'entiers.
  • Pour les nombres à virgule flottante : addition d'IEEE-754.
  • Pour les nombres complexes: addition complexe.
  • Pour les types quantifiés : dequantize_op_quantize(add, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur ou un tenseur quantifié ; (C1-C6)
(I2) rhs Tensor ou Tensor quantifié (C1-C5), (C7)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C1-C7)

Contraintes

  • Si l'opération utilise des tenseurs non linéarisables :
    • (C1) type(lhs) = type(rhs) = type(result).
  • Si l'opération utilise des tenseurs quantifiés :
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Si is_per_axis_quantized(lhs), alors quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = quantization_dimension(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Autres exemples

after_all

Sémantique

Assure que les opérations produisant le inputs sont exécutées avant toutes les opérations qui dépendent de result. L'exécution de cette opération n'a aucun effet. Elle n'existe que pour établir des dépendances de données entre result et inputs.

Entrées

Libellé Nom Type
(I1) inputs Nombre variable de token

Sorties

Nom Type
result token

Exemples

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

 Autres exemples

all_gather

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs des tenseurs operands de chaque processus le long de all_gather_dim et produit des tenseurs results.

L'opération divise la grille de processus StableHLO en process_groups, qui est définie comme suit :

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group :

  • operands...@receiver = [operand@sender for sender in process_group] pour tous les receiver dans process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) pour l'ensemble des process de process_group.

Entrées

Libellé Nom Type Contraintes
(I1) operands nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1), (C6)
(I2) all_gather_dim constante de type si64 (C1), (C6)
(I3) replica_groups Constante de tenseur bidimensionnel de type si64 (C2-C4)
(I4) channel_id constante de type si64 (C5)
(I5) use_global_device_ids constante de type i1 (C5)

Sorties

Nom Type Contraintes
results nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C6)

Contraintes

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Si use_global_device_ids = true, alors channel_id > 0.
  • (C6) type(results...) = type(operands...), sauf :
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

 Autres exemples

all_reduce

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, applique une fonction de réduction computation aux valeurs des tenseurs operands de chaque processus et produit des tenseurs results.

L'opération divise la grille de processus StableHLO en process_groups, qui est définie comme suit :

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group :

  • results...@process[result_index] = exec(schedule) pour un arbre binaire schedule où :
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est une arborescence binaire définie par l'implémentation dont le balayage dans l'ordre est to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

Entrées

Libellé Nom Type Contraintes
(I1) operands nombre variadique de Tensors ou Tensors quantifiés par Tensor (C5), (C6)
(I2) replica_groups nombre variadique de constantes de Tensor unidimensionnelles de type si64 (C1-C3)
(I3) channel_id constante de type si64 (C4)
(I4) use_global_device_ids constante de type i1 (C4)
(I5) computation fonction (C5)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C6-C7)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Si use_global_device_ids = true, alors channel_id > 0.
  • (C5) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>), où is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

 Autres exemples

all_to_all

Sémantique

all_to_all

Dans chaque groupe de processus de la grille de processus StableHLO, il divise les valeurs des Tensors operands le long de split_dimension en plusieurs parties, répartit les parties divisées entre les processus, concatène les parties dispersées le long de concat_dimension et produit des Tensors results. L'opération divise la grille de processus StableHLO en process_groups, qui est définie comme suit :

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

Ensuite, dans chaque process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) pour tous les sender dans process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]receiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Entrées

Libellé Nom Type Contraintes
(I1) operands nombre variadique de Tensors ou Tensors quantifiés par Tensor (C1-C3), (C9)
(I2) split_dimension constante de type si64 (C1), (C2), (C9)
(I3) concat_dimension constante de type si64 (C3), (C9)
(I4) split_count constante de type si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Constante de Tensor bidimensionnelle de type si64 (C5-C8)
(I6) channel_id constante de type si64

Sorties

Nom Type Contraintes
results nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C9)

Contraintes

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...), sauf si split_dimension != concat_dimension :
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

 Autres exemples

et

Sémantique

Effectue une opération AND élément par élément sur deux tenseurs lhs et rhs, et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les valeurs booléennes : "AND" logique
  • Pour les entiers: AND au niveau du bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur de type booléen ou entier ; (C1)
(I2) rhs un tenseur de type booléen ou entier ; (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

Autres exemples

atan2

Sémantique

Effectue une opération atan2 élément par élément sur le tenseur lhs et rhs, et génère un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : atan2 d'IEEE-754.
  • Pour les nombres complexes: atan2 complexe.
  • Pour les types quantifiés : dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)
(I2) rhs un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

 Autres exemples

batch_norm_grad

Sémantique

Calcule les gradients de plusieurs entrées de batch_norm_training en rétropropagation à partir de grad_output et produit les Tensors grad_operand, grad_scale et grad_offset. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1-C3), (C5)
(I2) scale Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C4), (C5)
(I3) mean Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C4)
(I4) variance Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C4)
(I5) grad_output un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C2), (C3)
(I6) epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C5)

Sorties

Nom Type Contraintes
grad_operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C2), (C3)
grad_scale Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C4)
grad_offset Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C4)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale et grad_offset ont le même baseline_element_type.
  • (C3) operand, grad_output et grad_operand ont la même forme.
  • (C4) scale, mean, variance, grad_scale et grad_offset ont la même forme.
  • (C5) size(scale) = dim(operand, feature_index).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Sémantique

Normalise le tenseur operand dans toutes les dimensions, à l'exception de la dimension feature_index, et produit un tenseur result. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1-C7)
(I2) scale Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C3)
(I3) offset Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C4)
(I4) mean Tensor 1D de type à virgule flottante ou quantifié par tenseur (C5)
(I5) variance Tensor 1D de type à virgule flottante ou quantifié par tenseur (C2), (C6)
(I6) epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C2), (C7)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance et result ont le même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Sémantique

Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de la dimension feature_index, et normalise le tenseur operand pour produire les tenseurs output, batch_mean et batch_var. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Pour les types quantifiés, effectue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1)
(I2) scale Tensor 1D à virgule flottante ou quantifié par tenseur (C2), (C3)
(I3) offset Tensor 1D à virgule flottante ou quantifié par tenseur (C2), (C4)
(I4) epsilon constante de type f32 (C1), (C3-C6)
(I5) feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
output un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C7)
batch_mean Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor (C2), (C5)
batch_var Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor (C2), (C6)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var et output ont le même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Sémantique

Effectue une opération de cast de bits sur le tenseur operand et produit un tenseur result où les bits de l'ensemble du tenseur operand sont réinterprétés à l'aide du type du tenseur result.

Plus formellement, étant donné E = element_type(operand), E' = element_type(result) et R = rank(operand) :

  • Si num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Si num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Si num_bits(E') = num_bits(E), alors bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits renvoie la représentation en mémoire d'une valeur donnée, et son comportement est défini par l'implémentation, car la représentation exacte des tenseurs est définie par l'implémentation, et la représentation exacte des types d'éléments est également définie par l'implémentation.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié ; (C1-C2)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C1-C2)

Contraintes

  • (C1) Étant donnés E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) et R = rank(operand) :
    • Si num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Si num_bits(E') < num_bits(E) :
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Si num_bits(E') > num_bits(E) :
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Si is_complex(operand) or is_complex(result), alors is_complex(operand) and is_complex(result).

Exemples

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

 Autres exemples

broadcast_in_dim

Sémantique

Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données du Tensor operand, et produit un Tensor result. Plus formellement, result[result_index] = operand[operand_index], où pour tous les d dans axes(operand) :

  • operand_index[d] = 0 si dim(operand, d) = 1.
  • Sinon, operand_index[d] = result_index[broadcast_dimensions[d]].

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié ; (C1-C2), (C5-C6)
(I2) broadcast_dimensions Constante de Tensor unidimensionnelle de type si64 (C2-C6)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1), (C3), (C5-C6)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), à l'exception de quantization_dimension(operand), scales(operand) et zero_points(operand), qui peuvent différer de quantization_dimension(result), scales(result) et zero_points(result), respectivement, dans le cas contraire.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Pour tous les d de axes(operand) :
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si is_per_axis_quantized(result) :
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si dim(operand, quantization_dimension(operand)) = 1, alors scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Exemples

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Autres exemples

coque

Sémantique

Génère le résultat de l'exécution d'une seule fonction de branches en fonction de la valeur de index. Plus formellement, result = selected_branch(), où :

  • selected_branch = branches[index] si 0 <= index < size(branches).
  • Sinon, selected_branch = branches[-1].

Entrées

Libellé Nom Type Contraintes
(I1) index Tensor de dimension 0 de type si32
(I2) branches nombre variable de fonctions (C1-C4)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C4)

Contraintes

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Exemples

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

 Autres exemples

cbrt

Sémantique

Effectue une racine cubique élément par élément sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : rootn(x, 3) d'IEEE-754.
  • Pour les nombres complexes: racine cubique complexe.
  • Pour les types quantifiés : dequantize_op_quantize(cbrt, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

 Autres exemples

ceil

Sémantique

Effectue le plafond par élément du tenseur operand et produit un tenseur result. Implémente l'opération roundToIntegralTowardPositive de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(ceil, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

 Autres exemples

cholesky

Sémantique

Calcule la décomposition de Cholesky d'un lot de matrices.

Plus formellement, pour tous les i dans index_space(result), result[i0, ..., iR-3, :, :] est une décomposition Cholesky de a[i0, ..., iR-3, :, :], sous la forme d'une matrice triangulaire inférieure (si lower est true) ou triangulaire supérieure (si lower est false). Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict ou le triangle inférieur strict, sont définies par l'implémentation.

Si i existe et que la matrice d'entrée n'est pas une matrice hermitienne définie positive, le comportement est indéfini.

Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1-C3)
(I2) lower Constante de tenseur de dimension 0 de type i1

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Exemples

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

limiter

Sémantique

Limite chaque élément du tenseur operand entre une valeur minimale et maximale, et génère un tenseur result. Plus formellement, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), où min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Pour les types quantifiés, exécute dequantize_op_quantize(clamp, min, operand, max, type(result)).

L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (numéro 560).

Entrées

Libellé Nom Type Contraintes
(I1) min un tenseur ou un tenseur quantifié par tenseur ; (C1), (C3)
(I2) operand un tenseur ou un tenseur quantifié par tenseur ; (C1-C4)
(I3) max un tenseur ou un tenseur quantifié par tenseur ; (C2), (C3)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C4)

Contraintes

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Exemples

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Autres exemples

collective_broadcast

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du tenseur operand du processus source aux processus cibles et générez un tenseur result.

L'opération divise la grille de processus StableHLO en process_groups, qui est définie comme suit :

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

Ensuite, result@process est donné par :

  • operand@process_groups[i, 0] si un i existe de sorte que le processus se trouve dans process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C3)
(I2) replica_groups Nombre variable de constantes de tenseur unidimensionnelles de type si64 (C1), (C2)
(I3) channel_id constante de type si64

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C3)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, où N est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C3) type(result) = type(operand).

Exemples

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du tenseur operand du processus source au processus cible et produit un tenseur result.

L'opération divise la grille de processus StableHLO en process_groups, qui est définie comme suit :

  • cross_replica(source_target_pairs) si channel_id <= 0.
  • cross_partition(source_target_pairs) si channel_id > 0.

Ensuite, result@process est donné par :

  • operand@process_groups[i, 0], s'il existe un i tel que process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C5)
(I2) source_target_pairs Constante de tenseur bidimensionnel de type si64 (C1-C4)
(I3) channel_id constante de type si64

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1)

Contraintes

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, où N est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C5) type(result) = type(operand).

Exemples

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Autres exemples

compare

Sémantique

Effectue une comparaison par élément des tenseurs lhs et rhs selon comparison_direction et compare_type, et produit un tenseur result.

Les valeurs de comparison_direction et compare_type ont la sémantique suivante :

Pour les types d'éléments booléens et entiers :

  • EQ : lhs = rhs.
  • NE : lhs != rhs.
  • GE : lhs >= rhs.
  • GT : lhs > rhs.
  • LE : lhs <= rhs.
  • LT : lhs < rhs.

Pour les types d'éléments à virgule flottante avec compare_type = FLOAT, l'opération implémente les opérations IEEE-754 suivantes :

  • EQ : compareQuietEqual.
  • NE : compareQuietNotEqual.
  • GE : compareQuietGreaterEqual.
  • GT : compareQuietGreater.
  • LE : compareQuietLessEqual.
  • LT : compareQuietLess.

Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER, l'opération utilise la combinaison des opérations totalOrder et compareQuietEqual de la norme IEEE-754.

Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag) est effectuée à l'aide des méthodes comparison_direction et compare_type fournies. L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes lorsque comparison_direction est GE, GT, LE ou LT (numéro 560).

Pour les types quantifiés, effectue dequantize_compare(lhs, rhs, comparison_direction).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1-C3)
(I2) rhs un tenseur ou un tenseur quantifié par tenseur ; (C1-C2)
(I3) comparison_direction énumération de EQ, NE, GE, GT, LE et LT
(I4) compare_type énumération de FLOAT, TOTALORDER, SIGNED et UNSIGNED (C3)

Sorties

Nom Type Contraintes
result Tensor de type booléen (C2)

Contraintes

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type est défini comme suit :
    • SIGNED si is_signed_integer(element_type(lhs)).
    • UNSIGNED si is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT ou TOTALORDER si is_float(element_type(lhs)).
    • FLOAT si is_complex(element_type(lhs)).

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

 Autres exemples

complexe

Sémantique

Effectue une conversion par élément en une valeur complexe à partir d'une paire de valeurs réelles et imaginaires, lhs et rhs, et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur de type f32 ou f64 (C1-C3)
(I2) rhs tenseur de type f32 ou f64 (C1)

Sorties

Nom Type Contraintes
result Tensor de type complexe (C2), (C3)

Contraintes

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) est de type complex<E>E = element_type(lhs).

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

 Autres exemples

composite

Sémantique

Encapsule une opération composée d'autres opérations StableHLO, prenant inputs et composite_attributes, générant ainsi results. La sémantique de l'opération est implémentée par l'attribut decomposition. L'opération composite peut être remplacée par sa décomposition sans modifier la sémantique du programme. Si l'intégration de la décomposition ne fournit pas la même sémantique d'opération, privilégiez custom_call.

Le champ version (par défaut 0) permet d'indiquer quand la sémantique d'un composite change.

Entrées

Libellé Nom Type
(I1) inputs nombre de valeurs variable
(I2) name constante de type string
(I3) composite_attributes dictionnaire d'attributs
(I4) decomposition constante de type string
(I5) version constante de type si32

Sorties

Nom Type
results nombre variadique de valeurs

Contraintes

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Exemples

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

 Autres exemples

concatenate

Sémantique

Chaîne inputs le long de la dimension dimension dans le même ordre que les arguments donnés et produit un tenseur result. Plus formellement, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], où:

  1. id = d0 + ... + dk-1 + kd.
  2. d est égal à dimension, et d0, etc. sont les tailles de la de dimension de inputs.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1-C6)
(I2) dimension constante de type si64 (C2), (C4), (C6)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C5-C6)

Contraintes

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)), sauf dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]), sauf :
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Exemples

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Autres exemples

constante

Sémantique

Génère un Tensor output à partir d'une constante value.

Entrées

Libellé Nom Type Contraintes
(I1) value constante (C1)

Sorties

Nom Type Contraintes
output un tenseur ou un tenseur quantifié ; (C1)

Contraintes

  • (C1) type(value) = type(output).

Exemples

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Autres exemples

d'effectuer une conversion

Sémantique

Effectue une conversion par élément d'un type d'élément à un autre sur le tenseur operand et produit un tenseur result.

Pour les conversions booléen vers n'importe quel type compatible, la valeur false est convertie en zéro, et la valeur true en un. Pour les conversions any-supported-type-to-boolean, une valeur nulle est convertie en false, et les valeurs non nulles sont converties en true. Pour en savoir plus sur le fonctionnement de cette fonctionnalité pour les types complexes, consultez la section ci-dessous.

Pour les conversions impliquant des nombres entier en entier, entier en virgule flottante ou de type virgule flottante en virgule flottante, si la valeur source peut être exactement représentée dans le type de destination, la valeur du résultat correspond à cette représentation exacte. Sinon, le comportement est à déterminer (numéro 180).

Pour les conversions impliquant une floating-point-to-integer, la partie fractionnaire est tronquée. Si la valeur tronquée ne peut pas être représentée dans le type de destination, le comportement est à déterminer (numéro 180).

Les conversions complexe-complexe suivent le même comportement que les conversions virgule flottante-virgule flottante pour convertir les parties réelles et imaginaires.

Pour les conversions complex-to-any-other-type et any-other-type-to-complex, la valeur imaginaire source est ignorée ou la valeur imaginaire de destination est mise à zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.

En principe, cette opération pourrait exprimer la déquantisation (conversion des tenseurs quantiques en tenseurs réguliers), la quantification (conversion des tenseurs réguliers en tenseurs quantiques) et la requantisation (conversion entre tenseurs quantiques), mais pour le moment, nous disposons d'opérations dédiées à cet effet : uniform_dequantize pour le premier cas d'utilisation et uniform_quantize pour le deuxième et le troisième cas d'utilisation. À l'avenir, ces deux opérations pourront être fusionnées dans convert (#1576).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur (C1)

Sorties

Nom Type Contraintes
result tenseur (C1)

Contraintes

  • (C1) shape(operand) = shape(result).

Exemples

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

 Autres exemples

convolution

Sémantique

Calcule les produits scalaires entre les fenêtres de lhs et les tranches de rhs, et produit result. Le diagramme suivant montre comment les éléments de result sont calculés à partir de lhs et rhs à l'aide d'un exemple concret.

convolution

Plus formellement, considérez le recadrage suivant des entrées en termes de lhs afin de pouvoir exprimer des fenêtres de lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1)
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0])
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Ce recadrage utilise les fonctions d'assistance suivantes:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]j[d] = i[permutation[d]].

Si feature_group_count = 1 et batch_group_count = 1, alors pour tous les output_spatial_index dans index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product, où :

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Si feature_group_count > 1 :

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension).

Si batch_group_count > 1 :

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Pour les types quantifiés, effectue dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Pour les types hybrides quantifiés, exécute hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur ou un tenseur quantifié par tenseur ; (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs un tenseur ou un tenseur quantifié ; (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides Constante de Tensor unidimensionnelle de type si64 (C2-C3), (C25)
(I4) padding Constante de tenseur bidimensionnel de type si64 (C4), (C25)
(I5) lhs_dilation Constante de tenseur à dimension 1 de type si64 (C5-C6), (C25)
(I6) rhs_dilation Constante de tenseur à dimension 1 de type si64 (C7-C8), (C25)
(I7) window_reversal Constante de Tensor unidimensionnelle de type i1 (C9)
(I8) input_batch_dimension constante de type si64 (C10), (C13), (C25)
(I9) input_feature_dimension constante de type si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Constante de tenseur à dimension 1 de type si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension constante de type si64 (C14), (C18)
(I12) kernel_output_feature_dimension constante de type si64 (C15-C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C17-C18), (C25)
(I14). output_batch_dimension constante de type si64 (C20), (C25)
(I15) output_feature_dimension constante de type si64 (C20), (C25), (C30)
(I16). output_spatial_dimensions Constante de tenseur à dimension 1 de type si64 (C19-C20), (C25)
(I17) feature_group_count constante de type si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count constante de type si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config Nombre variable d'énumérations de DEFAULT, HIGH et HIGHEST (C24)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C25-C28), (C30), (C32-34)

Contraintes

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Étant donné input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] :
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Compte kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] donné :
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Avec output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] :
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) est défini comme suit :
    • dim(lhs, input_batch_dimension) / batch_group_count si result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si result_dim = output_feature_dimension.
    • num_windows sinon, où :
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si l'opération utilise des tenseurs non linéarisables :
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si l'opération utilise des tenseurs quantifiés :
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Si is_per_axis_quantized(result), alors quantization_dimension(result) = output_feature_dimension.
    • Si is_quantized(lhs) :
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Si is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result).
    • Si !is_quantized(lhs) :
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Exemples

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

 Autres exemples

cosinus

Sémantique

Effectue une opération cosinus élément par élément sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : cos d'IEEE-754.
  • Pour les nombres complexes : cosinus complexe.
  • Pour les types quantifiés : dequantize_op_quantize(cosine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

 Autres exemples

count_leading_zeros

Sémantique

Effectue un comptage par élément du nombre de bits à zéro au début du tenseur operand et produit un tenseur result.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

 Autres exemples

custom_call

Sémantique

Encapsule une opération call_target_name définie par l'implémentation qui prend inputs et called_computations et produit results. has_side_effect, backend_config et api_version peuvent être utilisés pour fournir des métadonnées supplémentaires définies par l'implémentation.

Pour le moment, cette opération contient une collection de métadonnées assez désorganisée qui reflète l'évolution organique de son opération équivalente dans le compilateur XLA. À l'avenir, nous prévoyons d'unifier ces métadonnées (numéro 741).

Entrées

Libellé Nom Type
(I1) inputs nombre de valeurs variable
(I2) call_target_name constante de type string
(I3) has_side_effect constante de type i1
(I4) backend_config constante de type string ou dictionnaire d'attributs
(I5) api_version constante de type si32
(I6) called_computations nombre variable de constantes de type string

Sorties

Nom Type
results nombre de valeurs variable

Exemples

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

diviser

Sémantique

Effectue une division élément par élément des tenseurs lhs et rhs du dividende et du diviseur, et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les entiers: division entière qui produit le quotient algébrique, en ignorant toute partie fractionnaire.
  • Pour les nombres à virgule flottante : division d'IEEE-754.
  • Pour les nombres complexes : division complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

 Autres exemples

dot_general

Sémantique

Calcule les produits scalaires entre des tranches de lhs et des tranches de rhs, et produit un tenseur result.

Plus formellement, result[result_index] = dot_product, où:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_indexsize(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) et size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Pour les types quantifiés, effectue dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Pour les types de quantification hybride, effectue hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config contrôle le compromis entre vitesse et précision pour les calculs sur les backends d'accélérateur. Il peut s'agir de l'un des éléments suivants (pour le moment, la sémantique de ces valeurs d'énumération est sous-spécifiée, mais nous prévoyons de résoudre ce problème dans la version #755):

  • DEFAULT: calcul le plus rapide, mais approximation la moins précise du nombre d'origine.
  • HIGH : calcul plus lent, mais approximation plus précise du nombre d'origine.
  • HIGHEST: calcul le plus lent, mais approximation la plus précise du nombre d'origine.

Un DotAlgorithm définit les principales propriétés de l'algorithme utilisé pour implémenter l'opération de point, qui définit également la précision. Si les champs d'attribut de l'algorithme sont définis, precision_config doit être DEFAULT. DotAlgorithms n'a pas de valeur par défaut, car les paramètres par défaut sont définis par l'implémentation. Par conséquent, tous les champs d'algorithme de point peuvent être définis sur None pour spécifier un algorithme de point vide, qui utilisera plutôt la valeur precision_config.

Les champs DotAlgorithm incluent les suivants :

  • lhs_precision_type et rhs_precision_type, précisions auxquelles le côté gauche et le côté droit de l'opération sont arrondis. Les types de précision sont indépendants des types de stockage des entrées et des sorties.
  • accumulation_type : précision utilisée pour l'accumulation.
  • lhs_component_count, rhs_component_count et num_primitive_operations s'appliquent lorsque nous effectuons un algorithme qui décompose le membre gauche et/ou le membre droit en plusieurs composants et effectue plusieurs opérations de produit scalaire "primaires" sur ces valeurs, généralement pour émuler une précision plus élevée (par exemple, Utiliser le type de données d'IA bfloat16 pour les calculs à plus haute précision : bf16_6x tf32_3x, etc.). Pour les algorithmes sans décomposition, ces valeurs doivent être définies sur 1.
  • allow_imprecise_accumulation pour spécifier si l'accumulation avec une précision inférieure est autorisée pour certaines étapes (par exemple, CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Exemples d'attributs DotAlgorithm :

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

C'est aux implémentations de décider des combinaisons acceptées. En général, il n'est pas garanti que chaque algorithme soit compatible avec chaque type d'accélérateur par le consommateur de StableHLO. Si un algorithme donné n'est pas compatible, une erreur doit être générée au lieu de passer à une autre option. La validation StableHLO fournit une validation du meilleur effort, ce qui empêche les algorithmes qui ne sont pas connus pour être compatibles avec aucun matériel.

Consultez xla_data.proto > Algorithm pour connaître certaines valeurs d'algorithme acceptées. La demande n° 2483 décrit le plan visant à créer un document centralisé sur les algorithmes compatibles par backend.

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur ou un tenseur quantifié par tenseur ; (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs un tenseur ou un tenseur quantifié ; (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions Constante de Tensor unidimensionnelle de type si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Constante de Tensor unidimensionnelle de type si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Constante de tenseur à dimension 1 de type si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config nombre variadique d'énumérations de DEFAULT, HIGH et HIGHEST (C11), (C21)
(I8) lhs_precision_type FloatType ou TensorFloat32 (C21)
(I9) rhs_precision_type FloatType ou TensorFloat32 (C21)
(I10) accumulation_type FloatType ou TensorFloat32 (C21)
(I11) lhs_component_count constante de type si32 (C21), (C22)
(I12) rhs_component_count constante de type si32 (C21), (C23)
(I13). num_primitive_operations constante de type si32 (C21), (C24)
(I14) allow_imprecise_accumulation constante de type bool (C21)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C12), (C14), (C18-C20)

Contraintes

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Si l'opération utilise des tenseurs non linéarisables :
    • (C13) element_type(lhs) = element_type(rhs).
  • Si l'opération utilise des tenseurs quantifiés :
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) n'est pas dans rhs_contracting_dimensions.
    • Si is_quantized(lhs) :
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Si is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result).
    • Si !is_quantized(lhs) :
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • Si !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation) :
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Exemples

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

 Autres exemples

dynamic_broadcast_in_dim

Sémantique

Cette opération est fonctionnellement identique à l'opération broadcast_in_dim, mais la forme du résultat est spécifiée de manière dynamique via output_dimensions.

L'opération accepte également les attributs facultatifs known_expanding_dimensions et known_nonexpanding_dimensions pour exprimer des connaissances statiques sur le comportement d'expansion des dimensions. Si aucune valeur n'est spécifiée, toutes les dimensions sont supposées pouvoir s'étendre.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C2), (C5-C6), (C9)
(I2) output_dimensions Tensor unidimensionnel de type entier (C7)
(I3) broadcast_dimensions Tensor constante à dimension 1 de type entier (C2-C6)
(I4) known_expanding_dimensions Tensor constante à dimension 1 de type entier (C8-C9)
(I5) known_nonexpanding_dimensions Tensor constante à dimension 1 de type entier (C8-C9)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C1), (C3), (C5-C7)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), à l'exception de quantization_dimension(operand), scales(operand) et zero_points(operand), qui peuvent différer de quantization_dimension(result), scales(result) et zero_points(result), respectivement, dans le cas contraire.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Pour tous les d de axes(operand) :
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si is_per_axis_quantized(result) :
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si la valeur est dim(operand, quantization_dimension(operand)) = 1, alors scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand).

Exemples

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 Autres exemples

dynamic_conv

Sémantique

Cette opération est fonctionnellement identique à l'opération de convolution, mais le remplissage est spécifié de manière dynamique via padding.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs un tenseur ou un tenseur quantifié ; (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding Tensor bidimensionnel de type entier (C4)
(I4) window_strides Constante de tenseur à dimension 1 de type si64 (C2-C3)
(I5) lhs_dilation Constante de tenseur à dimension 1 de type si64 (C5-C6)
(I6) rhs_dilation Constante de tenseur à dimension 1 de type si64 (C7-C8)
(I7) window_reversal Constante de Tensor unidimensionnelle de type i1 (C9)
(I8) input_batch_dimension constante de type si64 (C10), (C13)
(I9) input_feature_dimension constante de type si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Constante de tenseur à dimension 1 de type si64 (C12), (C13)
(I11) kernel_input_feature_dimension constante de type si64 (C14), (C18)
(I12) kernel_output_feature_dimension constante de type si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions Constante de tenseur à dimension 1 de type si64 (C17-C18)
(I14) output_batch_dimension constante de type si64 (C20)
(I15) output_feature_dimension constante de type si64 (C20), (C29)
(I16) output_spatial_dimensions Constante de tenseur à dimension 1 de type si64 (C19-C20)
(I17) feature_group_count constante de type si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count constante de type si64 (C10), (C15), (C22), (C23)
(I19). precision_config Nombre variable d'énumérations de DEFAULT, HIGH et HIGHEST (C24)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C25-C27), (C29), (C31-C33)

Contraintes

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Étant donné input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] :
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Compte kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] donné :
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Avec output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] :
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) est défini comme suit :
    • dim(lhs, input_batch_dimension) / batch_group_count si result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si result_dim = output_feature_dimension.
    • num_windows sinon, où :
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si l'opération utilise des tenseurs non linéarisables :
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si l'opération utilise des tenseurs quantifiés :
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Si is_per_axis_quantized(result), alors quantization_dimension(result) = output_feature_dimension.
    • Si is_quantized(lhs) :
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Si is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result).
    • Si !is_quantized(lhs) :
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Exemples

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

 Autres exemples

dynamic_gather

Sémantique

Cette opération est fonctionnellement identique à l'opération gather, avec l'slice_sizes spécifié dynamiquement en tant que valeur.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1), (C7), (C10-C12), (C14)
(I2) start_indices Tensor de type entier (C2), (C3), (C13)
(I3) slice_sizes Tensor 1D de type entier (C8), (C11-C13)
(I4) offset_dims Constante de tenseur à dimension 1 de type si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C6-C8), (C13)
(I6) start_index_map Constante de tenseur à dimension 1 de type si64 (C3), (C9), (C10)
(I7) index_vector_dim constante de type si64 (C2), (C3), (C13)
(I8) indices_are_sorted constante de type i1

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C5), (C13-C14)

Contraintes

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) où :
    • batch_dim_sizes = shape(start_indices), sauf que la taille de dimension de start_indices correspondant à index_vector_dim n'est pas incluse.
    • offset_dim_sizes = shape(slice_sizes), sauf que les tailles des dimensions dans slice_sizes correspondant à collapsed_slice_dims ne sont pas incluses.
    • combine place batch_dim_sizes sur les axes correspondant à batch_dims et offset_dim_sizes sur les axes correspondant à offset_dims.
  • (C14) element_type(operand) = element_type(result).

Exemples

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

 Autres exemples

dynamic_iota

Sémantique

Cette opération est fonctionnellement identique à l'opération iota, mais la forme du résultat est spécifiée de manière dynamique via output_shape.

Entrées

Libellé Nom Type Contraintes
(I1) output_shape Tensor unidimensionnel de type entier (C1), (C2)
(I2) iota_dimension si64 (C1)

Sorties

Nom Type Contraintes
result un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C2)

Contraintes

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Exemples

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

 Autres exemples

dynamic_pad

Sémantique

Cette opération est fonctionnellement identique à l'opération pad, mais avec edge_padding_low, edge_padding_high et interior_padding spécifiés dynamiquement en tant que valeurs.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1), (C2), (C4)
(I2) padding_value Tensor à dimension 0 ou tensor quantifié par tensor (C1)
(I3) edge_padding_low Tensor 1D de type entier (C1), (C4)
(I4) edge_padding_high Tensor 1D de type entier (C1), (C4)
(I5) interior_padding Tensor unidimensionnel de type entier (C2-C4)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C3-C6)

Contraintes

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemples

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Autres exemples

dynamic_reshape

Sémantique

Cette opération est fonctionnellement identique à l'opération reshape, mais la forme du résultat est spécifiée de manière dynamique via output_shape.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié ; (C1-C3)
(I2) output_shape Tensor 1D de type entier (C4)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C1-C4)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf que quantization_dimension(operand) et quantization_dimension(result) peuvent être différents.
  • (C2) size(operand) = size(result).
  • (C3) Si is_per_axis_quantized(operand) :
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

 Autres exemples

dynamic_slice

Sémantique

Extrait une tranche du operand à l'aide d'index de départ calculés dynamiquement et produit un Tensor result. start_indices contient les indices de début de la tranche pour chaque dimension pouvant être ajustée, et slice_sizes contient les tailles de la tranche pour chaque dimension. Plus formellement, result[result_index] = operand[operand_index], où :

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1), (C2), (C4)
(I2) start_indices nombre variadique de Tensors à 0 dimensions de type entier (C2), (C3)
(I3) slice_sizes Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C5)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Exemples

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

 Autres exemples

dynamic_update_slice

Sémantique

Génère un tenseur result égal au tenseur operand, sauf que la tranche commençant à start_indices est mise à jour avec les valeurs de update. Plus formellement, result[result_index] est défini comme suit :

  • update[update_index] si 0 <= update_index < shape(update) où :
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • Sinon, operand[result_index].

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C4), (C6)
(I2) update un tenseur ou un tenseur quantifié par tenseur ; (C2), (C3), (C6)
(I3) start_indices nombre variadique de Tensors à 0 dimensions de type entier (C4), (C5)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Exemples

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

 Autres exemples

exponentiel

Sémantique

Effectue une opération exponentielle au niveau des éléments sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : exp d'IEEE-754.
  • Pour les nombres complexes: exponentiel complexe
  • Pour les types quantifiés : dequantize_op_quantize(exponential, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

 Autres exemples

exponential_minus_one

Sémantique

Effectue une opération exponentielle moins un par élément sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante: expm1 d'IEEE-754.
  • Pour les nombres complexes : l'exponentielle complexe moins un.
  • Pour les types quantifiés : dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

 Autres exemples

fft

Sémantique

Effectue les transformations de Fourier directe et inverse pour les entrées/sorties réelles et complexes.

fft_type est l'un des éléments suivants :

  • FFT : FFT de type "forward" complexe à complexe.
  • IFFT : FFT inverse complexe à complexe.
  • RFFT : FFT directe réelle-complexe.
  • IRFFT : FFT inverse réel-complexe (c'est-à-dire, prend complexe, renvoie réel).

Plus formellement, compte tenu de la fonction fft, qui prend en entrée des Tensors unidimensionnels de types complexes, produit des Tensors unidimensionnels de mêmes types que la sortie et calcule la transformation de Fourier discrète:

Pour fft_type = FFT, result est défini comme le résultat final d'une série de calculs L où L = size(fft_length). Par exemple, pour L = 3 :

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

De plus, étant donné la fonction ifft qui a la même signature de type et calcule l'inverse de fft :

Pour fft_type = IFFT, result est défini comme l'inverse des calculs pour fft_type = FFT. Par exemple, pour L = 3 :

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

De plus, étant donné la fonction rfft qui prend des tenseurs unidimensionnels de types à virgule flottante, produit des tenseurs unidimensionnels de types complexes de la même sémantique à virgule flottante et fonctionne comme suit :

  • rfft(real_operand) = truncated_result
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand)
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Lorsque la transformée de Fourier discrète est calculée pour des opérandes réels, les premiers éléments N/2 + 1 du résultat définissent de manière non ambiguë le reste du résultat. Le résultat de rfft est donc tronqué pour éviter de calculer des éléments redondants.)

Pour fft_type = RFFT, result est défini comme le résultat final d'une série de calculs L où L = size(fft_length). Par exemple, pour L = 3 :

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Enfin, étant donné la fonction irfft qui a la même signature de type et calcule l'inverse de rfft :

Pour fft_type = IRFFT, result est défini comme l'inverse des calculs pour fft_type = RFFT. Par exemple, pour L = 3 :

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur de type à virgule flottante ou complexe (C1), (C2), (C4), (C5)
(I2) fft_type énumération de FFT, IFFT, RFFT et IRFFT (C2), (C5)
(I3) fft_length Constante de Tensor unidimensionnelle de type si64 (C1), (C3), (C4)

Sorties

Nom Type Contraintes
result Tensor de type complexe ou à virgule flottante (C2), (C4), (C5)

Contraintes

  • (C1) size(fft_length) <= rank(operand).
  • (C2) La relation entre les types d'éléments operand et result varie :
    • Si fft_type = FFT, element_type(operand) et element_type(result) ont le même type complexe.
    • Si fft_type = IFFT, element_type(operand) et element_type(result) ont le même type complexe.
    • Si fft_type = RFFT, element_type(operand) est un type à virgule flottante et element_type(result) est un type complexe de la même sémantique à virgule flottante.
    • Si fft_type = IRFFT, element_type(operand) est un type complexe et element_type(result) est un type à virgule flottante de la même sémantique à virgule flottante.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Si parmi operand et result, il existe un tenseur real d'un type à virgule flottante, alors shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand), sauf :
    • Si fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Si fft_type = IRFFT, alors dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Exemples

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

étage

Sémantique

Effectue un prix plancher par élément du Tensor operand et produit un Tensor result. Met en œuvre l'opération roundToIntegralTowardNegative de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(floor, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

 Autres exemples

rassembler

Sémantique

Regroupe les tranches du Tensor operand à partir des décalages spécifiés dans start_indices et génère un Tensor result.

Le diagramme suivant montre comment les éléments de result sont mappés sur les éléments de operand à l'aide d'un exemple concret. Le diagramme sélectionne quelques exemples d'indices result et explique en détail à quels indices operand ils correspondent.

rassembler

Plus formellement, result[result_index] = operand[operand_index], où:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index est défini comme suit :
    • start_indices[bi0, ..., :, ..., biN], où bi sont des éléments individuels dans batch_index et : est inséré à l'indice index_vector_dim, si index_vector_dim < rank(start_indices).
    • Sinon, [start_indices[batch_index]].
  • Pour d_operand dans axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) si d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 dans les autres cas.
  • Pour d_operand dans axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] si d_operand = operand_batching_dims[i_batching] et d_start = start_indices_batching_dims[i_batching].
    • Sinon, full_batching_index[d_operand] = 0.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], où oi sont des éléments individuels dans offset_index, et 0 est inséré aux indices de collapsed_slice_dims et operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Si indices_are_sorted est true, l'implémentation peut supposer que les start_indices sont triées par rapport à start_index_map. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2 à partir de indices(result), full_start_index(i1) <= full_start_index(i2).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices Tensor de type entier (C2-C3), (C14), (C17), (C22)
(I3) offset_dims Constante de tenseur à dimension 1 de type si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims Constante de tenseur à dimension 1 de type si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims Constante de tenseur à dimension 1 de type si64 (C13-C17)
(I7) start_index_map Constante de tenseur à dimension 1 de type si64 (C3), (C18-C19)
(I8) index_vector_dim constante de type si64 (C2-C3), (C15), (C22)
(I9) slice_sizes Constante de tenseur à dimension 1 de type si64 (C9), (C12), (C20-C22)
(I10) indices_are_sorted constante de type i1

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C5), (C22-C23)

Contraintes

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) où :
    • batch_dim_sizes = shape(start_indices), sauf que la taille de dimension de start_indices correspondant à index_vector_dim n'est pas incluse.
    • offset_dim_sizes = slice_sizes, sauf que les tailles de dimension dans slice_sizes correspondant à collapsed_slice_dims et operand_batching_dims ne sont pas incluses.
    • combine place batch_dim_sizes sur les axes correspondant à batch_dims et offset_dim_sizes sur les axes correspondant à offset_dims.
  • (C23) element_type(operand) = element_type(result).

Exemples

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

 Autres exemples

get_dimension_size

Sémantique

Renvoie la taille de l'dimension donnée de l'operand. Plus formellement, result = dim(operand, dimension). La sémantique ne concerne que le composant de forme du type. Le type d'élément peut être n'importe quoi.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié ; (C1)
(I2) dimension constante de type si64 (C1)

Sorties

Nom Type
result Tensor de dimension 0 de type si32

Contraintes

  • (C1) 0 <= dimension < rank(operand).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Autres exemples

get_tuple_element

Sémantique

Extraction de l'élément à la position index du tuple operand et production d'un result. Plus formellement, result = operand[index].

Entrées

Libellé Nom Type Contraintes
(I1) operand tuple (C1), (C2)
(I2) index constante de type si32 (C1), (C2)

Sorties

Nom Type Contraintes
result tout type compatible (C2)

Contraintes

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Exemples

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

 Autres exemples

si

Sémantique

Génère la sortie de l'exécution d'une seule fonction à partir de true_branch ou false_branch, en fonction de la valeur de pred. Plus formellement, result = pred ? true_branch() : false_branch().

Entrées

Libellé Nom Type Contraintes
(I1) pred Tensor de dimension 0 de type i1
(I2) true_branch fonction (C1-C3)
(I3) false_branch fonction (C1), (C2)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C3)

Contraintes

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Exemples

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

 Autres exemples

imag

Sémantique

Extrait la partie imaginaire, par élément, de operand et produit un tenseur result. Plus formellement, pour chaque élément x : imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type complexe ou à virgule flottante (C1), (C2)

Sorties

Nom Type Contraintes
result tenseur de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • Sinon, element_type(operand).

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

 Autres exemples

infeed

Sémantique

Lit les données du flux d'entrée et génère results.

La sémantique de infeed_config est définie par l'implémentation.

results se compose de valeurs de charge utile qui viennent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (numéro 670).

Entrées

Libellé Nom Type
(I1) token token
(I2) infeed_config constante de type string

Sorties

Nom Type Contraintes
results Nombre variable de tenseurs, de tenseurs quantifiés ou de jetons (C1-C3)

Contraintes

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Exemples

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

 Autres exemples

iota

Sémantique

Remplit un Tensor output avec des valeurs dans l'ordre croissant à partir de zéro le long de la dimension iota_dimension. Plus formellement,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Entrées

Libellé Nom Type Contraintes
(I1) iota_dimension si64 (C1)

Sorties

Nom Type Contraintes
output un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) 0 <= iota_dimension < rank(output).

Exemples

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

 Autres exemples

is_finite

Sémantique

Effectue une vérification par élément pour déterminer si la valeur dans x est finie (c'est-à-dire qu'elle n'est ni +Inf, ni -Inf, ni NaN) et produit un tenseur y. Implémente l'opération isFinite de la spécification IEEE-754. Pour les types quantifiés, le résultat est toujours true.

Entrées

Libellé Nom Type Contraintes
(I1) x un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
y tenseur de type booléen (C1)

Contraintes

  • (C1) shape(x) = shape(y).

Exemples

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

 Autres exemples

log

Sémantique

Effectue une opération logarithmique élément par élément sur le tenseur operand et génère un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : log d'IEEE-754.
  • Pour les nombres complexes: logarithme complexe.
  • Pour les types quantifiés : dequantize_op_quantize(log, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

 Autres exemples

log_plus_one

Sémantique

Effectue le logarithme élément par élément plus une opération sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : logp1 d'IEEE-754.
  • Pour les nombres complexes : logarithme complexe plus un.
  • Pour les types quantifiés : dequantize_op_quantize(log_plus_one, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Autres exemples

logistique

Sémantique

Effectue une opération logistique au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante : division(1, addition(1, exp(-x))) d'IEEE-754.
  • Pour les nombres complexes: logistique complexe.
  • Pour les types quantifiés : dequantize_op_quantize(logistic, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

 Autres exemples

carte

Sémantique

Applique une fonction de mappage computation à inputs le long de dimensions et produit un tenseur result.

Plus formellement, result[result_index] = computation(inputs...[result_index]).

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1-C4)
(I2) dimensions Constante de tenseur à dimension 1 de type si64 (C3)
(I3) computation fonction (C4)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1), (C4)

Contraintes

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation est de type (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, où Ei = element_type(inputs[i]) et E' = element_type(result).

Exemples

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Autres exemples

maximum

Sémantique

Effectue une opération de valeur maximale au niveau des éléments sur les tensors lhs et rhs, et génère un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les valeurs booléennes : opérateur logique "OR".
  • Pour les entiers : valeur maximale de l'entier.
  • Pour les nombres à virgule flottante : maximum d'IEEE-754.
  • Pour les nombres complexes : valeur maximale lexicographique pour la paire (real, imaginary). L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (numéro 560).
  • Pour les types quantifiés :
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur ou un tenseur quantifié par tenseur ; (C1)
(I2) rhs un tenseur ou un tenseur quantifié par tenseur ; (C1)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

 Autres exemples

minimum

Sémantique

Effectue une opération min par élément sur les tenseurs lhs et rhs, et génère un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les valeurs booléennes : "AND" logique
  • Pour les entiers : valeur minimale de l'entier.
  • Pour les nombres à virgule flottante : minimum d'IEEE-754.
  • Pour les nombres complexes : valeur minimale lexicographique pour la paire (real, imaginary). L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (numéro 560).
  • Pour les types quantifiés :
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur ou un tenseur quantifié par tenseur ; (C1)
(I2) rhs un tenseur ou un tenseur quantifié par tenseur ; (C1)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

 Autres exemples

multiplier

Sémantique

Effectue le produit élément par élément de deux tenseurs lhs et rhs, et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes : "AND" logique
  • Pour les entiers: multiplication d'entiers.
  • Pour les nombres à virgule flottante: multiplication d'IEEE-754.
  • Pour les nombres complexes : multiplication complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur ou un tenseur quantifié par tenseur ; (C1)
(I2) rhs un tenseur ou un tenseur quantifié par tenseur ; (C1)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

 Autres exemples

negate

Sémantique

Effectue une négation par élément du Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les entiers signés: négation des entiers.
  • Pour les entiers non signés : bitcast vers un entier signé, négation d'entier, bitcast vers un entier non signé.
  • Pour les nombres à virgule flottante: negate d'IEEE-754.
  • Pour les nombres complexes : négation complexe.
  • Pour les types quantifiés : dequantize_op_quantize(negate, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

 Autres exemples

not

Sémantique

Effectue une opération NOT élément par élément du tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les valeurs booléennes : négation logique.
  • Pour les entiers: NOT (PAS) au niveau du bit.

Arguments

Nom Type Contraintes
operand Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

 Autres exemples

optimization_barrier

Sémantique

Assurez-vous que les opérations qui génèrent le operand sont exécutées avant toute opération qui dépend de result et empêche les transformations de compilation de déplacer les opérations au-delà de la barrière. En dehors de cela, l'opération est une identité, c'est-à-dire result = operand.

Arguments

Nom Type Contraintes
operand nombre variable de tenseurs, de tenseurs quantifiés par tenseur ou de jetons (C1)

Sorties

Nom Type Contraintes
result nombre variable de tenseurs, de tenseurs quantifiés par tenseur ou de jetons (C1)

Contraintes

  • (C1) type(operand...) = type(result...).

Exemples

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

 Autres exemples

ou

Sémantique

Effectue une opération OU élément par élément sur deux tenseurs lhs et rhs, et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les valeurs booléennes: opérateur logique OU.
  • Pour les entiers: opération OR au niveau du bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier ou booléen (C1)
(I2) rhs un tenseur de type entier ou booléen ; (C1)

Sorties

Nom Type Contraintes
result un tenseur de type entier ou booléen ; (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

 Autres exemples

flux de sortie

Sémantique

Écrit inputs dans le flux de sortie et génère un jeton result.

La sémantique de outfeed_config est définie par l'implémentation.

Entrées

Libellé Nom Type
(I1) inputs Nombre variable de tenseurs ou de tenseurs quantifiés
(I2) token token
(I3) outfeed_config constante de type string

Sorties

Nom Type
result token

Exemples

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

 Autres exemples

pad

Sémantique

Élargit operand en ajoutant des marges internes autour du tenseur et entre les éléments du tenseur avec l'padding_value donné.

edge_padding_low et edge_padding_high spécifient respectivement la quantité de marge intérieure ajoutée à l'extrémité inférieure (à côté de l'indice 0) et à l'extrémité supérieure (à côté de l'indice le plus élevé) de chaque dimension. La valeur de marge intérieure peut être négative, où la valeur absolue de la marge intérieure négative indique le nombre d'éléments à supprimer de la dimension spécifiée.

interior_padding spécifie la quantité de marge intérieure ajoutée entre deux éléments de chaque dimension, qui peut ne pas être négative. La marge intérieure se produit avant la marge extérieure, de sorte que la marge extérieure négative supprime des éléments de l'opérande à marge intérieure.

Plus formellement, result[result_index] est défini comme suit :

  • operand[operand_index] si result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • Sinon, padding_value.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1), (C2), (C4)
(I2) padding_value Tensor à dimension 0 ou tensor quantifié par tensor (C1)
(I3) edge_padding_low Constante de tenseur à dimension 1 de type si64 (C1), (C4)
(I4) edge_padding_high Constante de tenseur à dimension 1 de type si64 (C1), (C4)
(I5) interior_padding Constante de Tensor unidimensionnelle de type si64 (C2-C4)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C3-C6)

Contraintes

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemples

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Autres exemples

partition_id

Sémantique

Génère partition_id du processus en cours.

Sorties

Nom Type
result Tensor de dimension 0 de type ui32

Exemples

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

 Autres exemples

popcnt

Sémantique

Effectue un comptage par élément du nombre de bits définis dans le tenseur operand et produit un tenseur result.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Autres exemples

puissance

Sémantique

Effectue l'exponentiation élément par élément du tenseur lhs par le tenseur rhs et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les entiers : exponentiation entière.
  • Pour les nombres à virgule flottante : pow d'IEEE-754.
  • Pour les nombres complexes: exponentielle complexe.
  • Pour les types quantifiés : dequantize_op_quantize(power, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

 Autres exemples

real

Sémantique

Extrait la partie réelle, par élément, de la operand et produit un tenseur result. Plus formellement, pour chaque élément x : real(x) = is_complex(x) ? real_part(x) : x.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type complexe ou à virgule flottante (C1), (C2)

Sorties

Nom Type Contraintes
result tenseur de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • Sinon, element_type(operand).

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

 Autres exemples

recv

Sémantique

Reçoit les données d'un canal avec channel_id et produit results.

Si is_host_transfer est true, l'opération transfère des données depuis l'hôte. Sinon, il transfère des données depuis un autre appareil. Cela signifie que la valeur est définie par l'implémentation. Cet indicateur duplique les informations fournies dans channel_type. Nous prévoyons donc de n'en conserver qu'un seul (#666).

results est constitué de valeurs de charge utile qui apparaissent en premier et d'un jeton qui apparaissent en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (numéro 670).

Entrées

Libellé Nom Type Contraintes
(I1) token token (C4)
(I2) channel_id constante de type si64
(I3) channel_type énumération de DEVICE_TO_DEVICE et HOST_TO_DEVICE (C1)
(I4) is_host_transfer constante de type i1 (C1)

Sorties

Nom Type Contraintes
results Nombre variable de tenseurs, de tenseurs quantifiés ou de jetons (C2-C4)

Contraintes

  • (C1) channel_type est défini comme suit :
    • HOST_TO_DEVICE si is_host_transfer = true,
    • Sinon, DEVICE_TO_DEVICE.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Exemples

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

 Autres exemples

reduce

Sémantique

Applique une fonction de réduction body à inputs et init_values le long de dimensions et produit des Tensors results.

L'ordre des réductions est défini par l'implémentation, ce qui signifie que body et init_values doivent former un monoide pour garantir que l'opération produit les mêmes résultats pour toutes les entrées sur toutes les implémentations. Cependant, cette condition n'est pas valable pour de nombreuses réductions populaires. Par exemple, l'addition à virgule flottante pour body et zéro pour init_values ne forment pas réellement un monoide, car l'addition à virgule flottante n'est pas associative.

Plus formellement, results...[j0, ..., jR-1] = reduce(input_slices_converted), où :

  • input_slices = inputs...[j0, ..., :, ..., jR-1], où : sont insérés à dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) pour un arbre binaire schedule où :
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est une arborescence binaire complète définie par l'implémentation dont le balayage dans l'ordre comprend les éléments suivants :
    • Valeurs input_slices_converted...[index], pour tous les index dans index_space(input_slices_converted) dans l'ordre lexicographique croissant de index.
    • Intercalés avec une quantité de init_values_converted définie par l'implémentation aux positions définies par l'implémentation.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1-C4), (C6), (C7)
(I2) init_values nombre variadique de Tensors à 0 dimensions ou de Tensors quantifiés par Tensor (C2), (C3)
(I3) dimensions Constante de tenseur à dimension 1 de type si64 (C4), (C5), (C7)
(I4) body fonction (C6)

Sorties

Nom Type Contraintes
results nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C3), (C7), (C8)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), à l'exception des tailles de dimension de inputs... correspondant à dimensions.
  • (C8) element_type(results[i]) = Ei pour tous les i de [0,N).

Exemples

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

 Autres exemples

reduce_precision

Sémantique

Effectue une conversion élément par élément de operand vers un autre type à virgule flottante qui utilise exponent_bits et mantissa_bits, puis vers le type à virgule flottante d'origine et produit un tenseur output.

Plus formellement :

  • Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine à la valeur la plus proche pouvant être représentée avec mantissa_bits à l'aide de la sémantique roundToIntegralTiesToEven.
  • Ensuite, si mantissa_bits est inférieur au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués à mantissa_bits.
  • Ensuite, si les bits d'exposant du résultat intermédiaire ne rentrent pas dans la plage fournie par exponent_bits, le résultat intermédiaire déborde vers l'infini à l'aide du signe d'origine ou sous-déborde vers zéro à l'aide du signe d'origine.
  • Pour les types quantifiés, effectue dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1)
(I2) exponent_bits constante de type si32 (C2)
(I3) mantissa_bits constante de type si32 (C3)

Sorties

Nom Type Contraintes
output Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Exemples

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

 Autres exemples

reduce_scatter

Sémantique

reduce_scatter

Dans chaque groupe de processus de la grille de processus StableHLO, effectue une réduction, à l'aide de computations, sur les valeurs du tenseur operand de chaque processus, divise le résultat de la réduction en parties le long de scatter_dimension et disperse les parties fractionnées entre les processus pour produire le result.

L'opération divise la grille de processus StableHLO en process_groups, qui est définie comme suit :

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group :

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] pour tous les sender dans process_group, où receiver_index = process_group.index(receiver).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension constante de type si64 (C1), (C2), (C8)
(I3) replica_groups Constante de tenseur bidimensionnel de type si64 (C3-C5)
(I4) channel_id constante de type si64 (C6)
(I5) use_global_device_ids constante de type i1 (C6)
(I6) computation fonction (C7)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C8-C9)

Contraintes

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Si use_global_device_ids = true, alors channel_id > 0.
  • (C7) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>), où is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand), sauf :
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

 Autres exemples

reduce_window

Sémantique

Applique une fonction de réduction body aux fenêtres de inputs et init_values, et produit results.

Le schéma suivant montre comment les éléments de results... sont calculés à partir de inputs... à l'aide d'un exemple concret.

reduce_window

Plus formellement, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (voir reduce) où :

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values Nombre variable de tenseurs à dimension 0 ou de tenseurs quantifiés par tenseur (C1), (C13)
(I3) window_dimensions Constante de Tensor unidimensionnelle de type si64 (C4), (C5), (C15)
(I4) window_strides Constante de tenseur à dimension 1 de type si64 (C6), (C7), (C15)
(I5) base_dilations Constante de tenseur à dimension 1 de type si64 (C8), (C9), (C15)
(I6) window_dilations Constante de tenseur à dimension 1 de type si64 (C10), (C11), (C15)
(I7) padding Constante de tenseur bidimensionnel de type si64 (C12), (C15)
(I8) body fonction (C13)

Sorties

Nom Type Contraintes
results nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1), (C14-C16)

Contraintes

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), où is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows où :
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei pour l'ensemble des i de [0,N).

Exemples

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

 Autres exemples

reste

Sémantique

Effectue le reste par élément des tensors lhs et rhs du dividende et du diviseur, et produit un tenseur result.

Plus formellement, le signe du résultat est extrait du dividende, et la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur. Le reste est calculé comme lhs - d * rhs, où d est donné par :

  • Pour les entiers: stablehlo.divide(lhs, rhs).
  • Pour les nombres à virgule flottante : division(lhs, rhs) d'IEEE-754 avec l'attribut d'arrondi roundTowardZero.
  • Pour les nombres complexes: à déterminer (#997).
  • Pour les types quantifiés :
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Pour les types d'éléments à virgule flottante, cette opération contraste avec l'opération remainder de la spécification IEEE-754, où d est une valeur entière la plus proche de la valeur exacte de lhs/rhs avec des liens pairs.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

 Autres exemples

replica_id

Sémantique

Génère replica_id du processus actuel.

Sorties

Nom Type
result Tensor de dimension 0 de type ui32

Exemples

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

 Autres exemples

reshape

Sémantique

Effectue un remodelage du Tensor operand en un Tensor result. Conceptuellement, cela revient à conserver la même représentation canonique, mais à modifier potentiellement la forme, par exemple de tensor<2x3xf32> à tensor<3x2xf32> ou tensor<6xf32>.

Plus formellement, result[result_index] = operand[operand_index], où result_index et operand_index ont la même position dans l'ordre lexicographique de index_space(result) et index_space(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié ; (C1-C3)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1-C3)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf que quantization_dimension(operand) et quantization_dimension(result) peuvent être différents.
  • (C2) size(operand) = size(result).
  • (C3) Si is_per_axis_quantized(operand) :
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

 Autres exemples

inverser

Sémantique

Inverse l'ordre des éléments dans operand le long de dimensions spécifié et génère un tenseur result. Plus formellement, result[result_index] = operand[operand_index], où :

  • operand_index[d] = dim(result, d) - result_index[d] - 1 si d en dimensions.
  • Sinon, operand_index[d] = result_index[d].

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1), (C3)
(I2) dimensions Constante de tenseur à dimension 1 de type si64 (C2), (C3)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1), (C3)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Exemples

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

 Autres exemples

rng

Sémantique

Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution et produit un tenseur result d'une forme shape donnée.

Si la valeur est rng_distribution = UNIFORM, les nombres aléatoires sont générés en suivant la distribution uniforme sur l'intervalle [a, b). Si la valeur est a >= b, le comportement n'est pas défini.

Si la valeur est rng_distribution = NORMAL, les nombres aléatoires sont générés selon la distribution normale, avec une moyenne = a et l'écart type = b. Si la valeur est b < 0, le comportement n'est pas défini.

La méthode exacte de génération des nombres aléatoires est définie par l'implémentation. Par exemple, ils peuvent être ou non déterministes, et ils peuvent ou non utiliser un état masqué.

Lors de discussions avec de nombreux partenaires, il est apparu que cette opération était effectivement obsolète. Nous envisageons donc de la supprimer à l'avenir (numéro 597).

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor à zéro dimension de type entier, booléen ou à virgule flottante (C1), (C2)
(I2) b Tensor 0 dimensionnel de type entier, booléen ou à virgule flottante (C1), (C2)
(I3) shape Constante de tenseur à dimension 1 de type si64 (C3)
(I4) rng_distribution énumération de UNIFORM et NORMAL (C2)

Sorties

Nom Type Contraintes
result un tenseur de type entier, booléen ou à virgule flottante ; (C1-C3)

Contraintes

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Si rng_distribution = NORMAL, alors is_float(a).
  • (C3) shape(result) = shape.

Exemples

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Sémantique

Renvoie un output rempli de bits aléatoires uniformes et un état de sortie output_state mis à jour à l'aide de l'algorithme de générateur de nombres pseudo-aléatoires rng_algorithm, étant donné un état initial initial_state. La sortie est garantie comme étant une fonction déterministe de initial_state, mais elle n'est pas garantie comme étant déterministe entre les implémentations.

rng_algorithm est l'un des éléments suivants :

  • DEFAULT: algorithme défini par l'implémentation.
  • THREE_FRY : variante de l'algorithme Threefry définie par l'implémentation*
  • PHILOX: variante de l'algorithme Philox définie par l'implémentation*

* Voir Salmon et al. SC 2011. Nombres aléatoires parallèles: c'est aussi simple que 1, 2, 3.

Entrées

Libellé Nom Type Contraintes
(I1) rng_algorithm énumération de DEFAULT, THREE_FRY et PHILOX (C2)
(I2) initial_state Tensor unidimensionnel de type ui64 (C1), (C2)

Sorties

Nom Type Contraintes
output_state Tensor 1D de type ui64 (C1)
output tenseur de type entier ou à virgule flottante

Contraintes

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) est défini comme suit :
    • définie par l'implémentation si rng_algorithm = DEFAULT.
    • 2 si rng_algorithm = THREE_FRY.
    • 2 ou 3 si rng_algorithm = PHILOX.

Exemples

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Sémantique

Effectue un arrondi au niveau des éléments vers l'entier le plus proche, en rompant les liens avec zéro sur le Tensor operand et produit un Tensor result. Implémente l'opération roundToIntegralTiesToAway de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Autres exemples

round_nearest_even

Sémantique

Effectue un arrondi par élément vers l'entier le plus proche, en cas d'égalité, vers l'entier pair, sur le tenseur operand et produit un tenseur result. Implémente l'opération roundToIntegralTiesToEven de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_even, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

 Autres exemples

rsqrt

Sémantique

Effectue une opération de racine carrée réciproque par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : rSqrt d'IEEE-754.
  • Pour les nombres complexes : racine carrée réciproque complexe.
  • Pour les types quantifiés : dequantize_op_quantize(rsqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

 Autres exemples

disperser

Sémantique

Génère des Tensors results égaux aux Tensors inputs, à la différence que plusieurs tranches spécifiées par scatter_indices sont mises à jour avec les valeurs updates à l'aide de update_computation.

Le schéma suivant montre comment les éléments de updates... sont mappés sur des éléments de results... à l'aide d'un exemple concret. Le diagramme sélectionne quelques exemples d'indices updates... et explique en détail à quels indices results... ils correspondent.

disperser

Plus formellement, pour tous les update_index de index_space(updates[0]) :

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index est défini comme suit :
    • scatter_indices[si0, ..., :, ..., siN], où si correspond à des éléments individuels dans update_scatter_index et : est inséré à l'index index_vector_dim, si index_vector_dim < rank(scatter_indices).
    • Sinon, [scatter_indices[update_scatter_index]].
  • Pour d_input dans axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] si d_input = scatter_dims_to_operand_dims[d_start].
    • Sinon, full_start_index[d_input] = 0.
  • Pour d_input dans axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] si d_input = input_batching_dims[i_batching] et d_start = scatter_indices_batching_dims[i_batching].
    • Sinon, full_batching_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], où wi sont des éléments individuels dans update_window_index, et 0 est inséré aux indices de inserted_window_dims et input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

Par conséquent, results = exec(schedule, inputs), où:

  • schedule est une permutation de index_space(updates[0]) définie par l'implémentation.
  • exec([update_index, ...], results) = exec([...], updated_results) où :
    • Si result_index est dans les limites de shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results est une copie de results avec results...[result_index] défini sur updated_values....
    • Sinon, procédez comme suit :
    • updated_results = results.
  • exec([], results) = results.

Si indices_are_sorted est défini sur true, l'implémentation peut supposer que les éléments scatter_indices sont triés par rapport à scatter_dims_to_operand_dims. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Si unique_indices est true, l'implémentation peut supposer que tous les indices result_index vers lesquels la diffusion est effectuée sont uniques. Si unique_indices est true, mais que les indices vers lesquels les données sont dispersées ne sont pas uniques, le comportement est indéfini.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices tenseur de type entier (C4), (C15), (C19), (C22)
(I3) updates nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C3-C6), (C8)
(I4) update_window_dims Constante de tenseur à dimension 1 de type si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims Constante de tenseur à dimension 1 de type si64 (C14-C18)
(I8) scatter_dims_to_operand_dims Constante de tenseur à dimension 1 de type si64 (C19-C21)
(I9) index_vector_dim constante de type si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted constante de type i1
(I11) unique_indices constante de type i1
(I12) update_computation fonction (C23)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors ou Tensors quantifiés par Tensor (C24-C25)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
    • size(input_batching_dims)`.
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) où :
    • update_scatter_dim_sizes = shape(scatter_indices), sauf que la taille de la dimension scatter_indices correspondant à index_vector_dim n'est pas incluse.
    • update_window_dim_sizes <= shape(inputs[0]), à l'exception des tailles de dimension dans inputs[0] correspondant à inserted_window_dims et input_batching_dims.
    • combine place update_scatter_dim_sizes sur les axes correspondant à update_scatter_dims et update_window_dim_sizes sur les axes correspondant à update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), où is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei pour tous les i dans [0,N).

Exemples

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

 Autres exemples

sélectionner

Sémantique

Génère un tenseur result dans lequel chaque élément est sélectionné à partir du tenseur on_true ou on_false en fonction de la valeur de l'élément correspondant de pred. Plus formellement, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], où pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Pour les types quantifiés, exécute dequantize_select_quantize(pred, on_true, on_false, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) pred tenseur de type i1 (C1)
(I2) on_true un tenseur ou un tenseur quantifié par tenseur ; (C1-C2)
(I3) on_false Tensor ou Tensor quantifié par Tensor (C2)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C2)

Contraintes

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Exemples

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

 Autres exemples

select_and_scatter

Sémantique

Disperse les valeurs du tenseur source à l'aide de scatter en fonction du résultat de reduce_window du tenseur input à l'aide de select et génère un tenseur result.

Le diagramme suivant montre comment les éléments de result sont calculés à partir de operand et source à l'aide d'un exemple concret.

select_and_scatter

Plus formellement:

  • selected_values = reduce_window_without_init(...) avec les entrées suivantes:

    • inputs = [operand].
    • window_dimensions, window_strides et padding, qui sont utilisés tels quels.
    • base_dilations = windows_dilations = 1.
    • body est défini comme suit:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    E = element_type(operand) et reduce_window_without_init fonctionnent exactement comme reduce_window, sauf que le schedule de l'reduce sous-jacent (voir reduce) n'inclut pas les valeurs d'initialisation. Il n'est actuellement pas spécifié ce qui se passe si la fenêtre correspondante ne comporte pas de valeurs (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) où :

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index si selected_values[source_index] contient l'élément operand de operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1-C4), (C6), (C8-C11)
(I2) source un tenseur ou un tenseur quantifié par tenseur ; (C1), (C2)
(I3) init_value Tensor à 0 dimensions ou Tensor quantifié par Tensor (C3)
(I4) window_dimensions Constante de tenseur à dimension 1 de type si64 (C2), (C4), (C5)
(I5) window_strides Constante de tenseur à dimension 1 de type si64 (C2), (C6), (C7)
(I6) padding Constante de tenseur bidimensionnel de type si64 (C2), (C8)
(I7) select fonction (C9)
(I8) scatter fonction (C10)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C11-C12)

Contraintes

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows où :
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select est de type (tensor<E>, tensor<E>) -> tensor<i1>, où E = element_type(operand).
  • (C10) scatter est de type (tensor<E>, tensor<E>) -> tensor<E>, où is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Exemples

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Autres exemples

envoyer

Sémantique

Envoie inputs à un canal channel_id et génère un jeton result.

Si is_host_transfer est défini sur true, l'opération transfère les données à l'hôte. Sinon, il transfère les données vers un autre appareil. Cela signifie que la valeur est définie par l'implémentation. Cet indicateur duplique les informations fournies dans channel_type. Nous prévoyons donc de n'en conserver qu'un seul (#666).

Entrées

Libellé Nom Type Contraintes
(I1) inputs Nombre variable de tenseurs ou de tenseurs quantifiés
(I2) token token
(I3) channel_id constante de type si64
(I4) channel_type énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST (C1)
(I5) is_host_transfer constante de type i1 (C1)

Sorties

Nom Type
result token

Contraintes

  • (C1) channel_type est défini comme suit :
    • DEVICE_TO_HOST si is_host_transfer = true,
    • Sinon, DEVICE_TO_DEVICE.

Exemples

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

 Autres exemples

shift_left

Sémantique

Effectue une opération de décalage à gauche au niveau des éléments sur le tenseur lhs d'un nombre de bits rhs et produit un tenseur result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur de type entier (C1)
(I2) rhs tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

 Autres exemples

shift_right_arithmetic

Sémantique

Effectue une opération de décalage à droite arithmétique par élément sur le tenseur lhs d'un nombre de bits rhs et produit un tenseur result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur de type entier (C1)
(I2) rhs tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

 Autres exemples

shift_right_logical

Sémantique

Effectue une opération de décalage logique à droite par élément sur le tenseur lhs d'un nombre de bits rhs et produit un tenseur result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur de type entier (C1)
(I2) rhs tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

 Autres exemples

signer

Sémantique

Renvoie le signe de l'operand par élément et produit un tenseur result. Plus formellement, pour chaque élément x, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Pour les types quantifiés, exécute dequantize_op_quantize(sign, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

 Autres exemples

sinus

Sémantique

Effectue une opération sinusoïdale au niveau des éléments sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : sin d'IEEE-754.
  • Pour les nombres complexes: sinus complexe.
  • Pour les types quantifiés: dequantize_op_quantize(sine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

 Autres exemples

slice

Sémantique

Extraction d'une tranche de l'operand à l'aide d'indices de début calculés de manière statique et production d'un tenseur result. start_indices contient les indices de début de la tranche pour chaque dimension, limit_indices contient les indices de fin (exclusifs) de la tranche pour chaque dimension et strides contient les pas pour chaque dimension.

Plus formellement, result[result_index] = operand[operand_index], où operand_index = start_indices + result_index * strides.

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié par tenseur ; (C1-C3), (C5)
(I2) start_indices Constante de tenseur à dimension 1 de type si64 (C2), (C3), (C5)
(I3) limit_indices Constante de tenseur à dimension 1 de type si64 (C2), (C3), (C5)
(I4) strides Constante de tenseur à dimension 1 de type si64 (C2), (C4)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié par tenseur ; (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Exemples

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Autres exemples

trier

Sémantique

Trie les tranches unidimensionnelles de inputs le long de la dimension dimension, selon une valeur comparator, et génère results.

Contrairement aux entrées similaires d'autres opérations, dimension autorise les valeurs négatives, avec la sémantique décrite ci-dessous. À l'avenir, cela pourrait être interdit pour des raisons de cohérence (numéro 1377).

Si is_stable est défini sur "true", le tri est stable, c'est-à-dire que l'ordre relatif des éléments considérés comme égaux par le comparateur est préservé. Dans le cas où il n'y a qu'une seule entrée, deux éléments e1 et e2 sont considérés comme égaux par le comparateur si et seulement si comparator(e1, e2) = comparator(e2, e1) = false. Consultez la formalisation ci-dessous pour voir comment cela se généralise à plusieurs entrées.

Plus formellement, pour tous les result_index de index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], où riN sont des éléments individuels dans result_index et : est inséré à adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • sort trie une tranche unidimensionnelle dans l'ordre non descendant, en supposant que comparator_together renvoie true si l'argument de gauche est inférieur au deuxième argument de droite.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C1-C5)
(I2) dimension constante de type si64 (C4)
(I3) is_stable constante de type i1
(I4) comparator fonction (C5)

Sorties

Nom Type Contraintes
results nombre variable de tenseurs ou tenseurs quantifiés par tenseur (C2), (C3)

Contraintes

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, où R = rank(inputs[0]).
  • (C5) comparator est de type (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, où Ei = element_type(inputs[i]).

Exemples

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Autres exemples

sqrt

Sémantique

Effectue une racine carrée élément par élément sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes:

  • Pour les nombres à virgule flottante : squareRoot d'IEEE-754.
  • Pour les nombres complexes : racine carrée complexe.
  • Pour les types quantifiés : dequantize_op_quantize(sqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

 Autres exemples

subtract

Sémantique

Effectue la soustraction par élément de deux Tensors lhs et rhs, et produit un Tensor result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les entiers: soustraction d'entiers.
  • Pour les nombres à virgule flottante : subtraction d'IEEE-754.
  • Pour les nombres complexes : soustraction complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Autres exemples

tan

Sémantique

Effectue une opération tangente au niveau des éléments sur le tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : tan d'IEEE-754.
  • Pour les nombres complexes: tangente complexe.
  • Pour les types quantifiés : dequantize_op_quantize(tan, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

 Autres exemples

tanh

Sémantique

Effectue une opération de tangente hyperbolique par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les nombres à virgule flottante : tanh d'IEEE-754.
  • Pour les nombres complexes: tangente hyperbolique complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(tanh, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

 Autres exemples

transposer

Sémantique

Permute les dimensions du tenseur operand à l'aide de permutation et génère un tenseur result. Plus formellement, result[result_index] = operand[operand_index], où result_index[d] = operand_index[permutation[d]].

Entrées

Libellé Nom Type Contraintes
(I1) operand un tenseur ou un tenseur quantifié ; (C1-C4)
(I2) permutation Constante de Tensor unidimensionnelle de type si64 (C2-C4)

Sorties

Nom Type Contraintes
result un tenseur ou un tenseur quantifié ; (C1), (C3-C4)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf que quantization_dimension(operand) et quantization_dimension(result) peuvent être différents.
  • (C2) permutation est une permutation de range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Si is_per_axis_quantized(result), alors quantization_dimension(operand) = permutation(quantization_dimension(result)).

Exemples

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Autres exemples

triangular_solve

Sémantique

Résout des lots de systèmes d'équations linéaires avec des matrices de coefficients triangulaires inférieures ou supérieures.

Plus formellement, étant donné a et b, result[i0, ..., iR-3, :, :] est la solution de op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] lorsque left_side est true ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] lorsque left_side est false, en résolvant la variable xop(a) est déterminée par transpose_a, qui peut être l'une des valeurs suivantes:

  • NO_TRANSPOSE : effectuez l'opération à l'aide de a tel quel.
  • TRANSPOSE : effectue une opération sur la transposition de a.
  • ADJOINT : effectue une opération sur la transposée conjuguée de a.

Les données d'entrée ne sont lues que dans le triangle inférieur de a, si lower est true ou dans le triangle supérieur de a, dans le cas contraire. Les données de sortie sont renvoyées dans le même triangle. Les valeurs de l'autre triangle sont définies par l'implémentation.

Si unit_diagonal est vrai, l'implémentation peut supposer que les éléments de la diagonale de a sont égaux à 1, sinon le comportement est indéfini.

Pour les types quantifiés, exécute dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1-C3)
(I2) b un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1-C4)
(I3) left_side constante de type i1 (C3)
(I4) lower constante de type i1
(I5) unit_diagonal constante de type i1
(I6) transpose_a Énumération de NO_TRANSPOSE, TRANSPOSE et ADJOINT

Sorties

Nom Type Contraintes
result un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) La relation entre shape(a) et shape(b) est définie comme suit :
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Exemples

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Sémantique

Génère un tuple result à partir des valeurs val.

Entrées

Libellé Nom Type Contraintes
(I1) val nombre de valeurs variable (C1)

Sorties

Nom Type Contraintes
result tuple (C1)

Contraintes

  • (C1) result est de type tuple<E0, ..., EN-1>, où Ei = type(val[i]).

Exemples

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

 Autres exemples

uniform_dequantize

Sémantique

Effectue la conversion élément par élément du tenseur quantifié operand en tenseur à virgule flottante result, conformément aux paramètres de quantification définis par le type operand.

Plus formellement, result = dequantize(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result tenseur de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Exemples

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Sémantique

Effectue une conversion par élément du tenseur à virgule flottante ou du tenseur quantifié operand en tenseur quantifié result, conformément aux paramètres de quantification définis par le type result.

Plus formellement,

  • Si is_float(operand) :
    • result = quantize(operand, type(result)).
  • Si is_quantized(operand) :
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur de type à virgule flottante ou quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result tenseur quantifié (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Exemples

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

tandis que

Sémantique

Génère la sortie de l'exécution de la fonction body au moins une fois lorsque la fonction cond génère true. Plus formellement, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Le comportement d'une boucle infinie est à déterminer (numéro 383).

Entrées

Libellé Nom Type Contraintes
(I1) operand Nombre variable de tenseurs, de tenseurs quantifiés ou de jetons (C1-C3)
(I2) cond fonction (C1)
(I3) body fonction (C2)

Sorties

Nom Type Contraintes
results nombre variadique de Tensors, Tensors quantifiés ou jetons (C3)

Contraintes

  • (C1) cond est de type (T0, ..., TN-1) -> tensor<i1>, où Ti = type(operand[i]).
  • (C2) body est de type (T0, ..., TN-1) -> (T0, ..., TN-1), où Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Exemples

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

 Autres exemples

xor

Sémantique

Effectue une opération XOR élément par élément de deux tenseurs lhs et rhs, et génère un tenseur result. En fonction du type d'élément, effectue les opérations suivantes :

  • Pour les valeurs booléennes : XOR logique.
  • Pour les entiers : XOR (OU exclusif) bit à bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs un tenseur de type booléen ou entier ; (C1)
(I2) rhs un tenseur de type booléen ou entier ; (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

 Autres exemples

Interopérabilité des dialectes

Pour le moment, les programmes StableHLO dans la nature contiennent parfois des opérations qui ne sont pas définies par StableHLO.

Module, fonction, appel et retour

StableHLO utilise les opérations MLIR en amont pour ModuleOp, FuncOp, CallOp et ReturnOp. Cela a été fait pour améliorer l'interopérabilité avec le mécanisme MLIR existant, car de nombreuses passes utiles sont écrites en ciblant FuncOp et ModuleOp, et de nombreux pipelines de compilation s'attendent à ce que ces opérations soient présentes. Des garanties de compatibilité complète sont appliquées à ces opérations. Si ces opérations changent de manière incompatible (par exemple, suppression), des équivalents StableHLO seront ajoutés pour préserver la compatibilité.

CHLO

L'ensemble d'opérations CHLO contient des opérations de niveau supérieur qui se décomposent en StableHLO. Actuellement, aucune garantie de compatibilité n'est proposée pour les CHLO. Pour garantir la compatibilité, la passe clo-legalize-to-stablehlo doit être utilisée avant la sérialisation.

Opérations de forme

Dans la communauté, il est courant d'utiliser certaines opérations des dialectes MLIR de base dans des programmes StableHLO dynamiques pour effectuer des calculs de forme. Le plus souvent, il s'agit d'opérations de dialecte shape telles que shape_of ou num_elements, d'opérations de dialecte tensor telles que dim ou from_elements, et du type index intégré.

Le document Dynamism RFC > O2 les indique comme étant hors du champ d'application, mais certains types de prise en charge des types index sont inclus à des fins d'interopérabilité. Aucune garantie de compatibilité n'est fournie pour ces opérations ou types. La passe shape-legalize-to-stablehlo peut être utilisée pour convertir ces opérations en opérations StableHLO entièrement compatibles.

Opérations obsolètes

Plusieurs opérations StableHLO héritées de MHLO sont obsolètes et en cours d'abandon de StableHLO. Pour en savoir plus sur ces suppressions, consultez la page Nettoyage de StableHLO v1.0 2283. Le problème de suivi de ces abandons est le n° 2340.

Ces opérations se répartissent en plusieurs catégories:

  • Catégorie "Not in HLO" (Pas dans HLO) des opérations StableHLO : elles faisaient initialement partie de l'ensemble d'opérations StableHLO, mais ont ensuite été jugées inadaptées : broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (n° 3).
  • Opérations inutilisées : ces opérations ont peut-être été utiles à un moment donné, mais elles n'étaient pas suffisamment développées ou les pipelines qui les utilisaient ont été refactorisés pour ne plus en avoir besoin. Cela inclut map, tuple (598), get_tuple_element, rng, les comparaisons complex 560 et la convolution window_reversal (1181).

Certaines de ces opérations peuvent être facilement supprimées, car elles peuvent être exprimées à l'aide d'opérations existantes (broadcast, create_token, cross-replica-sum, dot, unary_einsum) et seront supprimées une fois la période de compatibilité existante écoulée (six mois). D'autres sont encore à l'étude pour être supprimées (einsum, get_tuple_element, map, rng, torch_index_select, tuple, complex, comparaisons, window_reversal). En attendant les commentaires de la communauté, ces opérations seront supprimées ou ajoutées à la spécification avec une prise en charge complète. Tant que ces futures opérations ne sont pas connues, leur compatibilité n'est garantie que pendant six mois.

Exécution

Exécution séquentielle

Un programme StableHLO est exécuté en fournissant des valeurs d'entrée à la fonction main et en calculant les valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées en exécutant le graphique des opérations enracinées dans l'opération return correspondante.

L'ordre d'exécution est défini par l'implémentation tant qu'il est aligné sur le flux de données, c'est-à-dire si les opérations sont exécutées avant leur utilisation. Dans StableHLO, toutes les opérations à effet secondaire consomment un jeton et en produisent un (plusieurs jetons peuvent être multiplexés en un seul jeton via after_all). L'ordre d'exécution des effets secondaires est donc également aligné sur le flux de données. Par exemple, dans le programme ci-dessous, deux ordres d'exécution sont possibles: %0%1%2return et %1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Plus formellement, un processus StableHLO est une combinaison de : 1) un programme StableHLO, 2) des états d'exécution (non encore exécuté, déjà exécuté) et 3) des valeurs intermédiaires sur lesquelles le processus travaille. Le processus commence par les valeurs d'entrée de la fonction main, passe par le graphique des opérations mettant à jour les états des opérations et les valeurs intermédiaires, puis se termine par les valeurs de sortie. À déterminer pour une formalisation plus poussée (#484)

Exécution parallèle

Les programmes StableHLO peuvent être exécutés en parallèle, organisés dans une grille de processus 2D de num_replicas par num_partitions, qui ont tous deux le type ui32.

Dans la grille de processus StableHLO, num_replicas * num_partitions des processus StableHLO s'exécutent en même temps. Chaque processus possède un process_id = (replica_id, partition_id) unique, où replica_id dans replica_ids = range(num_replicas) et partition_id dans partition_ids = range(num_partitions), qui sont tous deux de type ui32.

La taille de la grille de processus est connue de manière statique pour chaque programme (à l'avenir, nous prévoyons de la rendre explicite dans les programmes StableHLO 650), et la position dans la grille de processus est connue de manière statique pour chaque processus. Chaque processus a accès à sa position dans la grille de processus via les opérations replica_id et partition_id.

Dans la grille de processus, les programmes peuvent tous être identiques (dans le style "Programme unique, données multiples"), tous différents (dans le style "Programmes multiples, données multiples") ou quelque chose entre les deux. À l'avenir, nous prévoyons de prendre en charge d'autres idiomes permettant de définir des programmes StableHLO parallèles, y compris GSPMD (#619).

Dans la grille de processus, les processus sont pour la plupart indépendants les uns des autres. Ils ont des états d'opération distincts, des valeurs d'entrée/intermédiaire/sortie distinctes, et la plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.

Étant donné que l'exécution de la plupart des opérations n'utilise que les valeurs du même processus, il est généralement clair de faire référence à ces valeurs par leur nom. Toutefois, lorsque vous décrivez la sémantique des opérations collectives, cela est insuffisant, et la notation name@process_id est utilisée pour faire référence à la valeur name dans un processus particulier. (Dans cette perspective, name non qualifié peut être considéré comme un raccourci pour name@(replica_id(), partition_id()).)

L'ordre d'exécution entre les processus est défini par l'implémentation, à l'exception de la synchronisation introduite par la communication point à point et les opérations collectives, comme décrit ci-dessous.

Communication point à point

Les processus StableHLO peuvent communiquer entre eux via des canaux StableHLO. Un canal est représenté par un identifiant positif de type si64. Grâce à diverses opérations, il est possible d'envoyer des valeurs aux canaux et de les recevoir de ces derniers.

Une formalisation plus poussée, par exemple sur l'origine de ces ID de canal, la façon dont les programmes de processus en prennent connaissance et le type de synchronisation qu'ils introduisent, est à définir (numéro 484).

Communication en streaming

Chaque processus StableHLO a accès à deux interfaces de streaming :

  • Infeed, qui peut être lu.
  • Flux de sortie dans lequel des opérations d'écriture peuvent être effectuées.

Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et qui ont donc des processus à leurs deux extrémités, les flux entrants et sortants ont leur autre extrémité définie par l'implémentation.

Une formalisation plus poussée, par exemple sur l'impact de la communication en streaming sur l'ordre d'exécution et le type de synchronisation qu'elle introduit, est à définir (numéro 484).

Opérations collectives

StableHLO propose six opérations collectives: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute et reduce_scatter. Toutes ces opérations divisent les processus de la grille de processus StableHLO en groupes de processus StableHLO et exécutent un calcul conjoint dans chaque groupe de processus, indépendamment des autres groupes de processus.

Dans chaque groupe de processus, les opérations collectives peuvent entraîner une barrière de synchronisation. Une formalisation plus poussée, par exemple en précisant le moment exact où cette synchronisation se produit, comment les processus arrivent exactement à cette barrière et ce qui se passe s'ils ne le font pas, est à déterminer (numéro 484).

Si le groupe de processus implique une communication interpartition, c'est-à-dire que des processus du groupe de processus ont des ID de partition différents, l'exécution de l'opération collective nécessite un canal, et l'opération collective doit fournir un channel_id positif de type si64. La communication entre les réplications n'a pas besoin de canaux.

Les calculs effectués par les opérations collectives sont spécifiques aux opérations individuelles et sont décrits dans les sections d'opérations individuelles ci-dessus. Toutefois, les stratégies selon lesquelles la grille de processus est divisée en groupes de processus sont partagées entre ces opérations et sont décrites dans cette section. Plus formellement, StableHLO prend en charge les quatre stratégies suivantes.

cross_replica

Seules les communications interrépliques ont lieu au sein de chaque groupe de processus. Cette stratégie prend replica_groups (une liste de listes d'ID de réplicas) et calcule un produit cartésien de replica_groups par partition_ids. Les replica_groups doivent comporter des éléments uniques et couvrir tous les replica_ids. Plus formellement, en utilisant la syntaxe Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica génère [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Seules les communications interpartitions ont lieu au sein de chaque groupe de processus. Cette stratégie prend partition_groups (une liste de listes d'ID de partition) et calcule un produit cartésien de partition_groups par replica_ids. Les partition_groups doivent comporter des éléments uniques et couvrir tous les partition_ids. Plus formellement, en utilisant la syntaxe Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour partition_groups = [[0, 1]] et num_replicas = 4, cross_partition produit [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Des communications inter-réplicas et inter-partitions peuvent se produire dans chaque groupe de processus. Cette stratégie prend replica_groups, une liste de listes d'ID de réplication, et calcule les produits cartésiens de chaque replica_group par partition_ids. Les replica_groups doivent comporter des éléments uniques et couvrir tous les replica_ids. Plus formellement, en utilisant la syntaxe Python :

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica_and_partition produit [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Cette stratégie prend flattened_id_groups (une liste de listes d'ID de processus "aplatis" sous la forme de replica_id * num_partitions + partition_id) et les transforme en ID de processus. flattened_id_groups doit comporter des éléments uniques et couvrir tous les process_ids. Plus formellement, en utilisant la syntaxe Python :

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 et num_partitions = 2, flattened_ids produit [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Précision

Pour le moment, StableHLO ne fournit aucune garantie concernant l'exactitude numérique, mais cela peut changer à l'avenir (numéro 1156).

Sémantique d'exécution de l'opération quantifiée

L'interprétation des opérations StableHLO quantifiées peut varier en fonction des capacités et de la configuration matérielle requise. Par exemple, certains matériels peuvent choisir d'interpréter les opérations quantiques à l'aide d'une stratégie de "déquantisation, exécution d'une opération à virgule flottante et, enfin, quantification". D'autres peuvent effectuer l'intégralité du calcul avec une méthode arithmétique d'entiers. Par conséquent, l'interprétation des opérations StableHLO échantillonnées est déterminée exclusivement par l'implémentation spécifique. L'interprétation de la quantification hybride (#1575) doit être basée sur sa sémantique, telle que prescrite dans la spécification (via 1792).

Erreurs

Les programmes StableHLO sont validés via un ensemble étendu de contraintes pour les opérations individuelles, ce qui exclut de nombreuses classes d'erreurs avant l'exécution. Toutefois, les conditions d'erreur sont toujours possibles, par exemple via des débordements d'entiers, des accès hors limites, etc. Sauf mention explicite, toutes ces erreurs entraînent un comportement défini par l'implémentation, mais cela peut changer à l'avenir (numéro 1157).

Exceptions à virgule flottante

À titre d'exception à cette règle, les exceptions à virgule flottante dans les programmes StableHLO ont un comportement bien défini. Les opérations qui génèrent des exceptions définies par la norme IEEE-754 (opération non valide, division par zéro, débordement, sous-dépassement ou exceptions inexactes) génèrent des résultats par défaut (tels que définis dans la norme) et poursuivent l'exécution sans lever le drapeau d'état correspondant, semblable à la gestion des exceptions raiseNoFlag de la norme. Les exceptions pour les opérations non standards (par exemple, les opérations arithmétiques complexes et certaines fonctions transcendantes) sont définies par l'implémentation.

Incohérences de forme

StableHLO accepte les Tensors de forme dynamique. Toutefois, les formes doivent être cohérentes au moment de l'exécution, sinon le comportement n'est pas défini. StableHLO ne fournit pas explicitement une opération pouvant affirmer qu'un tenseur a une forme donnée au moment de l'exécution. La génération du code correct relève de la responsabilité du producteur.

Par exemple, le programme ci-dessous est valide. Toutefois, au moment de l'exécution, les formes exactes de %arg0 et %arg1 doivent être identiques, sinon le comportement du programme n'est pas défini:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

Pour décrire la syntaxe, ce document utilise le type ISO modifié de la syntaxe EBNF (ISO/IEC 14977:1996, Wikipédia), avec deux modifications: 1) les règles sont définies à l'aide de ::= plutôt que de =,

2) La concaténation est exprimée à l'aide de la juxtaposition plutôt que de ,.

Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constants" et "Ops"), nous utilisons des formules basées sur la syntaxe Python étendue avec la prise en charge de l'expression concise des opérations de tableau, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans de rares cas où des extraits de code plus volumineux sont nécessaires, nous utilisons la syntaxe Python standard, qui est toujours introduite explicitement.

Formules

Découvrons comment fonctionnent les formules à l'aide d'un exemple tiré des spécifications dot_general. L'une des contraintes de cette opération se présente comme suit : dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Les noms utilisés dans cette formule proviennent de deux sources: 1) les fonctions globales, c'est-à-dire dim, 2) les définitions des membres de l'élément de programme correspondant, c'est-à-dire les entrées lhs, lhs_batching_dimensions, rhs et rhs_batching_dimensions définies dans la section "Inputs" (Entrées) de dot_general.

Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python, avec certaines extensions axées sur la concision. Pour comprendre la formule, transformons-la en syntaxe Python standard.

A) Dans ces formules, nous utilisons = pour représenter l'égalité. La première étape pour obtenir la syntaxe Python consiste donc à remplacer = par ==, comme suit : dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) De plus, ces formules acceptent les points de suspension (...) qui transforment les expressions scalaires en expressions de Tensor. En résumé, f(xs...) signifie approximativement "pour chaque scalaire x dans le tenseur xs, calculer un scalaire f(x), puis renvoyer tous ces résultats scalaires ensemble en tant que résultat de tenseur". Dans la syntaxe Python standard, notre exemple de formule se transforme en : [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Grâce aux ellipses, il est souvent possible d'éviter de travailler au niveau des scalaires individuels. Toutefois, dans certains cas délicats, une syntaxe semi-informelle de niveau inférieur peut être utilisée, comme dans la formule start_indices[bi0, ..., :, ..., biN] de la spécification gather. Par souci de concision, nous ne fournissons pas de formalisme exact pour traduire cette syntaxe en Python standard, dans l'espoir qu'elle reste intuitivement compréhensible au cas par cas. N'hésitez pas à nous contacter si certaines formules spécifiques semblent opaques. Nous ferons de notre mieux pour les améliorer.

Vous remarquerez également que les formules utilisent des points de suspension pour développer toutes sortes de listes, y compris des tenseurs, des listes de tenseurs (qui peuvent par exemple provenir d'un nombre variadique de tenseurs), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas de formalisme exact (par exemple, les listes ne font même pas partie du système de types StableHLO) et nous nous appuyons plutôt sur une compréhension intuitive.

C) Le dernier mécanisme de notation que nous utilisons est la diffusion implicite. Bien que l'opset StableHLO n'accepte pas la diffusion implicite, les formules le font, également au service de la concision. En résumé, si un scalaire est utilisé dans un contexte où un tenseur est attendu, le scalaire est diffusé dans la forme attendue.

Pour poursuivre l'exemple dot_general, voici une autre contrainte : 0 <= lhs_batching_dimensions < rank(lhs). Comme défini dans la spécification dot_general, lhs_batching_dimensions est un tenseur, mais 0 et rank(lhs) sont des scalaires. Une fois la diffusion implicite appliquée, la formule devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Lorsqu'elle est appliquée à une opération dot_general particulière, cette formule est évaluée selon un Tensor de valeurs booléennes. Lorsque des formules sont utilisées comme contraintes, la contrainte est valide si la formule s'évalue à true ou à un tenseur qui ne comporte que des éléments true.

Noms

Dans les formules, le champ lexical inclut : 1) les fonctions globales, 2) les définitions de membres,

3) définitions locales. La liste des fonctions globales est fournie ci-dessous. La liste des définitions d'éléments dépend de l'élément du programme auquel la notation est appliquée:

  • Pour les opérations, les définitions des membres incluent les noms introduits dans les sections "Entrées" et "Sorties".
  • Pour tout le reste, les définitions de membres incluent des parties structurelles de l'élément de programme, nommées d'après les non-terminaux EBNF correspondants. La plupart du temps, les noms de ces parties structurelles sont obtenus en convertissant les noms des non-terminaux en snake case (par exemple, IntegerLiteral => integer_literal), mais parfois les noms sont abrégés dans le processus (par exemple, QuantizationStorageType => storage_type). Dans ce cas, les noms sont introduits explicitement de la même manière que les sections "Inputs" / "Outputs" dans les opérations et les sorties.
  • De plus, les définitions de membres incluent toujours self pour faire référence à l'élément de programme correspondant.

Valeurs

Lorsque les formules sont évaluées, elles fonctionnent avec les types de valeurs suivants : 1) Value (valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> ; leurs types sont toujours connus), 2) Placeholder (valeurs futures, par exemple lhs, rhs ou result ; leurs valeurs réelles ne sont pas encore connues, seuls leurs types le sont), 3) Type (types tels que définis dans la section "Types"), 4) Function (fonctions globales telles que définies dans la section "Fonctions").

Selon le contexte, les noms peuvent faire référence à différentes valeurs. Plus précisément, la section "Sémantique" des opérations (et des équivalents pour les autres éléments de programme) définit la logique d'exécution. Toutes les entrées sont donc disponibles en tant que Value. En revanche, la section "Contraintes" des opérations (et des équivalents) définit une logique "au moment de la compilation", c'est-à-dire quelque chose qui est généralement exécuté avant l'exécution. Par conséquent, seules les entrées constantes sont disponibles en tant que Value et les autres entrées ne sont disponibles qu'en tant que Placeholder.

Noms Dans "Sémantique" Dans "Contraintes"
Fonctions globales Function Function
Entrées constantes Value Value
Entrées non constantes Value Placeholder
Sorties Value Placeholder
Définitions locales Cela dépend de la définition Cela dépend de la définition

Prenons un exemple d'opération transpose :

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Pour cette opération, permutation est une constante. Il est donc disponible en tant que Value, tant au niveau de la sémantique que des contraintes. En revanche, operand et result sont disponibles en tant que Value en sémantique, mais uniquement en tant que Placeholder dans les contraintes.

Fonctions

Construction de types

Aucune fonction ne peut être utilisée pour créer des types. À la place, nous utilisons directement la syntaxe de type, car elle est généralement plus concise. Par exemple, (tensor<E>, tensor<E>) -> (tensor<E>) plutôt que function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Fonctions sur les types

  • element_type est défini sur les types de tenseurs et les types de tenseurs quantifiés, et renvoie respectivement la partie TensorElementType ou QuantizedTensorElementType de l'TensorType ou de l'QuantizedTensorType correspondant.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool vérifie si le type x peut être promu au type y. Lorsque x et y sont des QuantizedTensorElementType, la promotion ne s'applique qu'au storage_type. Cette version spécifique de la promotion est actuellement utilisée dans le contexte du calcul de la réduction (pour en savoir plus, consultez la RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Disponible pour tous les types. Par exemple, is_float(x) renvoie true si x est un FloatType. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour is_type_name(type(x)).

  • max_value(x: Type) -> Value renvoie la valeur maximale d'un TensorElementType. Si x n'est pas un TensorElementType, renvoie None.

  • min_value(x: Type) -> Value renvoie la valeur minimale possible d'un TensorElementType. Si x n'est pas un TensorElementType, renvoie None.

  • member_name(x: Value | Placeholder | Type) -> Any. Disponible pour toutes les définitions de membres member_name de tous types. Par exemple, tensor_element_type(x) renvoie la partie TensorElementType d'un TensorType correspondant. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour member_name(type(x)). Si x n'est pas un type qui possède un membre approprié, ou une valeur ou un espace réservé de ce type, renvoie None.

  • is_empty_algorithm(*args: Type) vérifie si tous les champs de l'algorithme de point sont définis sur None. Cela est nécessaire, car les algorithmes de point ont des comportements par défaut définis par l'implémentation. Il serait donc incorrect de spécifier une valeur par défaut.

Construction des valeurs

  • operation_name(*xs: Value | Type) -> Value. Disponible pour toutes les opérations. Par exemple, add(lhs, rhs) prend deux valeurs de tenseur lhs et rhs, puis renvoie le résultat de l'évaluation de l'opération add avec ces entrées. Pour certaines opérations (par exemple, broadcast_in_dim), leurs sorties sont de type "chargeur", c'est-à-dire nécessaires pour évaluer une opération. Dans ce cas, la fonction utilise ces types comme arguments.

Fonctions sur les valeurs

  • Tous les opérateurs et fonctions de Python sont disponibles. Par exemple, les notations subscription (abonnement) et slicing (tranchage) de Python sont disponibles pour indexer des tenseurs, des tenseurs quantiques et des tupels.

  • to_destination_type(x: Value, destination_type: Type) -> Value est défini sur les tenseurs et renvoie la valeur convertie de x en fonction de type(x) et destination_type comme suit:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Nous discutons actuellement de la fusion des opérations convert, uniform_quantize et uniform_dequantize (numéro 1576). Après la fusion, nous n'avons plus besoin de la fonction ci-dessus et pouvons utiliser le nom de l'opération pour convert à la place.

  • is_nan(x: Value) -> Value est défini sur les Tensors et renvoie true si tous les éléments de x sont NaN ou false dans les autres cas. Si x n'est pas un tenseur, renvoie None.

  • is_sorted(x: Value) -> Value est défini sur les tenseurs et renvoie true si les éléments de x sont triés par ordre croissant par rapport à l'ordre lexicographique croissant de leurs indices, ou false dans le cas contraire. Si x n'est pas un tenseur, renvoie None.

  • is_unique(x: Value) -> Value est défini sur les tenseurs et renvoie true si x ne comporte pas d'éléments en double, ou false dans le cas contraire. Si x n'est pas un tenseur, renvoie None.

  • member_name(x: Value) -> Any est défini pour toutes les définitions de membres member_name de toutes les valeurs. Par exemple, real_part(x) renvoie la partie RealPart d'un ComplexConstant correspondant. Si x n'est pas une valeur associée à un membre approprié, renvoie None.

  • same(x: Value) -> Value est défini sur les Tensors et renvoie true si les éléments de x sont tous égaux les uns aux autres, ou false dans le cas contraire. Si le tenseur ne comporte pas d'éléments, il est considéré comme "tous égaux les uns aux autres", c'est-à-dire que la fonction renvoie true. Si x n'est pas un Tensor, la fonction renvoie None.

  • split(x: Value, num_results: Value, axis: Value) -> Value est défini sur les tenseurs et renvoie des tranches num_results de x le long de l'axe axis. Si x n'est pas un tenseur ou dim(x, axis) % num_results != 0, renvoie None.

  • is_defined_in_parent_scope(x: Value) -> Value est défini sur des chaînes et renvoie true si x est le nom d'une fonction définie dans le même champ d'application que la fonction parente de l'opération concernée.

  • is_namespaced_op_name(x: Value) -> Value est défini sur des chaînes et renvoie true si x est un nom d'opération valide, c'est-à-dire qu'il respecte l'expression régulière suivante : [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Calculs de forme

  • axes(x: Value | Placeholder | Type) -> Value est un raccourci pour range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value est un raccourci pour shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List est un raccourci pour list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value est défini sur les Tensors et renvoie les index size(x) pour le TensorType correspondant, trié par ordre lexicographique croissant, c'est-à-dire [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Si x n'est pas un type de tenseur, un type de tenseur quantifié, une valeur ou un espace réservé de l'un de ces types, None est renvoyé.

  • rank(x: Value | Placeholder | Type) -> Value est un raccourci pour size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value est défini dans la section "Fonctions sur les types" via member_name.

  • size(x: Value | Placeholder | Type) -> Value est un raccourci pour reduce(lambda x, y: x * y, shape(x)).

Calculs de quantification

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type est un raccourci pour element_type(baseline_type(x)).

  • baseline_type est défini sur les types de tenseurs et les types de tenseurs quantifiés, et les transforme en "ligne de base", c'est-à-dire en un type de la même forme, mais avec les paramètres de quantification du type d'élément réinitialisés aux valeurs par défaut. Il s'agit d'une astuce pratique pour comparer les types de Tensors et quantifiés de manière uniforme, ce qui est assez souvent nécessaire. Pour les types quantifiés, cela permet de comparer les types en ignorant les paramètres de quantification, c'est-à-dire que shape, storage_type, expressed_type, storage_min, storage_max et quantization_dimension (pour le type quantifié par axe) doivent tous correspondre, mais scales et zero points peuvent différer.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize est défini sur les types de tenseurs quantifiés et les convertit en types de tenseurs à virgule flottante. Pour ce faire, les éléments quantifiés, qui représentent les valeurs entières du type de stockage, sont convertis en valeurs à virgule flottante correspondantes du type exprimé à l'aide du point zéro et de la mise à l'échelle associés au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize est défini sur les types de tenseurs à virgule flottante et les transforme en types de tenseurs quantifiés. Cela se produit en convertissant les valeurs à virgule flottante du type exprimé en valeurs entières correspondantes du type de stockage en utilisant le point zéro et l'échelle associées au type d'élément quantifié.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize permet de spécifier des calculs par élément sur des tenseurs quantifiés. Il déquantifie, c'est-à-dire qu'il transforme les éléments quantifiés en types exprimés, puis effectue une opération, puis quantifie, c'est-à-dire qu'il transforme les résultats en types de stockage. Pour le moment, cette fonction ne fonctionne que pour la quantification par tenseur. La quantification par axe est en cours (numéro 1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op permet de spécifier une quantification uniquement pour les poids pour une opération hybride qui accepte la valeur gauche en virgule flottante et la valeur droite en types quantifiés. Il déquantifie les entrées quantifiées en types exprimés et effectue des calculs avec float. Le type d'élément du tenseur de gauche à virgule flottante et le type exprimé du tenseur de droite quantifié doivent être identiques.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Calculs de grille

  • cross_partition(replica_groups: Value) -> Value. Consultez la section "cross_replica" ci-dessus.

  • cross_replica(replica_groups: Value) -> Value. Consultez la section "cross_replica" ci-dessus.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Consultez la section "cross_replica_and_partition" ci-dessus.

  • flattened_ids(replica_groups: Value) -> Value. Consultez la section "flattened_ids" ci-dessus.

Dynamique

Les valeurs StableHLO peuvent avoir des tailles de dimension dynamiques, par exemple tensor<?xi64>. Toutefois, les valeurs StableHLO ne peuvent pas avoir un nombre dynamique de dimensions (dynamisme non classé, par exemple tensor<*xi64>). Les opérandes et les résultats sont autorisés à utiliser des tailles de dimension dynamiques, même si des contraintes s'appliquent à ces tailles. Les contraintes sont vérifiées de manière statique si possible. Sinon, elles sont différées jusqu'à l'exécution, et les incohérences entraînent un comportement non défini. Vous trouverez des exemples ci-dessous.

Incohérences de forme pour les opérations unaires par élément

Prenons l'exemple du programme de jouets suivant:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Un tel programme est inhabituel, car il n'est pas courant de connaître la forme du résultat, mais pas celle de l'entrée. Néanmoins, il s'agit d'un programme StableHLO valide. Il n'est pas possible de valider de manière statique l'opération abs dans ce programme, car la forme exacte de l'opérande est inconnue. Toutefois, les formes sont certainement compatibles, et cela peut être vérifié de manière statique : ? peut s'avérer être 2 au moment de l'exécution, et il n'y aura aucun problème. Toutefois, ? peut également s'avérer être un autre entier, auquel cas le comportement est indéfini.

Notez que si la taille d'une dimension est dynamique dans le résultat, il ne peut pas y avoir de comportement indéfini. En effet, il n'existe pas de taille "attendue", il ne peut donc pas y avoir de non-concordance.

Incohérences de forme pour les opérations binaires par élément

Prenons l'exemple de programme suivant :

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

En ce qui concerne les opérations binaires par élément, les formes des entrées et du résultat doivent correspondre au moment de l'exécution. Au moment de la compilation, les dimensions statiques doivent être égales, sinon elles doivent simplement être compatibles. Si n'importe quelle dimension est dynamique dans les entrées, il peut y avoir un comportement non défini au moment de l'exécution, car la taille dynamique peut ne pas correspondre à la taille correspondante dans l'autre opérande (statique ou dynamique). Si toutes les entrées sont statiques, le fait que le résultat soit dynamique ou non n'a pas d'importance: les dimensions connues de manière statique seront vérifiées de manière statique, et les dimensions dynamiques n'imposent aucune contrainte.

Incohérences de forme pour les opérations qui utilisent leur forme de sortie comme opérande

Prenons l'exemple du programme de jouets suivant:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Les valeurs de l'opérande de forme au moment de l'exécution doivent correspondre à la forme du résultat, sinon le comportement n'est pas défini. Autrement dit, au moment de l'exécution, %arg0 doit avoir une valeur de dense<[3, 4]> : tensor<2xi32>. Si l'opérande de forme est une constante, cela peut être vérifié de manière statique. Si la forme du résultat est entièrement dynamique, il ne peut pas y avoir de divergence.