Spécification StableHLO

StableHLO est un ensemble d'opérations pour les opérations de haut niveau (HLO) dans les modèles de machine learning (ML). StableHLO sert de couche de portabilité entre différents frameworks et compilateurs de ML : les frameworks de ML qui produisent des programmes StableHLO sont compatibles avec les compilateurs de ML qui consomment des programmes StableHLO.

Notre objectif est de simplifier et d'accélérer le développement du ML en créant une meilleure interopérabilité entre les différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et les compilateurs de ML (tels que XLA et IREE). À cette fin, ce document fournit une spécification pour le langage de programmation StableHLO.

Cette spécification comporte trois sections principales. Tout d'abord, la section Programmes décrit la structure des programmes StableHLO, qui se composent de fonctions StableHLO, elles-mêmes composées d'opérations StableHLO. Dans cette structure, la section Ops spécifie la sémantique des opérations individuelles. La section Exécution fournit la sémantique pour toutes ces opérations exécutées ensemble dans un programme. Enfin, la section Notation aborde la notation utilisée tout au long de la spécification.

Pour afficher les spécifications d'une version précédente de StableHLO, ouvrez le dépôt au niveau de la version taguée qui vous intéresse. Par exemple, la spécification StableHLO v0.19.0. Pour afficher les modifications apportées à chaque mise à jour mineure de StableHLO, consultez le journal des versions dans VhloDialect.td.

Programmes

Program ::= {Func}

Les programmes StableHLO se composent d'un nombre arbitraire de fonctions StableHLO. Vous trouverez ci-dessous un exemple de programme avec une fonction @main qui comporte trois entrées (%image, %weights et %bias) et une sortie. Le corps de la fonction comporte six opérations.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Fonctions

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Les fonctions StableHLO (également appelées fonctions nommées) comportent un identifiant, des entrées/sorties et un corps. À l'avenir, nous prévoyons d'introduire des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, #740, #744).

Identifiants

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Les identifiants StableHLO sont semblables aux identifiants de nombreux langages de programmation, avec deux particularités : 1) tous les identifiants ont des sigles qui distinguent les différents types d'identifiants, 2) les identifiants de valeur peuvent être entièrement numériques pour simplifier la génération de programmes StableHLO.

Types

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Les types StableHLO sont classés dans les types de valeurs (également appelés types de première classe) qui représentent les valeurs StableHLO et les types non-valeurs qui décrivent d'autres éléments du programme. Les types StableHLO sont semblables à ceux de nombreux langages de programmation, la principale particularité étant la nature spécifique au domaine de StableHLO, qui entraîne des résultats inhabituels (par exemple, les types scalaires ne sont pas des types de valeurs).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Les types Tensor représentent des tenseurs, c'est-à-dire des tableaux multidimensionnels. Ils ont une forme et un type d'élément, où une forme représente des tailles de dimensions non négatives ou inconnues dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0 à R-1. Le nombre de dimensions R est appelé rang. Par exemple, tensor<2x3xf32> est un type de Tensor avec une forme 2x3 et un type d'élément f32. Il comporte deux dimensions (ou, en d'autres termes, deux axes) : la dimension 0 et la dimension 1, dont les tailles sont respectivement de 2 et 3. Son rang est 2.

Les formes peuvent être partiellement ou complètement inconnues (dynamiques), par exemple tensor<?x2xf64> est partiellement inconnue et tensor<?x?xf64> est complètement inconnue. Les tailles de dimensions dynamiques sont représentées à l'aide d'un ?. Il n'est pas possible de supprimer le classement des formes.

À l'avenir, nous prévoyons d'étendre les types de Tensor au-delà des tailles de dimension et des types d'éléments, par exemple pour inclure les mises en page (#629) et la parcimonie (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nom Type Contraintes
storage_type type entier (C1-C3), (C8)
storage_min constante entière (C1), (C3), (C7)
storage_max constante entière (C2), (C3), (C7)
expressed_type type à virgule flottante (C4)
quantization_dimension constante entière facultative (C10-C12)
scales Nombre variadique de constantes à virgule flottante (C4-C6), (C9), (C10), (C13)
zero_points nombre variable de constantes entières (C7-C9)

Les types d'éléments quantifiés représentent des valeurs entières d'un type de stockage dans la plage allant de storage_min à storage_max (inclus) qui correspondent à des valeurs à virgule flottante d'un type exprimé. Pour une valeur entière donnée i, la valeur à virgule flottante correspondante f peut être calculée sous la forme f = (i - zero_point) * scale, où scale et zero_point sont appelés paramètres de quantification. Les éléments storage_min et storage_max sont facultatifs dans la grammaire, mais ont respectivement les valeurs par défaut min_value(storage_type) et max_value(storage_type). Les types d'éléments quantifiés sont soumis aux contraintes suivantes :

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Si is_empty(quantization_dimension), alors size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Pour le moment, QuantizationScale est une constante à virgule flottante, mais il existe un fort intérêt pour les échelles basées sur des nombres entiers, représentées par des multiplicateurs et des décalages. Nous prévoyons d'étudier cette possibilité dans un avenir proche (#1404).

Une discussion est en cours sur la sémantique de QuantizationZeroPoint, y compris le type, les valeurs et la possibilité d'avoir un ou plusieurs points zéro dans un type de Tensor quantifié. En fonction des résultats de cette discussion, la spécification concernant les zéro points pourra être modifiée à l'avenir (#1405).

Une autre discussion en cours porte sur la sémantique de QuantizationStorageMin et QuantizationStorageMax afin de déterminer si des contraintes doivent être imposées à ces valeurs et à celles des Tensors quantifiés (#1406).

Enfin, nous prévoyons d'explorer la représentation des échelles inconnues et des points zéro, de la même manière que nous prévoyons d'explorer la représentation des tailles de dimensions inconnues (#1407).

Les types de tenseurs quantifiés représentent des tenseurs avec des éléments quantifiés. Ces Tensors sont exactement les mêmes que les Tensors standards, sauf que leurs éléments ont des types d'éléments quantifiés au lieu de types d'éléments standards.

Dans les Tensors quantifiés, la quantification peut être par Tensor, c'est-à-dire avec un scale et un zero_point pour l'ensemble du Tensor, ou par axe, c'est-à-dire avec plusieurs scales et zero_points, une paire par tranche d'une dimension particulière quantization_dimension. Plus formellement, dans un Tensor t avec quantification par axe, il existe des tranches dim(t, quantization_dimension) de quantization_dimension : t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], etc. Tous les éléments de la tranche i utilisent scales[i] et zero_points[i] comme paramètres de quantification. Les types de Tensor quantifiés sont soumis aux contraintes suivantes :

  • Pour la quantification par Tensor :
    • Aucune autre contrainte.
  • Pour la quantification par axe :
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Les types de jetons représentent des jetons, c'est-à-dire des valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer un ordre d'exécution aux opérations, comme décrit dans la section Exécution.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Les types de tampon représentent des tampons. Par exemple, dans XLA, les tampons sont des tableaux multidimensionnels avec un stockage cohérent. Comme les types Tensor, les types Buffer ont une forme et un type d'élément, où une forme représente des tailles de dimension non négatives ou inconnues dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0 à R-1. Le nombre de dimensions R est appelé rang. Par exemple, memref<2x3xf32> est un type de tampon avec une forme 2x3 et un type d'élément f32. Il comporte deux dimensions (ou, en d'autres termes, deux axes) : la dimension 0 et la dimension 1, dont les tailles sont respectivement de 2 et 3. Son rang est 2.

Les tampons peuvent être alloués à l'aide d'un custom_call à CreateBuffer ou Pin, et désalloués à l'aide d'un custom_call à Unpin. Seules les opérations custom_call peuvent lire et écrire le contenu des tampons. Pour en savoir plus, consultez custom_call.

Les types de tuples représentent des tuples, c'est-à-dire des listes hétérogènes. Les tuples sont une fonctionnalité ancienne qui n'existe que pour la compatibilité avec HLO. Dans HLO, les tuples sont utilisés pour représenter les entrées et sorties variadiques. Dans StableHLO, les entrées et sorties variadiques sont prises en charge de manière native. La seule utilisation des tuples dans StableHLO est de représenter de manière exhaustive l'ABI HLO où, par exemple, T, tuple<T> et tuple<tuple<T>> peuvent être très différents selon une implémentation particulière. À l'avenir, nous prévoyons de modifier l'ABI HLO, ce qui pourrait nous permettre de supprimer les types de tuples de StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Les types d'éléments représentent les éléments des types Tensor. Contrairement à de nombreux langages de programmation, ces types ne sont pas de première classe dans StableHLO. Cela signifie que les programmes StableHLO ne peuvent pas représenter directement les valeurs de ces types (par conséquent, il est idiomatique de représenter les valeurs scalaires de type T avec des valeurs de tenseur de dimension 0 de type tensor<T>).

  • Le type booléen représente les valeurs booléennes true et false.
  • Les types entiers peuvent être signés (si) ou non signés (ui) et avoir l'une des largeurs de bits acceptées (2, 4, 8, 16, 32 ou 64). Les types siN signés représentent des valeurs entières allant de -2^(N-1) à 2^(N-1)-1 inclus, et les types uiN non signés représentent des valeurs entières allant de 0 à 2^N-1 inclus.
  • Les types à virgule flottante peuvent être l'un des suivants :
  • Les types complexes représentent des valeurs complexes qui ont une partie réelle et une partie imaginaire du même type d'élément. Les types complexes acceptés sont complex<f32> (les deux parties sont de type f32) et complex<f64> (les deux parties sont de type f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Les types de fonction représentent les fonctions nommées et anonymes. Ils ont des types d'entrée (la liste des types à gauche de ->) et des types de sortie (la liste des types à droite de ->). Dans de nombreux langages de programmation, les types de fonctions sont de première classe, mais pas dans StableHLO.

StringType ::= 'string'

Le type String représente des séquences d'octets. Contrairement à de nombreux langages de programmation, le type de chaîne n'est pas de première classe dans StableHLO et n'est utilisé que pour spécifier des métadonnées statiques pour les éléments de programme.

Opérations

Les opérations StableHLO (également appelées ops) représentent un ensemble fermé d'opérations de haut niveau dans les modèles de machine learning. Comme indiqué ci-dessus, la syntaxe StableHLO s'inspire fortement de MLIR, qui n'est pas forcément l'alternative la plus ergonomique, mais qui est sans doute la mieux adaptée à l'objectif de StableHLO, qui est de créer une meilleure interopérabilité entre les frameworks de ML et les compilateurs de ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Les opérations StableHLO (également appelées ops) ont un nom, des entrées/sorties et une signature. Le nom se compose du préfixe stablehlo. et d'un mnémonique qui identifie de manière unique l'une des opérations compatibles. Vous trouverez ci-dessous la liste complète des opérations compatibles.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Les opérations consomment des entrées et produisent des sorties. Les entrées sont classées dans les catégories suivantes : valeurs d'entrée (calculées lors de l'exécution), fonctions d'entrée (fournies de manière statique, car dans StableHLO, les fonctions ne sont pas des valeurs de première classe) et attributs d'entrée (également fournis de manière statique). Le type d'entrées et de sorties consommées et produites par une opération dépend de son mnémonique. Par exemple, l'opération add consomme deux valeurs d'entrée et produit une valeur de sortie. En comparaison, l'opération select_and_scatter consomme trois valeurs d'entrée, deux fonctions d'entrée et trois attributs d'entrée.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Les fonctions d'entrée (également appelées fonctions anonymes) sont très semblables aux fonctions nommées, sauf qu'elles n'ont pas d'identifiant (d'où le nom "anonyme") et qu'elles ne déclarent pas de types de sortie (les types de sortie sont déduits de l'opération return dans la fonction).

La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir la production Unused ci-dessus), qui est là pour la compatibilité avec MLIR. Dans MLIR, il existe un concept plus général de "régions" qui peuvent comporter plusieurs "blocs" d'opérations connectés entre eux par des opérations de saut. Ces blocs ont des ID qui correspondent à la production Unused, ce qui permet de les distinguer les uns des autres. StableHLO ne comporte pas d'opérations de saut. La partie correspondante de la syntaxe MLIR n'est donc pas utilisée (mais elle est toujours présente).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Les attributs d'entrée ont un nom et une valeur qui correspond à l'une des constantes acceptées. Il s'agit du principal moyen de spécifier des métadonnées statiques pour les éléments de programme. Par exemple, l'opération concatenate utilise l'attribut dimension pour spécifier la dimension le long de laquelle ses valeurs d'entrée sont concaténées. De même, l'opération slice utilise plusieurs attributs tels que start_indices et limit_indices pour spécifier les limites utilisées pour segmenter la valeur d'entrée.

Pour le moment, les programmes StableHLO en circulation contiennent parfois des attributs qui ne sont pas décrits dans ce document. À l'avenir, nous prévoyons d'intégrer ces attributs à l'opset StableHLO ou de leur interdire d'apparaître dans les programmes StableHLO. En attendant, voici la liste de ces attributs :

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Métadonnées de localisation (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

La signature d'opération se compose des types de toutes les valeurs d'entrée (la liste des types à gauche de ->) et des types de toutes les valeurs de sortie (la liste des types à droite de ->). À proprement parler, les types d'entrée sont redondants, et les types de sortie le sont presque toujours également (car pour la plupart des opérations StableHLO, les types de sortie peuvent être déduits des entrées). Néanmoins, la signature op fait délibérément partie de la syntaxe StableHLO pour assurer la compatibilité avec MLIR.

Vous trouverez ci-dessous un exemple d'op dont le mnémonique est select_and_scatter. Il consomme trois valeurs d'entrée (%operand, %source et %init_value), deux fonctions d'entrée et trois attributs d'entrée (window_dimensions, window_strides et padding). Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée (mais pas les types de fonctions et d'attributs d'entrée qui sont fournis en ligne).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Constantes

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Les constantes StableHLO ont un littéral et un type qui représentent ensemble une valeur StableHLO. En général, le type fait partie de la syntaxe constante, sauf lorsqu'il n'y a pas d'ambiguïté (par exemple, une constante booléenne a sans ambiguïté le type i1, tandis qu'une constante entière peut avoir plusieurs types possibles).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Les constantes booléennes représentent les valeurs booléennes true et false. Les constantes booléennes sont de type i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Les constantes entières représentent des valeurs entières sous forme de chaînes utilisant la notation décimale ou hexadécimale. Les autres bases, par exemple binaire ou octale, ne sont pas acceptées. Les constantes entières sont soumises aux contraintes suivantes :

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Les constantes à virgule flottante représentent des valeurs à virgule flottante via des chaînes qui utilisent la notation décimale ou scientifique. De plus, la notation hexadécimale peut être utilisée pour spécifier directement les bits sous-jacents dans le format à virgule flottante du type correspondant. Les constantes à virgule flottante sont soumises aux contraintes suivantes :

  • (C1) Si une notation non hexadécimale est utilisée, is_wellformed(float_literal, float_type).
  • (C2) Si la notation hexadécimale est utilisée, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Les constantes complexes représentent des valeurs complexes à l'aide de listes composées d'une partie réelle (en premier) et d'une partie imaginaire (en second). Par exemple, (1.0, 0.0) : complex<f32> représente 1.0 + 0.0i et (0.0, 1.0) : complex<f32> représente 0.0 + 1.0i. L'ordre dans lequel ces parties sont ensuite stockées en mémoire est défini par l'implémentation. Les constantes complexes sont soumises aux contraintes suivantes :

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Les constantes Tensor représentent les valeurs Tensor à l'aide de listes imbriquées spécifiées à l'aide de la notation NumPy. Par exemple, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> représente une valeur de Tensor avec le mappage suivant des index aux éléments : {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. L'ordre dans lequel ces éléments sont ensuite stockés en mémoire est défini par l'implémentation. Les constantes Tensor sont soumises aux contraintes suivantes :

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), où :
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), où :
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • Sinon, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Les constantes de tenseur quantifié représentent des valeurs de tenseur quantifié en utilisant la même notation que les constantes de tenseur, avec des éléments spécifiés comme constantes de leur type de stockage. Les constantes de Tensor quantifié sont soumises aux contraintes suivantes :

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Les littéraux de chaîne sont constitués d'octets spécifiés à l'aide de caractères ASCII et de séquences d'échappement. Elles sont indépendantes de l'encodage. L'interprétation de ces octets est donc définie par l'implémentation. Les littéraux de chaîne sont de type string.

Opérations

abs

Sémantique

Effectue une opération abs au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les entiers signés : module entier.
  • Pour les valeurs float : abs à partir de la norme IEEE-754.
  • Pour les nombres complexes : module complexe.
  • Pour les types quantifiés : dequantize_op_quantize(abs, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1-C2)

Sorties

Nom Type Contraintes
result Tenseur de type entier signé ou à virgule flottante, ou tenseur quantifié par tenseur (C1-C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • Sinon, baseline_element_type(operand).

Exemples

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]

 Autres exemples

add

Sémantique

Effectue l'addition élément par élément de deux Tensors lhs et rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : OR logique.
  • Pour les nombres entiers : addition d'entiers.
  • Pour les valeurs float : addition à partir de la norme IEEE-754.
  • Pour les nombres complexes : addition complexe.
  • Pour les types quantifiés : dequantize_op_quantize(add, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié (C1-C6)
(I2) rhs tenseur ou tenseur quantifié (C1-C5), (C7)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C1-C7)

Contraintes

  • Si l'opération utilise des Tensors non quantifiés :
    • (C1) type(lhs) = type(rhs) = type(result).
  • Si l'opération utilise des Tensors quantifiés :
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Si is_per_axis_quantized(lhs), alors quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = quantization_dimension(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]

 Autres exemples

after_all

Sémantique

Garantit que les opérations produisant le inputs sont exécutées avant toute opération dépendant de result. L'exécution de cette opération n'a aucun effet. Elle n'existe que pour établir des dépendances de données entre result et inputs.

Entrées

Libellé Nom Type
(I1) inputs Nombre variadique de token

Sorties

Nom Type
result token

Exemples

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token

 Autres exemples

all_gather

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs des Tensors operands de chaque processus le long de all_gather_dim et produit des Tensors results.

L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group :

  • operands...@receiver = [operand@sender for sender in process_group] pour tous les receiver dans process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) pour tous les process dans process_group.

Entrées

Libellé Nom Type Contraintes
(I1) operands un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1), (C6)
(I2) all_gather_dim constante de type si64 (C1), (C6)
(I3) replica_groups Constante Tensor à deux dimensions de type si64 (C2-C4)
(I4) channel_id constante de type si64 (C5)
(I5) use_global_device_ids constante de type i1 (C5)

Sorties

Nom Type Contraintes
results un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C6)

Contraintes

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Si use_global_device_ids = true, alors channel_id > 0.
  • (C6) type(results...) = type(operands...) sauf :
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
  // channel_id = 0
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
  // use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

 Autres exemples

all_reduce

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, applique une fonction de réduction computation aux valeurs des Tensors operands de chaque processus et produit des Tensors results.

L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group :

  • results...@process[result_index] = exec(schedule) pour un arbre binaire donné schedule où :
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est un arbre binaire défini par l'implémentation dont la traversée dans l'ordre est to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

Entrées

Libellé Nom Type Contraintes
(I1) operands un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C5), (C6)
(I2) replica_groups Nombre variable de constantes Tensor à une dimension de type si64 (C1-C3)
(I3) channel_id constante de type si64 (C4)
(I4) use_global_device_ids constante de type i1 (C4)
(I5) computation fonction (C5)

Sorties

Nom Type Contraintes
results un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C6-C7)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Si use_global_device_ids = true, alors channel_id > 0.
  • (C5) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>)is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  // channel_id = 0
  channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
  // use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

 Autres exemples

all_to_all

Sémantique

all_to_all

Dans chaque groupe de processus de la grille de processus StableHLO, divise les valeurs des tenseurs operands le long de split_dimension en parties, répartit les parties divisées entre les processus, concatène les parties réparties le long de concat_dimension et produit des tenseurs results. L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

Ensuite, dans chaque process_group :

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) pour tous les sender dans process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]receiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Entrées

Libellé Nom Type Contraintes
(I1) operands un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1-C3), (C9)
(I2) split_dimension constante de type si64 (C1), (C2), (C9)
(I3) concat_dimension constante de type si64 (C3), (C9)
(I4) split_count constante de type si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Constante Tensor à deux dimensions de type si64 (C5-C8)
(I6) channel_id constante de type si64

Sorties

Nom Type Contraintes
results un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C9)

Contraintes

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...) sauf si split_dimension != concat_dimension :
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
  // channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

 Autres exemples

et

Sémantique

Effectue un AND élément par élément de deux Tensors lhs et rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : AND logique.
  • Pour les nombres entiers : AND au niveau du bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type booléen ou entier (C1)
(I2) rhs Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]

 Autres exemples

atan2

Sémantique

Effectue une opération atan2 au niveau des éléments sur les Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : atan2 à partir de la norme IEEE-754.
  • Pour les nombres complexes : atan2 complexe.
  • Pour les types quantifiés : dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)
(I2) rhs Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

 Autres exemples

batch_norm_grad

Sémantique

Calcule les gradients de plusieurs entrées de batch_norm_training en rétropropagation à partir de grad_output et produit les Tensors grad_operand, grad_scale et grad_offset. Plus formellement, cette opération peut être exprimée comme une décomposition en opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Pour les types quantifiés, effectue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1-C3), (C5)
(I2) scale Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C4), (C5)
(I3) mean Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C4)
(I4) variance Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C4)
(I5) grad_output Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C2), (C3)
(I6) epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C5)

Sorties

Nom Type Contraintes
grad_operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C2), (C3)
grad_scale Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C4)
grad_offset Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C4)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale et grad_offset ont le même baseline_element_type.
  • (C3) operand, grad_output et grad_operand ont la même forme.
  • (C4) scale, mean, variance, grad_scale et grad_offset ont la même forme.
  • (C5) size(scale) = dim(operand, feature_index).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
 <    tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Sémantique

Normalise le tenseur operand dans toutes les dimensions, à l'exception de la dimension feature_index, et produit un tenseur result. Plus formellement, cette opération peut être exprimée comme une décomposition en opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1-C7)
(I2) scale Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C3)
(I3) offset Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C4)
(I4) mean Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C5)
(I5) variance Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur (C2), (C6)
(I6) epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C2), (C7)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance et result ont le même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Sémantique

Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de la dimension feature_index, et normalise le Tensor operand pour produire les Tensors output, batch_mean et batch_var. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Pour les types quantifiés, effectue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)
(I2) scale Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur (C2), (C3)
(I3) offset Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur (C2), (C4)
(I4) epsilon constante de type f32 (C1), (C3-C6)
(I5) feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
output Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C7)
batch_mean Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur (C2), (C5)
batch_var Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur (C2), (C6)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var et output ont le même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
 <   (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Sémantique

Effectue une opération bitcast sur le Tensor operand et produit un Tensor result où les bits de l'ensemble du Tensor operand sont réinterprétés à l'aide du type du Tensor result.

Plus formellement, étant donné E = element_type(operand), E' = element_type(result) et R = rank(operand) :

  • Si num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Si num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Si num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits renvoie la représentation en mémoire d'une valeur donnée. Son comportement est défini par l'implémentation, car la représentation exacte des Tensors et des types d'éléments est également définie par l'implémentation.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié (C1-C2)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C1-C2)

Contraintes

  • (C1) Étant donné E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) et R = rank(operand) :
    • Si num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Si num_bits(E') < num_bits(E) :
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Si num_bits(E') > num_bits(E) :
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Si is_complex(operand) or is_complex(result), alorsis_complex(operand) and is_complex(result).

Exemples

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

 Autres exemples

broadcast_in_dim

Sémantique

broadcast_in_dim

Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données dans le Tensor operand et produit un Tensor result. Plus formellement, result[result_index] = operand[operand_index] où pour tous les d dans axes(operand) :

  • operand_index[d] = 0 si dim(operand, d) = 1.
  • Sinon, operand_index[d] = result_index[broadcast_dimensions[d]].

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié (C1-C2), (C5-C6)
(I2) broadcast_dimensions Constante Tensor de type si64 à une dimension (C2-C6)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C1), (C3), (C5-C6)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf si quantization_dimension(operand), scales(operand) et zero_points(operand) diffèrent de quantization_dimension(result), scales(result) et zero_points(result), respectivement.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Pour tous les d dans axes(operand) :
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si is_per_axis_quantized(result) :
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si dim(operand, quantization_dimension(operand)) = 1, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Exemples

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 Autres exemples

coque

Sémantique

Génère le résultat de l'exécution d'une seule fonction à partir de branches en fonction de la valeur de index. Plus précisément, result = selected_branch(), où :

  • selected_branch = branches[index] si 0 <= index < size(branches).
  • Sinon, selected_branch = branches[-1].

Entrées

Libellé Nom Type Contraintes
(I1) index Tensor de dimension 0 de type si32
(I2) branches nombre variable de fonctions (C1-C4)

Sorties

Nom Type Contraintes
results un nombre variable de Tensors, de Tensors quantifiés ou de jetons. (C4)

Contraintes

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Exemples

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
  "stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]

 Autres exemples

cbrt

Sémantique

Effectue une opération de racine cubique élément par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : rootn(x, 3) à partir de la norme IEEE-754.
  • Pour les nombres complexes : racine cubique complexe.
  • Pour les types quantifiés : dequantize_op_quantize(cbrt, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]

 Autres exemples

ceil

Sémantique

Effectue un plafond élément par élément du Tensor operand et produit un Tensor result. Implémente l'opération roundToIntegralTowardPositive à partir de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(ceil, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

 Autres exemples

cholesky

Sémantique

Calcule la décomposition de Cholesky d'un lot de matrices.

Plus formellement, pour tous les i dans index_space(result), result[i0, ..., iR-3, :, :] est une décomposition de Cholesky de a[i0, ..., iR-3, :, :], sous la forme d'une matrice triangulaire inférieure (si lower est true) ou supérieure (si lower est false). Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict ou le triangle inférieur strict, sont définies par l'implémentation.

Si i existe et que la matrice d'entrée n'est pas une matrice hermitienne définie positive, le comportement n'est pas défini.

Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) a Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1-C3)
(I2) lower constante de type i1

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Exemples

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

limiter

Sémantique

Limite chaque élément du Tensor operand entre une valeur minimale et maximale, et produit un Tensor result. Plus formellement, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), où min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Pour les types quantifiés, effectue dequantize_op_quantize(clamp, min, operand, max, type(result)).

L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (#560).

Entrées

Libellé Nom Type Contraintes
(I1) min tenseur ou tenseur quantifié par tenseur (C1), (C3)
(I2) operand tenseur ou tenseur quantifié par tenseur (C1-C4)
(I3) max tenseur ou tenseur quantifié par tenseur (C2), (C3)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C4)

Contraintes

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Exemples

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]

 Autres exemples

collective_broadcast

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du tenseur operand du processus source aux processus cibles et générez un tenseur result.

L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

result@process est ensuite donné par :

  • operand@process_groups[i, 0] s'il existe un i tel que le processus soit dans process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) dans le cas contraire.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C3)
(I2) replica_groups Nombre variable de constantes Tensor à une dimension de type si64 (C1), (C2)
(I3) channel_id constante de type si64

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C3)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < NN est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C3) type(result) = type(operand).

Exemples

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du tenseur operand du processus source au processus cible et produit un tenseur result.

L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :

  • cross_replica(source_target_pairs) si channel_id <= 0.
  • cross_partition(source_target_pairs) si channel_id > 0.

result@process est ensuite donné par :

  • operand@process_groups[i, 0], s'il existe un i tel que process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) dans le cas contraire.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C5)
(I2) source_target_pairs Constante Tensor à deux dimensions de type si64 (C1-C4)
(I3) channel_id constante de type si64

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, où N est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C5) type(result) = type(operand).

Exemples

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

 Autres exemples

compare

Sémantique

Effectue une comparaison élément par élément des Tensors lhs et rhs selon comparison_direction et compare_type, et produit un Tensor result.

Les valeurs de comparison_direction et compare_type ont la sémantique suivante :

Pour les types d'éléments booléens et entiers :

  • EQ : lhs = rhs.
  • NE : lhs != rhs.
  • GE : lhs >= rhs.
  • GT : lhs > rhs.
  • LE : lhs <= rhs.
  • LT : lhs < rhs.

Pour les types d'éléments à virgule flottante avec compare_type = FLOAT, l'opération implémente les opérations IEEE-754 suivantes :

  • EQ : compareQuietEqual.
  • NE : compareQuietNotEqual.
  • GE : compareQuietGreaterEqual.
  • GT : compareQuietGreater.
  • LE : compareQuietLessEqual.
  • LT : compareQuietLess.

Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER, l'opération utilise la combinaison des opérations totalOrder et compareQuietEqual de la norme IEEE-754.

Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag) est effectuée à l'aide des comparison_direction et compare_type fournis. L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes lorsque comparison_direction est GE, GT, LE ou LT (#560).

Pour les types quantifiés, effectue dequantize_compare(lhs, rhs, comparison_direction).

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié par tenseur (C1-C3)
(I2) rhs tenseur ou tenseur quantifié par tenseur (C1-C2)
(I3) comparison_direction Énumération de EQ, NE, GE, GT, LE et LT
(I4) compare_type Énumération de FLOAT, TOTALORDER, SIGNED et UNSIGNED (C3)

Sorties

Nom Type Contraintes
result Tensor de type booléen (C2)

Contraintes

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type est défini comme suit :
    • SIGNED si is_signed_integer(element_type(lhs)).
    • UNSIGNED si is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT ou TOTALORDER si is_float(element_type(lhs)).
    • FLOAT si is_complex(element_type(lhs)).

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = <#stablehlocomparison_di>rection LT,
  compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]

 Autres exemples

complexe

Sémantique

Effectue une conversion élément par élément en valeur complexe à partir d'une paire de valeurs réelles et imaginaires, lhs et rhs, et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type f32 ou f64 (C1-C3)
(I2) rhs Tensor de type f32 ou f64 (C1)

Sorties

Nom Type Contraintes
result Tensor de type complexe (C2), (C3)

Contraintes

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) est de type complex<E>E = element_type(lhs).

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]

 Autres exemples

composite

Sémantique

Encapsule une opération composée d'autres opérations StableHLO, en prenant inputs et composite_attributes et en produisant results. La sémantique de l'opération est implémentée par l'attribut decomposition. L'opération composite peut être remplacée par sa décomposition sans modifier la sémantique du programme. Dans les cas où l'intégration de la décomposition ne fournit pas la même sémantique d'op, préférez utiliser custom_call.

Le champ version (par défaut 0) est utilisé pour indiquer quand la sémantique d'un composite change.

Entrées

Libellé Nom Type
(I1) inputs nombre variable de valeurs
(I2) name constante de type string
(I3) composite_attributes dictionnaire d'attributs
(I4) decomposition constante de type string
(I5) version constante de type si32

Sorties

Nom Type
results nombre variable de valeurs

Contraintes

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Exemples

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
 < ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32

 Autres exemples

concatenate

Sémantique

Concatène inputs le long de la dimension dimension dans le même ordre que les arguments fournis et produit un Tensor result. Plus formellement, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], où :

  1. id = d0 + ... + dk-1 + kd.
  2. d est égal à dimension, et d0, ... sont les tailles de la de dimension de inputs.

Entrées

Libellé Nom Type Contraintes
(I1) inputs un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1-C6)
(I2) dimension constante de type si64 (C2), (C4), (C6)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C5-C6)

Contraintes

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)), sauf dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]), sauf :
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Exemples

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

 Autres exemples

constante

Sémantique

Génère un Tensor output à partir d'une constante value.

Entrées

Libellé Nom Type Contraintes
(I1) value constante (C1)

Sorties

Nom Type Contraintes
output tenseur ou tenseur quantifié (C1)

Contraintes

  • (C1) type(value) = type(output).

Exemples

%output = "stablehlo.constant"() {
  val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]

 Autres exemples

d'effectuer une conversion

Sémantique

Effectue une conversion élément par élément d'un type d'élément à un autre sur le Tensor operand et produit un Tensor result.

Pour les conversions boolean-to-any-supported-type, la valeur false est convertie en zéro et la valeur true est convertie en un. Pour les conversions any-supported-type-to-boolean, une valeur nulle est convertie en false et les valeurs non nulles sont converties en true. Vous trouverez ci-dessous des informations sur le fonctionnement de cette fonctionnalité pour les types complexes.

Pour les conversions d'entier à entier, d'entier à virgule flottante ou de virgule flottante à virgule flottante, si la valeur source peut être représentée exactement dans le type de destination, la valeur résultante est cette représentation exacte. Sinon, le comportement est à déterminer (#180).

Pour les conversions floating-point-to-integer, la partie fractionnaire est tronquée. Si la valeur tronquée ne peut pas être représentée dans le type de destination, le comportement est à déterminer (#180).

Les conversions complexe vers complexe suivent le même comportement que les conversions virgule flottante vers virgule flottante pour convertir les parties réelles et imaginaires.

Pour les conversions complex-to-any-other-type et any-other-type-to-complex, la valeur imaginaire source est ignorée ou la valeur imaginaire de destination est définie sur zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.

En principe, cette opération pourrait exprimer la déquantification (conversion de Tensors quantifiés en Tensors réguliers), la quantification (conversion de Tensors réguliers en Tensors quantifiés) et la requantification (conversion entre Tensors quantifiés), mais pour le moment, nous avons des opérations dédiées pour cela : uniform_dequantize pour le premier cas d'utilisation et uniform_quantize pour les deuxième et troisième cas d'utilisation. À l'avenir, ces deux opérations pourront être fusionnées dans convert (#1576).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur (C1)

Sorties

Nom Type Contraintes
result tenseur (C1)

Contraintes

  • (C1) shape(operand) = shape(result).

Exemples

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

 Autres exemples

convolution

Sémantique

Calcule les produits scalaires entre les fenêtres de lhs et les tranches de rhs, et génère result. Le diagramme suivant montre comment les éléments de result sont calculés à partir de lhs et rhs à l'aide d'un exemple concret.

convolution

Plus formellement, considérez la reformulation suivante des entrées en termes de lhs afin de pouvoir exprimer des fenêtres de lhs :

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1)
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0])
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Cette reformulation utilise les fonctions d'assistance suivantes :

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]j[d] = i[permutation[d]].

Si feature_group_count = 1 et batch_group_count = 1, alors pour tous les output_spatial_index dans index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product où :

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Si feature_group_count > 1 :

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
  • result = concatenate(results, output_feature_dimension).

Si batch_group_count > 1 :

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Pour les types quantifiés, effectue dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Pour les types quantifiés hybrides, effectue hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié par tenseur (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs tenseur ou tenseur quantifié (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides Constante Tensor de type si64 à une dimension (C2-C3), (C25)
(I4) padding Constante Tensor à deux dimensions de type si64 (C4), (C25)
(I5) lhs_dilation Constante Tensor de type si64 à une dimension (C5-C6), (C25)
(I6) rhs_dilation Constante Tensor de type si64 à une dimension (C7-C8), (C25)
(I7) window_reversal Constante Tensor de type i1 à une dimension (C9)
(I8) input_batch_dimension constante de type si64 (C10), (C13), (C25)
(I9) input_feature_dimension constante de type si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Constante Tensor de type si64 à une dimension (C12), (C13), (C25)
(I11) kernel_input_feature_dimension constante de type si64 (C14), (C18)
(I12) kernel_output_feature_dimension constante de type si64 (C15-C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions Constante Tensor de type si64 à une dimension (C17-C18), (C25)
(I14) output_batch_dimension constante de type si64 (C20), (C25)
(I15) output_feature_dimension constante de type si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions Constante Tensor de type si64 à une dimension (C19-C20), (C25)
(I17) feature_group_count constante de type si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count constante de type si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config Nombre variable d'énums de DEFAULT, HIGH et HIGHEST (C24)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C25-C28), (C30), (C32-34)

Contraintes

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Étant donné input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] :
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Étant donné kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] :
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Étant donné output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] :
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) est défini comme suit :
    • dim(lhs, input_batch_dimension) / batch_group_count si result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si result_dim = output_feature_dimension.
    • num_windows sinon, où :
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si l'opération utilise des Tensors non quantifiés :
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si l'opération utilise des Tensors quantifiés :
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Si is_per_axis_quantized(result), alorsquantization_dimension(result) = output_feature_dimension.
    • Si is_quantized(lhs) :
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Si is_per_tensor_quantized(rhs), alorsis_per_tensor_quantized(result).
    • Si !is_quantized(lhs) :
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Exemples

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strid<es = arra>yi64: 4, 4,
  paddi<n>g = dense<0 : ten>sor2x2xi64,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  // In the StableHLO dialect, dimension numbers are encoded vi<a:
  // `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" a<re spatial dimensions.
  d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
  batch_group_count = 1 : i64,
  fea<ture_group_count >= 1 : i64,
 < precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

 Autres exemples

cosinus

Sémantique

Effectue une opération de cosinus au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : cos à partir de la norme IEEE-754.
  • Pour les nombres complexes : cosinus complexe.
  • Pour les types quantifiés : dequantize_op_quantize(cosine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]

 Autres exemples

count_leading_zeros

Sémantique

Effectue un décompte élément par élément du nombre de bits zéro de début dans le Tensor operand et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]

 Autres exemples

custom_call

Sémantique

Encapsule une opération call_target_name définie par l'implémentation qui prend inputs et called_computations et produit results. has_side_effect, backend_config et api_version peuvent être utilisés pour fournir des métadonnées supplémentaires définies par l'implémentation.

Pour le moment, cette opération contient une collection de métadonnées assez désorganisée qui reflète l'évolution organique de son opération homologue dans le compilateur XLA. À l'avenir, nous prévoyons d'unifier ces métadonnées (#741).

Entrées

Libellé Nom Type
(I1) inputs nombre variable de valeurs
(I2) call_target_name constante de type string
(I3) has_side_effect constante de type i1
(I4) backend_config constante de type string ou dictionnaire d'attributs
(I5) api_version constante de type si32
(I6) called_computations nombre variadique de constantes de type string
(I7) output_operand_aliases spécifier les parties d'alias dans les sorties et les opérandes ;

Sorties

Nom Type
results nombre variable de valeurs

(Compatibilité XLA avec les GPU) Cibles custom_call spéciales

Il existe trois call_target_name spéciaux liés aux types buffer : CreateBuffer crée un buffer non initialisé, Pin crée un buffer initialisé et Unpin libère un buffer et renvoie le contenu du buffer.

%uninitialized_buffer = "stablehlo.custom_call"() {
  call_target_name = "CreateBuffer",
  api_version> = 4 : <i32,
>} : () - memref4xf64

%initialized_buffer = "stablehlo.custom_call"(%init_value) {
  call_target_name = "Pin&quo<t;,
 > ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64

%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
  call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
  api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64

Alias

Certaines opérations custom_call peuvent nécessiter qu'une partie des sorties et une partie des opérandes partagent la même mémoire. Cela peut être exprimé via output_operand_aliases. Une représentation de paire d'alias se compose d'une liste d'indices de tuples de sortie représentant la partie de sortie, et d'un operand_index ainsi que d'une liste d'indices de tuples d'opérandes représentant la partie d'opérande. La liste des indices de tuples de sortie ou d'opérandes est vide si le type correspondant n'est pas un type tuple et peut être arbitrairement longue pour un type de tuple arbitrairement imbriqué. C'est semblable à la représentation de l'alias XLA.

La partie de sortie et la partie d'entrée d'une paire d'alias doivent être du même type. Pour les opérations custom_call qui ne sont pas des appels à CreateBuffer, Pin et Unpin, un opérande buffer peut apparaître dans une seule paire d'alias, et une sortie buffer doit apparaître dans une paire d'alias.

Exemples

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64

%updated_buffer = "stablehlo.custom_call"(%buffer) {
  call_target_name = "Update",
  api_version = 4 : i32,
  output_operand_aliases< = [
    #stablehlo.output_operand_aliasoutput_tuple_indices = [],
      operand_ind>ex = 0,
     < oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64

diviser

Sémantique

Effectue une division élément par élément des Tensors de dividende lhs et de diviseur rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les nombres entiers : division entière qui produit le quotient algébrique en supprimant toute partie fractionnaire.
  • Pour les valeurs float : division à partir de la norme IEEE-754.
  • Pour les nombres complexes : division complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)
(I2) rhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

 Autres exemples

dot_general

Sémantique

Calcule les produits scalaires entre les tranches de lhs et celles de rhs, et génère un Tensor result.

Plus précisément, result[result_index] = dot_product, où :

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_indexsize(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) et size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Pour les types quantifiés, effectue dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Pour les types quantifiés hybrides, effectue hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config contrôle le compromis entre vitesse et précision pour les calculs sur les backends d'accélérateur. Il peut s'agir de l'une des valeurs suivantes (pour le moment, la sémantique de ces valeurs d'énumération n'est pas suffisamment spécifiée, mais nous prévoyons de résoudre ce problème dans #755) :

  • DEFAULT : calcul le plus rapide, mais approximation la moins précise du nombre d'origine.
  • HIGH : calcul plus lent, mais approximation plus précise du nombre d'origine.
  • HIGHEST : le calcul le plus lent, mais l'approximation la plus précise du nombre d'origine.

Un DotAlgorithm définit les propriétés principales de l'algorithme utilisé pour implémenter l'opération de produit scalaire, qui définit également la précision. Si les champs d'attribut d'algorithme sont définis, precision_config doit être DEFAULT. DotAlgorithms n'ont pas de valeur par défaut, car les paramètres par défaut sont définis par l'implémentation. Par conséquent, tous les champs de l'algorithme de points peuvent être définis sur None pour spécifier un algorithme de points vide, qui utilisera plutôt la valeur precision_config.

Les champs DotAlgorithm incluent les suivants :

  • lhs_precision_type et rhs_precision_type, les précisions auxquelles les côtés gauche et droit de l'opération sont arrondis. Les types de précision sont indépendants des types de stockage des entrées et de la sortie.
  • accumulation_type : précision utilisée pour l'accumulation.
  • lhs_component_count, rhs_component_count et num_primitive_operations s'appliquent lorsque nous effectuons un algorithme qui décompose les côtés gauche et/ou droit en plusieurs composants et effectue plusieurs opérations de produit scalaire "primitives" sur ces valeurs, généralement pour émuler une précision plus élevée (par exemple, Utiliser le type de données d'intelligence artificielle bfloat16 pour des calculs plus précis : bf16_6x tf32_3x, etc.). Pour les algorithmes sans décomposition, ces valeurs doivent être définies sur 1.
  • allow_imprecise_accumulation pour spécifier si l'accumulation dans une précision inférieure est autorisée pour certaines étapes (par exemple, CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Exemples d'attributs DotAlgorithm :

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

Il appartient aux implémentations de décider quelles combinaisons sont prises en charge. En général, il n'est pas garanti que chaque algorithme soit compatible avec chaque type d'accélérateur par le consommateur de StableHLO. Si un algorithme donné n'est pas pris en charge, une erreur doit être générée au lieu de revenir à une alternative. La validation StableHLO fournira la meilleure vérification possible, en empêchant les algorithmes qui ne sont pas connus pour être compatibles avec aucun matériel.

Pour obtenir des exemples de valeurs d'algorithme acceptées, consultez xla_data.proto > Algorithm. La demande 2483 décrit le plan de création d'un document centralisé sur les algorithmes compatibles par backend.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié par tenseur (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs tenseur ou tenseur quantifié (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions Constante Tensor de type si64 à une dimension (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Constante Tensor de type si64 à une dimension (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Constante Tensor de type si64 à une dimension (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Constante Tensor de type si64 à une dimension (C2), (C4), (C8), (C10), (C16)
(I7) precision_config Nombre variable d'énums de DEFAULT, HIGH et HIGHEST (C11), (C21)
(I8) lhs_precision_type FloatType ou TensorFloat32 (C21)
(I9) rhs_precision_type FloatType ou TensorFloat32 (C21)
(I10) accumulation_type FloatType ou TensorFloat32 (C21)
(I11) lhs_component_count constante de type si32 (C21), (C22)
(I12) rhs_component_count constante de type si32 (C21), (C23)
(I13) num_primitive_operations constante de type si32 (C21), (C24)
(I14) allow_imprecise_accumulation constante de type bool (C21)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C12), (C14), (C18-C20)

Contraintes

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Si l'opération utilise des Tensors non quantifiés :
    • (C13) element_type(lhs) = element_type(rhs).
  • Si l'opération utilise des Tensors quantifiés :
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) n'est pas dans rhs_contracting_dimensions.
    • Si is_quantized(lhs) :
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Si is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result).
    • Si !is_quantized(lhs) :
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • Si !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation) :
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Exemples

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #sta<blehlo.dot
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimension>s = [1]
  ,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
  algorithm = #stablehlo.dot<_algorithm
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation >= false
  
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

 Autres exemples

dynamic_broadcast_in_dim

Sémantique

Cette opération est fonctionnellement identique à l'opération broadcast_in_dim, mais la forme du résultat est spécifiée de manière dynamique via output_dimensions.

L'opération accepte également les attributs facultatifs known_expanding_dimensions et known_nonexpanding_dimensions pour exprimer des connaissances statiques sur le comportement d'expansion des dimensions. Si aucune n'est spécifiée, toutes les dimensions sont considérées comme pouvant être développées.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié (C1-C2), (C5-C6), (C9)
(I2) output_dimensions Tensor unidimensionnel de type entier (C7)
(I3) broadcast_dimensions Tensor constant unidimensionnel de type entier (C2-C6)
(I4) known_expanding_dimensions Tensor constant unidimensionnel de type entier (C8-C9)
(I5) known_nonexpanding_dimensions Tensor constant unidimensionnel de type entier (C8-C9)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C1), (C3), (C5-C7)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf si quantization_dimension(operand), scales(operand) et zero_points(operand) diffèrent de quantization_dimension(result), scales(result) et zero_points(result), respectivement.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Pour tous les d dans axes(operand) :
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si is_per_axis_quantized(result) :
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si dim(operand, quantization_dimension(operand)) = 1, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand).

Exemples

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensio<ns = arra>yi64: 2, 1,
  known_expanding_dimensio<ns = a>rrayi64: 0,
  known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 Autres exemples

dynamic_conv

Sémantique

Cette opération est fonctionnellement identique à l'opération convolution, mais le remplissage est spécifié de manière dynamique via padding.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié par tenseur (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs tenseur ou tenseur quantifié (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding Tensor bidimensionnel de type entier (C4)
(I4) window_strides Constante Tensor de type si64 à une dimension (C2-C3)
(I5) lhs_dilation Constante Tensor de type si64 à une dimension (C5-C6)
(I6) rhs_dilation Constante Tensor de type si64 à une dimension (C7-C8)
(I7) window_reversal Constante Tensor de type i1 à une dimension (C9)
(I8) input_batch_dimension constante de type si64 (C10), (C13)
(I9) input_feature_dimension constante de type si64 (C11), (C13-C14)
(I10) input_spatial_dimensions Constante Tensor de type si64 à une dimension (C12), (C13)
(I11) kernel_input_feature_dimension constante de type si64 (C14), (C18)
(I12) kernel_output_feature_dimension constante de type si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions Constante Tensor de type si64 à une dimension (C17-C18)
(I14) output_batch_dimension constante de type si64 (C20)
(I15) output_feature_dimension constante de type si64 (C20), (C29)
(I16) output_spatial_dimensions Constante Tensor de type si64 à une dimension (C19-C20)
(I17) feature_group_count constante de type si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count constante de type si64 (C10), (C15), (C22), (C23)
(I19) precision_config Nombre variable d'énums de DEFAULT, HIGH et HIGHEST (C24)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C25-C27), (C29), (C31-C33)

Contraintes

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Étant donné input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] :
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Étant donné kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] :
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Étant donné output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] :
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) est défini comme suit :
    • dim(lhs, input_batch_dimension) / batch_group_count si result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si result_dim = output_feature_dimension.
    • num_windows sinon, où :
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim]
    • rhs_dim = kernel_spatial_dimensions[spatial_dim]
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si l'opération utilise des Tensors non quantifiés :
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si l'opération utilise des Tensors quantifiés :
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Si is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Si is_per_axis_quantized(result), alorsquantization_dimension(result) = output_feature_dimension.
    • Si is_quantized(lhs) :
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Si is_per_tensor_quantized(rhs), alorsis_per_tensor_quantized(result).
    • Si !is_quantized(lhs) :
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Exemples

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strid<es = arra>yi64: 4, 4,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  dimension_numbers = #stab<lehlo.convraw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions => [1, 2]
  ,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

 Autres exemples

dynamic_gather

Sémantique

Cette opération est fonctionnellement identique à l'opération gather, avec slice_sizes spécifié de manière dynamique en tant que valeur.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1), (C7), (C10-C12), (C14)
(I2) start_indices Tensor de type entier (C2), (C3), (C13)
(I3) slice_sizes Tensor unidimensionnel de type entier (C8), (C11-C13)
(I4) offset_dims Constante Tensor de type si64 à une dimension (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims Constante Tensor de type si64 à une dimension (C1), (C6-C8), (C13)
(I6) start_index_map Constante Tensor de type si64 à une dimension (C3), (C9), (C10)
(I7) index_vector_dim constante de type si64 (C2), (C3), (C13)
(I8) indices_are_sorted constante de type i1

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C5), (C13-C14)

Contraintes

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) où :
    • batch_dim_sizes = shape(start_indices), sauf que la taille de la dimension start_indices correspondant à index_vector_dim n'est pas incluse.
    • offset_dim_sizes = shape(slice_sizes), sauf que les tailles de dimension dans slice_sizes correspondant à collapsed_slice_dims ne sont pas incluses.
    • combine place batch_dim_sizes sur les axes correspondant à batch_dims et offset_dim_sizes sur les axes correspondant à offset_dims.
  • (C14) element_type(operand) = element_type(result).

Exemples

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vect>or_dim = 2,
  indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

 Autres exemples

dynamic_iota

Sémantique

Cette opération est fonctionnellement identique à l'opération iota, mais la forme du résultat est spécifiée de manière dynamique via output_shape.

Entrées

Libellé Nom Type Contraintes
(I1) output_shape Tensor unidimensionnel de type entier (C1), (C2)
(I2) iota_dimension si64 (C1)

Sorties

Nom Type Contraintes
result Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C2)

Contraintes

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Exemples

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

 Autres exemples

dynamic_pad

Sémantique

Cette opération est fonctionnellement identique à l'opération pad, mais avec edge_padding_low, edge_padding_high et interior_padding spécifiés de manière dynamique en tant que valeurs.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1), (C2), (C4)
(I2) padding_value Tenseur de dimension 0 ou tenseur quantifié par tenseur (C1)
(I3) edge_padding_low Tensor unidimensionnel de type entier (C1), (C4)
(I4) edge_padding_high Tensor unidimensionnel de type entier (C1), (C4)
(I5) interior_padding Tensor unidimensionnel de type entier (C2-C4)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C3-C6)

Contraintes

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemples

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Autres exemples

dynamic_reshape

Sémantique

Cette opération est fonctionnellement identique à l'opération reshape, mais la forme du résultat est spécifiée de manière dynamique via output_shape.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié (C1-C3)
(I2) output_shape Tensor unidimensionnel de type entier (C4)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C1-C4)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf que quantization_dimension(operand) et quantization_dimension(result) peuvent être différents.
  • (C2) size(operand) = size(result).
  • (C3) Si is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]

 Autres exemples

dynamic_slice

Sémantique

Extrait une tranche de operand à l'aide d'indices de début calculés de manière dynamique et produit un Tensor result. start_indices contient les indices de début de la tranche pour chaque dimension susceptible d'être ajustée, et slice_sizes contient la taille de la tranche pour chaque dimension. Plus formellement, result[result_index] = operand[operand_index] où :

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1), (C2), (C4)
(I2) start_indices Nombre variable de tenseurs de dimension 0 de type entier (C2), (C3)
(I3) slice_sizes Constante Tensor de type si64 à une dimension (C2), (C4), (C5)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Exemples

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

 Autres exemples

dynamic_update_slice

Sémantique

Produit un Tensor result qui est égal au Tensor operand, sauf que la tranche commençant à start_indices est mise à jour avec les valeurs de update. Plus formellement, result[result_index] est défini comme suit :

  • update[update_index] si 0 <= update_index < shape(update) où :
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • Sinon, operand[result_index].

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1-C4), (C6)
(I2) update tenseur ou tenseur quantifié par tenseur (C2), (C3), (C6)
(I3) start_indices Nombre variable de tenseurs de dimension 0 de type entier (C4), (C5)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Exemples

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
 < : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

 Autres exemples

exponentiel

Sémantique

Effectue une opération exponentielle au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : exp à partir de la norme IEEE-754.
  • Pour les nombres complexes : exponentielle complexe.
  • Pour les types quantifiés : dequantize_op_quantize(exponential, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

 Autres exemples

exponential_minus_one

Sémantique

Effectue une opération exponentielle moins un au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : expm1 à partir de la norme IEEE-754.
  • Pour les nombres complexes : exponentielle complexe moins un.
  • Pour les types quantifiés : dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]

 Autres exemples

fft

Sémantique

Effectue les transformées de Fourier directes et inverses pour les entrées/sorties réelles et complexes.

fft_type est l'un des éléments suivants :

  • FFT : FFT complexe-complexe directe.
  • IFFT : FFT complexe-complexe inverse.
  • RFFT : FFT de réel à complexe.
  • IRFFT : FFT inverse du réel au complexe (c'est-à-dire prend un complexe et renvoie un réel).

Plus formellement, étant donné la fonction fft qui prend en entrée des Tensors unidimensionnels de types complexes, produit en sortie des Tensors unidimensionnels de mêmes types et calcule la transformée de Fourier discrète :

Pour fft_type = FFT, result est défini comme le résultat final d'une série de calculs L où L = size(fft_length). Par exemple, pour L = 3 :

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

De plus, étant donné la fonction ifft qui a la même signature de type et calcule l'inverse de fft :

Pour fft_type = IFFT, result est défini comme l'inverse des calculs pour fft_type = FFT. Par exemple, pour L = 3 :

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

De plus, étant donné la fonction rfft qui prend des tenseurs unidimensionnels de types à virgule flottante, produit des tenseurs unidimensionnels de types complexes de la même sémantique à virgule flottante et fonctionne comme suit :

  • rfft(real_operand) = truncated_result
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand)
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Lorsque la transformée de Fourier discrète est calculée pour des opérandes réels, les N/2 + 1 premiers éléments du résultat définissent sans ambiguïté le reste du résultat. Le résultat de rfft est donc tronqué pour éviter de calculer des éléments redondants.)

Pour fft_type = RFFT, result est défini comme le résultat final d'une série de calculs L où L = size(fft_length). Par exemple, pour L = 3 :

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Enfin, étant donné la fonction irfft qui a la même signature de type et qui calcule l'inverse de rfft :

Pour fft_type = IRFFT, result est défini comme l'inverse des calculs pour fft_type = RFFT. Par exemple, pour L = 3 :

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe (C1), (C2), (C4), (C5)
(I2) fft_type Énumération de FFT, IFFT, RFFT et IRFFT (C2), (C5)
(I3) fft_length Constante Tensor de type si64 à une dimension (C1), (C3), (C4)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe (C2), (C4), (C5)

Contraintes

  • (C1) size(fft_length) <= rank(operand).
  • (C2) La relation entre les types d'éléments operand et result varie :
    • Si fft_type = FFT, element_type(operand) et element_type(result) ont le même type complexe.
    • Si fft_type = IFFT, element_type(operand) et element_type(result) ont le même type complexe.
    • Si fft_type = RFFT, element_type(operand) est un type à virgule flottante et element_type(result) est un type complexe de la même sémantique à virgule flottante.
    • Si fft_type = IRFFT, element_type(operand) est un type complexe et element_type(result) est un type à virgule flottante de la même sémantique à virgule flottante.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Si, parmi operand et result, il existe un Tensor real de type à virgule flottante, alors shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) sauf :
    • Si fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Si fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Exemples

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = <#stablehloff>t_type FFT,
  fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

étage

Sémantique

Effectue le plancher élément par élément du Tensor operand et produit un Tensor result. Implémente l'opération roundToIntegralTowardNegative à partir de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(floor, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

 Autres exemples

gather

Sémantique

Collecte les tranches du Tensor operand à partir des décalages spécifiés dans start_indices et produit un Tensor result.

Le schéma suivant montre comment les éléments de result sont mappés sur les éléments de operand à l'aide d'un exemple concret. Le diagramme choisit quelques exemples d'indices result et explique en détail à quels indices operand ils correspondent.

gather

Plus précisément, result[result_index] = operand[operand_index] où :

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index est défini comme suit :
    • start_indices[bi0, ..., :, ..., biN]bi sont des éléments individuels dans batch_index et : est inséré à l'index index_vector_dim, si index_vector_dim < rank(start_indices).
    • Sinon, [start_indices[batch_index]].
  • Pour d_operand dans axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) si d_operand = start_index_map[d_start].
    • Sinon, full_start_index[d_operand] = 0.
  • Pour d_operand dans axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]si d_operand = operand_batching_dims[i_batching] et d_start = start_indices_batching_dims[i_batching].
    • Sinon, full_batching_index[d_operand] = 0.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN]oi sont des éléments individuels dans offset_index, et 0 est inséré aux indices de collapsed_slice_dims et operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Si indices_are_sorted est true, l'implémentation peut supposer que start_indices est trié par rapport à start_index_map. Sinon, le comportement n'est pas défini. Plus précisément, pour tous les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices tenseur de type entier (C2-C3), (C14), (C17), (C22)
(I3) offset_dims Constante Tensor de type si64 à une dimension (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims Constante Tensor de type si64 à une dimension (C1), (C6-C9), (C22)
(I5) operand_batching_dims Constante Tensor de type si64 à une dimension (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims Constante Tensor de type si64 à une dimension (C13-C17)
(I7) start_index_map Constante Tensor de type si64 à une dimension (C3), (C18-C19)
(I8) index_vector_dim constante de type si64 (C2-C3), (C15), (C22)
(I9) slice_sizes Constante Tensor de type si64 à une dimension (C9), (C12), (C20-C22)
(I10) indices_are_sorted constante de type i1

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C5), (C22-C23)

Contraintes

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) où :
    • batch_dim_sizes = shape(start_indices), sauf que la taille de la dimension start_indices correspondant à index_vector_dim n'est pas incluse.
    • offset_dim_sizes = slice_sizes, sauf que les tailles de dimension dans slice_sizes correspondant à collapsed_slice_dims et operand_batching_dims ne sont pas incluses.
    • combine place batch_dim_sizes sur les axes correspondant à batch_dims et offset_dim_sizes sur les axes correspondant à offset_dims.
  • (C23) element_type(operand) = element_type(result).

Exemples

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vect>or_dim = 3,
  slice_siz<es = arrayi64: >1, 1, 2, 2,
  indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

 Autres exemples

get_dimension_size

Sémantique

Produit la taille de l'dimension donné de l'operand. Plus précisément, result = dim(operand, dimension). La sémantique ne concerne que le composant de forme du type. Le type d'élément peut être n'importe quoi.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié (C1)
(I2) dimension constante de type si64 (C1)

Sorties

Nom Type
result Tensor de dimension 0 de type si32

Contraintes

  • (C1) 0 <= dimension < rank(operand).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3

 Autres exemples

get_tuple_element

Sémantique

Extrait l'élément à la position index du tuple operand et génère un result. Plus précisément, result = operand[index].

Entrées

Libellé Nom Type Contraintes
(I1) operand tuple (C1), (C2)
(I2) index constante de type si32 (C1), (C2)

Sorties

Nom Type Contraintes
result toute valeur (C2)

Contraintes

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Exemples

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]

 Autres exemples

si

Sémantique

Génère le résultat de l'exécution d'une seule fonction à partir de true_branch ou false_branch en fonction de la valeur de pred. Plus précisément, result = pred ? true_branch() : false_branch().

Entrées

Libellé Nom Type Contraintes
(I1) pred Tensor de dimension 0 de type i1
(I2) true_branch fonction (C1-C3)
(I3) false_branch fonction (C1), (C2)

Sorties

Nom Type Contraintes
results un nombre variable de Tensors, de Tensors quantifiés ou de jetons. (C3)

Contraintes

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Exemples

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
  "stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10

 Autres exemples

imag

Sémantique

Extrait la partie imaginaire, élément par élément, de operand et produit un tenseur result. Plus précisément, pour chaque élément x : imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • Sinon, element_type(operand).

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]

 Autres exemples

in-feed

Sémantique

Lit les données du flux In-Feed et génère results.

La sémantique de infeed_config est définie par l'implémentation.

results se compose de valeurs de charge utile qui viennent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour plus de clarté (#670).

Entrées

Libellé Nom Type
(I1) token token
(I2) infeed_config constante de type string

Sorties

Nom Type Contraintes
results un nombre variable de Tensors, de Tensors quantifiés ou de jetons. (C1-C3)

Contraintes

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Exemples

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

 Autres exemples

iota

Sémantique

Remplit un Tensor output avec des valeurs dans l'ordre croissant à partir de zéro le long de la dimension iota_dimension. Plus précisément,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Entrées

Libellé Nom Type Contraintes
(I1) iota_dimension si64 (C1)

Sorties

Nom Type Contraintes
output Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) 0 <= iota_dimension < rank(output).

Exemples

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

 Autres exemples

is_finite

Sémantique

Effectue une vérification élément par élément pour déterminer si la valeur dans x est finie (c'est-à-dire ni +Inf, ni -Inf, ni NaN) et produit un Tensor y. Implémente l'opération isFinite de la spécification IEEE-754. Pour les types quantifiés, le résultat est toujours true.

Entrées

Libellé Nom Type Contraintes
(I1) x Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
y Tensor de type booléen (C1)

Contraintes

  • (C1) shape(x) = shape(y).

Exemples

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]

 Autres exemples

log

Sémantique

Effectue une opération logarithmique au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : log à partir de la norme IEEE-754.
  • Pour les nombres complexes : logarithme complexe.
  • Pour les types quantifiés : dequantize_op_quantize(log, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

 Autres exemples

log_plus_one

Sémantique

Effectue une opération de logarithme plus un au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : logp1 à partir de la norme IEEE-754.
  • Pour les nombres complexes : complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))
  • Pour les types quantifiés : dequantize_op_quantize(log_plus_one, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

 Autres exemples

logistique

Sémantique

Effectue une opération logistique au niveau des éléments sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : division(1, addition(1, exp(-x))) à partir de la norme IEEE-754.
  • Pour les nombres complexes : logistique complexe.
  • Pour les types quantifiés : dequantize_op_quantize(logistic, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

 Autres exemples

carte

Sémantique

Applique une fonction de mappage computation à inputs le long de dimensions et produit un Tensor result.

Plus précisément, result[result_index] = computation(inputs...[result_index]).

Entrées

Libellé Nom Type Contraintes
(I1) inputs un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1-C4)
(I2) dimensions Constante Tensor de type si64 à une dimension (C3)
(I3) computation fonction (C4)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1), (C4)

Contraintes

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation est de type (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>Ei = element_type(inputs[i]) et E' = element_type(result).

Exemples

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
    stablehlo.return %<0 :> tensori64
}) {
  dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]

 Autres exemples

maximum

Sémantique

Effectue une opération max au niveau des éléments sur les Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : OR logique.
  • Pour les nombres entiers : nombre entier maximal.
  • Pour les valeurs float : maximum à partir de la norme IEEE-754.
  • Pour les nombres complexes : maximum lexicographique pour la paire (real, imaginary). L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (#560).
  • Pour les types quantifiés :
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié par tenseur (C1)
(I2) rhs tenseur ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]

 Autres exemples

minimum

Sémantique

Effectue une opération min élément par élément sur les Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : AND logique.
  • Pour les nombres entiers : nombre entier minimal.
  • Pour les valeurs float : minimum à partir de la norme IEEE-754.
  • Pour les nombres complexes : minimum lexicographique pour la paire (real, imaginary). L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (#560).
  • Pour les types quantifiés :
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié par tenseur (C1)
(I2) rhs tenseur ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]

 Autres exemples

multiplier

Sémantique

Effectue le produit élément par élément de deux tenseurs lhs et rhs, et produit un tenseur result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : AND logique.
  • Pour les nombres entiers : multiplication d'entiers.
  • Pour les valeurs float : multiplication à partir de la norme IEEE-754.
  • Pour les nombres complexes : multiplication complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur ou tenseur quantifié par tenseur (C1)
(I2) rhs tenseur ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]

 Autres exemples

negate

Sémantique

Effectue la négation élément par élément du tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les entiers signés : négation d'un entier.
  • Pour les entiers non signés : bitcast vers un entier signé, négation de l'entier, bitcast vers un entier non signé.
  • Pour les valeurs float : negate à partir de la norme IEEE-754.
  • Pour les nombres complexes : négation complexe.
  • Pour les types quantifiés : dequantize_op_quantize(negate, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]

 Autres exemples

not

Sémantique

Effectue un NOT élément par élément du Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : NOT logique.
  • Pour les entiers : NOT au niveau du bit.

Arguments

Nom Type Contraintes
operand Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]

 Autres exemples

optimization_barrier

Sémantique

Garantit que les opérations qui produisent le operand sont exécutées avant toute opération qui dépend du result et empêche les transformations du compilateur de déplacer des opérations au-delà de la barrière. À part cela, l'opération est une identité, c'est-à-dire result = operand.

Arguments

Nom Type Contraintes
operand Nombre variable de tenseurs, de tenseurs ou de jetons quantifiés par tenseur (C1)

Sorties

Nom Type Contraintes
result Nombre variable de tenseurs, de tenseurs ou de jetons quantifiés par tenseur (C1)

Contraintes

  • (C1) type(operand...) = type(result...).

Exemples

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0

 Autres exemples

ou

Sémantique

Effectue un OR élément par élément de deux Tensors lhs et rhs et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : OR logique.
  • Pour les entiers : opérateur OR (OU) au niveau du bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier ou booléen (C1)
(I2) rhs Tensor de type entier ou booléen (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier ou booléen (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]

 Autres exemples

sortie

Sémantique

Écrit inputs dans le flux de sortie et génère un jeton result.

La sémantique de outfeed_config est définie par l'implémentation.

Entrées

Libellé Nom Type
(I1) inputs un nombre variable de Tensors ou de Tensors quantifiés.
(I2) token token
(I3) outfeed_config constante de type string

Sorties

Nom Type
result token

Exemples

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token

 Autres exemples

pad

Sémantique

Développe operand en ajoutant une marge intérieure autour du Tensor ainsi qu'entre les éléments du Tensor avec le padding_value donné.

edge_padding_low et edge_padding_high spécifient la quantité de marge intérieure ajoutée respectivement à l'extrémité inférieure (à côté de l'index 0) et à l'extrémité supérieure (à côté de l'index le plus élevé) de chaque dimension. La quantité de marge intérieure peut être négative, où la valeur absolue de la marge intérieure négative indique le nombre d'éléments à supprimer de la dimension spécifiée.

interior_padding spécifie la marge intérieure ajoutée entre deux éléments dans chaque dimension, qui ne peut pas être négative. La marge intérieure se produit avant la marge extérieure, de sorte qu'une marge extérieure négative supprimera les éléments de l'opérande avec marge intérieure.

Plus formellement, result[result_index] est défini comme suit :

  • operand[operand_index] siresult_index = edge_padding_low + operand_index * (interior_padding + 1).
  • Sinon, padding_value.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1), (C2), (C4)
(I2) padding_value Tenseur de dimension 0 ou tenseur quantifié par tenseur (C1)
(I3) edge_padding_low Constante Tensor de type si64 à une dimension (C1), (C4)
(I4) edge_padding_high Constante Tensor de type si64 à une dimension (C1), (C4)
(I5) interior_padding Constante Tensor de type si64 à une dimension (C2-C4)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C3-C6)

Contraintes

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemples

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_l<ow = arra>yi64: 0, 1,
  edge_padding_hi<gh = arra>yi64: 2, 1,
  interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Autres exemples

partition_id

Sémantique

Produit partition_id du processus actuel.

Sorties

Nom Type
result Tensor de dimension 0 de type ui32

Exemples

%result = "stablehlo.partition_id">;() : (<) - >tensorui32

 Autres exemples

popcnt

Sémantique

Effectue un décompte élément par élément du nombre de bits définis dans le Tensor operand et produit un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]

 Autres exemples

puissance

Sémantique

Effectue une exponentiation élément par élément du Tensor lhs par le Tensor rhs et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les nombres entiers : exponentiation d'entiers.
  • Pour les valeurs float : pow à partir de la norme IEEE-754.
  • Pour les nombres complexes : exponentiation complexe.
  • Pour les types quantifiés : dequantize_op_quantize(power, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)
(I2) rhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

 Autres exemples

real

Sémantique

Extrait la partie réelle, élément par élément, de operand et produit un Tensor result. Plus précisément, pour chaque élément x : real(x) = is_complex(x) ? real_part(x) : x.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • Sinon, element_type(operand).

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]

 Autres exemples

recv

Sémantique

Reçoit des données d'un canal avec channel_id et produit results.

Si is_host_transfer est défini sur true, l'opération transfère les données depuis l'hôte. Sinon, il transfère les données depuis un autre appareil en fonction des valeurs de source_target_pairs. Cette option duplique les informations fournies dans channel_type. À l'avenir, nous prévoyons de n'en conserver qu'une seule (#666). Si is_host_transferfalse et que source_target_pairs est None ou vide, le comportement est considéré comme indéfini.

results se compose de valeurs de charge utile qui viennent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour plus de clarté (#670).

Entrées

Libellé Nom Type Contraintes
(I1) token token
(I2) channel_id constante de type si64
(I3) channel_type Énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST (C5)
(I4) is_host_transfer constante de type i1 (C5-C6)
(I5) source_target_pairs Constante Tensor à deux dimensions de type si64 (C1-C4), (C6)

Sorties

Nom Type Contraintes
results un nombre variable de Tensors, de Tensors quantifiés ou de jetons. (C2-C4)

Contraintes

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, où N est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C5) channel_type est défini comme suit :
    • DEVICE_TO_HOST si is_host_transfer = true,
    • Sinon, DEVICE_TO_DEVICE.

Exemples

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)

 Autres exemples

reduce

Sémantique

Applique une fonction de réduction body à inputs et init_values le long de dimensions et produit des Tensors results.

L'ordre des réductions est défini par l'implémentation, ce qui signifie que body et init_values doivent former un monoïde pour garantir que l'opération produit les mêmes résultats pour toutes les entrées sur toutes les implémentations. Toutefois, cette condition ne s'applique pas à de nombreuses réductions populaires. Par exemple, l'addition à virgule flottante pour body et zéro pour init_values ne forment pas réellement un monoïde, car l'addition à virgule flottante n'est pas associative.

Plus précisément, results...[j0, ..., jR-1] = reduce(input_slices_converted) où :

  • input_slices = inputs...[j0, ..., :, ..., jR-1], où : sont insérés à dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) pour un arbre binaire donné schedule où :
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est un arbre binaire complet défini par l'implémentation dont la traversée dans l'ordre se compose de :
    • Valeurs input_slices_converted...[index], pour tous les index dans index_space(input_slices_converted) dans l'ordre lexicographique croissant de index.
    • Entrecroisé avec une quantité de init_values_converted définie par l'implémentation à des positions définies par l'implémentation.

Entrées

Libellé Nom Type Contraintes
(I1) inputs un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1-C4), (C6), (C7)
(I2) init_values nombre variable de Tensors de dimension 0 ou de Tensors quantifiés par Tensor (C2), (C3)
(I3) dimensions Constante Tensor de type si64 à une dimension (C4), (C5), (C7)
(I4) body fonction (C6)

Sorties

Nom Type Contraintes
results un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C3), (C7), (C8)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), sauf que les tailles de dimension de inputs... correspondant à dimensions ne sont pas incluses.
  • (C8) element_type(results[i]) = Ei pour tous les i dans [0,N).

Exemples

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
  dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]

 Autres exemples

reduce_precision

Sémantique

Effectue la conversion élément par élément de operand vers un autre type à virgule flottante qui utilise exponent_bits et mantissa_bits, puis de nouveau vers le type à virgule flottante d'origine, et produit un Tensor output.

Plus précisément :

  • Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine à la valeur la plus proche pouvant être représentée avec mantissa_bits à l'aide de la sémantique roundToIntegralTiesToEven.
  • Ensuite, si mantissa_bits est inférieur au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués à mantissa_bits.
  • Ensuite, si les bits d'exposant du résultat intermédiaire ne tiennent pas dans la plage fournie par exponent_bits, le résultat intermédiaire déborde à l'infini en utilisant le signe d'origine ou déborde à zéro en utilisant le signe d'origine.
  • Pour les types quantifiés, effectue dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)
(I2) exponent_bits constante de type si32 (C2)
(I3) mantissa_bits constante de type si32 (C3)

Sorties

Nom Type Contraintes
output Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Exemples

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

 Autres exemples

reduce_scatter

Sémantique

reduce_scatter

Dans chaque groupe de processus de la grille de processus StableHLO, effectue une réduction, à l'aide de computations, sur les valeurs du Tensor operand de chaque processus, divise le résultat de la réduction le long de scatter_dimension en parties et distribue les parties divisées entre les processus pour produire result.

L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group :

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] pour tous les sender dans process_group, où receiver_index = process_group.index(receiver).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1), (C2), (C7), (C8)
(I2) scatter_dimension constante de type si64 (C1), (C2), (C8)
(I3) replica_groups Constante Tensor à deux dimensions de type si64 (C3-C5)
(I4) channel_id constante de type si64 (C6)
(I5) use_global_device_ids constante de type i1 (C6)
(I6) computation fonction (C7)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C8-C9)

Contraintes

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Si use_global_device_ids = true, alors channel_id > 0.
  • (C7) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>)is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) sauf :
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
  %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
  "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimension = 1 :< i64,
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

 Autres exemples

reduce_window

Sémantique

Applique une fonction de réduction body aux fenêtres inputs et init_values et génère results.

Le schéma suivant montre comment les éléments de results... sont calculés à partir de inputs... à l'aide d'un exemple concret.

reduce_window

Plus formellement, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (voir reduce) où :

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Entrées

Libellé Nom Type Contraintes
(I1) inputs un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values nombre variable de Tensors de dimension 0 ou de Tensors quantifiés par Tensor (C1), (C13)
(I3) window_dimensions Constante Tensor de type si64 à une dimension (C4), (C5), (C15)
(I4) window_strides Constante Tensor de type si64 à une dimension (C6), (C7), (C15)
(I5) base_dilations Constante Tensor de type si64 à une dimension (C8), (C9), (C15)
(I6) window_dilations Constante Tensor de type si64 à une dimension (C10), (C11), (C15)
(I7) padding Constante Tensor à deux dimensions de type si64 (C12), (C15)
(I8) body fonction (C13)

Sorties

Nom Type Contraintes
results un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1), (C14-C16)

Contraintes

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows où :
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei pour tous les i dans [0,N).

Exemples

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
  wind>ow_dimensions = arrayi64: <2, 1,
  w>indow_strides = arrayi64: <4, 1,
  b>ase_dilations = arrayi64: 2,< 1,
  win>dow_dilations = arr<ayi64: 3, 1,
  p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]

 Autres exemples

reste

Sémantique

Effectue le reste élément par élément des Tensors de dividende lhs et de diviseur rhs, et produit un Tensor result.

Plus formellement, le signe du résultat est tiré du dividende, et la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur. Le reste est calculé comme suit : lhs - d * rhs, où d est donné par :

  • Pour les nombres entiers : stablehlo.divide(lhs, rhs).
  • Pour les valeurs flottantes : division(lhs, rhs) à partir de IEEE-754 avec l'attribut d'arrondi roundTowardZero.
  • Pour les nombres complexes : à déterminer (#997).
  • Pour les types quantifiés :
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Pour les types d'éléments à virgule flottante, cette opération est différente de l'opération remainder de la spécification IEEE-754, où d est une valeur entière la plus proche de la valeur exacte de lhs/rhs avec des liens vers le nombre pair.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)
(I2) rhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]

 Autres exemples

replica_id

Sémantique

Produit replica_id du processus actuel.

Sorties

Nom Type
result Tensor de dimension 0 de type ui32

Exemples

%result = "stablehlo.replica_id">;() : (<) - >tensorui32

 Autres exemples

reshape

Sémantique

Remodele le Tensor operand en un Tensor result. Conceptuellement, cela revient à conserver la même représentation canonique, mais à modifier potentiellement la forme, par exemple de tensor<2x3xf32> à tensor<3x2xf32> ou tensor<6xf32>.

Plus précisément, result[result_index] = operand[operand_index]result_index et operand_index ont la même position dans l'ordre lexicographique de index_space(result) et index_space(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié (C1-C3)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C1-C3)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf que quantization_dimension(operand) et quantization_dimension(result) peuvent être différents.
  • (C2) size(operand) = size(result).
  • (C3) Si is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]

 Autres exemples

inverser

Sémantique

Inverse l'ordre des éléments dans operand le long de la dimensions spécifiée et génère un Tensor result. Plus formellement, result[result_index] = operand[operand_index] où :

  • operand_index[d] = dim(result, d) - result_index[d] - 1 si d dans dimensions.
  • Sinon, operand_index[d] = result_index[d].

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1), (C3)
(I2) dimensions Constante Tensor de type si64 à une dimension (C2), (C3)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1), (C3)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Exemples

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]

 Autres exemples

rng

Sémantique

Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution et produit un Tensor result d'une forme donnée shape.

Si la valeur est rng_distribution = UNIFORM, les nombres aléatoires sont générés selon la distribution uniforme sur l'intervalle [a, b). Si a >= b, le comportement n'est pas défini.

Si rng_distribution = NORMAL, les nombres aléatoires sont générés selon la distribution normale avec la moyenne = a et l'écart-type = b. Si la valeur est b < 0, le comportement n'est pas défini.

La méthode exacte de génération des nombres aléatoires est définie par l'implémentation. Par exemple, ils peuvent être déterministes ou non, et utiliser ou non un état caché.

Lors de conversations avec de nombreuses parties prenantes, cette opération a été considérée comme obsolète. Nous prévoyons donc de l'explorer à l'avenir (#597).

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor de type entier, booléen ou à virgule flottante de dimension 0 (C1), (C2)
(I2) b Tensor de type entier, booléen ou à virgule flottante de dimension 0 (C1), (C2)
(I3) shape Constante Tensor de type si64 à une dimension (C3)
(I4) rng_distribution Énumération de UNIFORM et NORMAL (C2)

Sorties

Nom Type Contraintes
result Tensor de type entier, booléen ou à virgule flottante (C1-C3)

Contraintes

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Si rng_distribution = NORMAL, alors is_float(a).
  • (C3) shape(result) = shape.

Exemples

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Sémantique

Renvoie un output rempli de bits aléatoires uniformes et un état de sortie mis à jour output_state à l'aide de l'algorithme de générateur de nombres pseudo-aléatoires rng_algorithm, étant donné un état initial initial_state. La sortie est garantie comme étant une fonction déterministe de initial_state, mais elle n'est pas garantie comme étant déterministe entre les implémentations.

rng_algorithm est l'un des éléments suivants :

  • DEFAULT : algorithme défini par l'implémentation.
  • THREE_FRY : variante de l'algorithme Threefry définie par l'implémentation*.
  • PHILOX : variante de l'algorithme Philox définie par l'implémentation*.

* Voir Salmon et al. SC 2011. Nombres aléatoires parallèles : un jeu d'enfant.

Entrées

Libellé Nom Type Contraintes
(I1) rng_algorithm Énumération de DEFAULT, THREE_FRY et PHILOX (C2)
(I2) initial_state Tensor de type ui64 à une dimension (C1), (C2)

Sorties

Nom Type Contraintes
output_state Tensor de type ui64 à une dimension (C1)
output Tensor de type entier ou à virgule flottante

Contraintes

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) est défini comme suit :
    • défini par l'implémentation si rng_algorithm = DEFAULT.
    • 2 si rng_algorithm = THREE_FRY.
    • 2 ou 3 si rng_algorithm = PHILOX.

Exemples

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Sémantique

Effectue un arrondi élément par élément vers l'entier le plus proche, en cas d'égalité, l'arrondi s'éloigne de zéro, sur le Tensor operand et produit un Tensor result. Implémente l'opération roundToIntegralTiesToAway de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

 Autres exemples

round_nearest_even

Sémantique

Effectue un arrondi élément par élément vers l'entier le plus proche, en cas d'égalité vers l'entier pair, sur le Tensor operand et produit un Tensor result. Implémente l'opération roundToIntegralTiesToEven à partir de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_even, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

 Autres exemples

rsqrt

Sémantique

Effectue une opération de racine carrée réciproque au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : rSqrt à partir de la norme IEEE-754.
  • Pour les nombres complexes : racine carrée réciproque complexe.
  • Pour les types quantifiés : dequantize_op_quantize(rsqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

 Autres exemples

disperser

Sémantique

Produit des Tensors results égaux aux Tensors inputs, sauf que plusieurs tranches spécifiées par scatter_indices sont mises à jour avec les valeurs updates à l'aide de update_computation.

Le schéma suivant montre comment les éléments de updates... sont mappés sur les éléments de results... à l'aide d'un exemple concret. Le diagramme choisit quelques exemples d'indices updates... et explique en détail à quels indices results... ils correspondent.

disperser

Plus formellement, pour tout update_index dans index_space(updates[0]) :

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index est défini comme suit :
    • scatter_indices[si0, ..., :, ..., siN]si sont des éléments individuels dans update_scatter_index et : est inséré à l'index index_vector_dim, si index_vector_dim < rank(scatter_indices).
    • Sinon, [scatter_indices[update_scatter_index]].
  • Pour d_input dans axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] sid_input = scatter_dims_to_operand_dims[d_start].
    • Sinon, full_start_index[d_input] = 0.
  • Pour d_input dans axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]si d_input = input_batching_dims[i_batching] et d_start = scatter_indices_batching_dims[i_batching].
    • Sinon, full_batching_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN]wi sont des éléments individuels dans update_window_index, et 0 est inséré aux indices de inserted_window_dims et input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

Dans ce cas, results = exec(schedule, inputs), où :

  • schedule est une permutation de index_space(updates[0]) définie par l'implémentation.
  • exec([update_index, ...], results) = exec([...], updated_results) où :
    • Si result_index est dans les limites de shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results est une copie de results avec results...[result_index] défini sur updated_values....
    • Sinon, procédez comme suit :
    • updated_results = results.
  • exec([], results) = results.

Si indices_are_sorted est true, l'implémentation peut supposer que scatter_indices est trié par rapport à scatter_dims_to_operand_dims. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Si unique_indices est true, l'implémentation peut supposer que tous les indices result_index dispersés sont uniques. Si unique_indices est true, mais que les indices vers lesquels la dispersion est effectuée ne sont pas uniques, le comportement n'est pas défini.

Entrées

Libellé Nom Type Contraintes
(I1) inputs un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices tenseur de type entier (C4), (C15), (C19), (C22)
(I3) updates un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C3-C6), (C8)
(I4) update_window_dims Constante Tensor de type si64 à une dimension (C2), (C4), (C7-C8)
(I5) inserted_window_dims Constante Tensor de type si64 à une dimension (C2), (C4), (C9-C11)
(I6) input_batching_dims Constante Tensor de type si64 à une dimension (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims Constante Tensor de type si64 à une dimension (C14-C18)
(I8) scatter_dims_to_operand_dims Constante Tensor de type si64 à une dimension (C19-C21)
(I9) index_vector_dim constante de type si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted constante de type i1
(I11) unique_indices constante de type i1
(I12) update_computation fonction (C23)

Sorties

Nom Type Contraintes
results un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C24-C25)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) où :
    • update_scatter_dim_sizes = shape(scatter_indices), sauf que la taille de la dimension scatter_indices correspondant à index_vector_dim n'est pas incluse.
    • update_window_dim_sizes <= shape(inputs[0]), sauf que les tailles de dimension dans inputs[0] correspondant à inserted_window_dims et input_batching_dims ne sont pas incluses.
    • combine place update_scatter_dim_sizes sur les axes correspondant à update_scatter_dims et update_window_dim_sizes sur les axes correspondant à update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), où is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei pour tous les i dans [0,N).

Exemples

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ],
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimensio<n_numbers = #stablehlo.scatter
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2>, 1],
    index_vector_dim = 3,
  indices_are_sorted = false,
  uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

 Autres exemples

sélectionner

Sémantique

Génère un Tensor result où chaque élément est sélectionné à partir du Tensor on_true ou on_false en fonction de la valeur de l'élément pred correspondant. Plus formellement, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], où pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Pour les types quantifiés, effectue dequantize_select_quantize(pred, on_true, on_false, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) pred tenseur de type i1 (C1)
(I2) on_true tenseur ou tenseur quantifié par tenseur (C1-C2)
(I3) on_false tenseur ou tenseur quantifié par tenseur (C2)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C2)

Contraintes

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Exemples

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]

 Autres exemples

select_and_scatter

Sémantique

Disperse les valeurs du Tensor source à l'aide de scatter en fonction du résultat de reduce_window du Tensor input à l'aide de select et produit un Tensor result.

Le diagramme suivant montre comment les éléments de result sont calculés à partir de operand et source à l'aide d'un exemple concret.

select_and_scatter

Plus précisément :

  • selected_values = reduce_window_without_init(...) avec les entrées suivantes :

    • inputs = [operand].
    • window_dimensions, window_strides et padding sont utilisés tels quels.
    • base_dilations = windows_dilations = 1.
    • body est défini comme suit :
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    E = element_type(operand) et reduce_window_without_init fonctionnent exactement comme reduce_window, sauf que le schedule du reduce sous-jacent (voir reduce) n'inclut pas les valeurs d'initialisation. Le comportement de la fonction lorsque la fenêtre correspondante ne contient pas de valeurs n'est pas spécifié pour le moment (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) où :

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index si selected_values[source_index] comporte l'élément operand de operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1-C4), (C6), (C8-C11)
(I2) source tenseur ou tenseur quantifié par tenseur (C1), (C2)
(I3) init_value Tenseur de dimension 0 ou tenseur quantifié par tenseur (C3)
(I4) window_dimensions Constante Tensor de type si64 à une dimension (C2), (C4), (C5)
(I5) window_strides Constante Tensor de type si64 à une dimension (C2), (C6), (C7)
(I6) padding Constante Tensor à deux dimensions de type si64 (C2), (C8)
(I7) select fonction (C9)
(I8) scatter fonction (C10)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C11-C12)

Contraintes

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows où :
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select est de type (tensor<E>, tensor<E>) -> tensor<i1>E = element_type(operand).
  • (C10) scatter est de type (tensor<E>, tensor<E>) -> tensor<E>is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Exemples

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>E
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
  ^bb0(%<arg>0: tensori64, %arg1: tensori64):
    %0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
  >  "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
  window_dim<ensions => arrayi64: 3, 1,
  <window_strides => arrayi64<: 2, 1,>
  padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

 Autres exemples

envoyer

Sémantique

Envoie inputs à la chaîne channel_id. Les entrées sont ensuite envoyées aux autres appareils dans l'ordre spécifié par source_target_pairs. L'opération génère un jeton result.

Si is_host_transfer est défini sur true, l'opération transfère les données vers l'hôte. Sinon, il transfère les données vers un autre appareil en fonction des valeurs de source_target_pairs. Cette option duplique les informations fournies dans channel_type. À l'avenir, nous prévoyons de n'en conserver qu'une seule (#666). Si is_host_transferfalse et que source_target_pairs est None ou vide, le comportement est considéré comme indéfini.

Entrées

Libellé Nom Type Contraintes
(I1) inputs un nombre variable de Tensors ou de Tensors quantifiés.
(I2) token token
(I3) channel_id constante de type si64
(I4) channel_type Énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST (C5)
(I5) is_host_transfer constante de type i1 (C5-C6)
(I6) source_target_pairs Constante Tensor à deux dimensions de type si64 (C1-C4), (C6)

Sorties

Nom Type
result token

Contraintes

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, où N est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C5) channel_type est défini comme suit :
    • DEVICE_TO_HOST si is_host_transfer = true,
    • Sinon, DEVICE_TO_DEVICE.

Exemples

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token

 Autres exemples

shift_left

Sémantique

Effectue une opération de décalage à gauche au niveau des éléments sur le Tensor lhs par le nombre de bits rhs et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur de type entier (C1)
(I2) rhs tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]

 Autres exemples

shift_right_arithmetic

Sémantique

Effectue une opération de décalage arithmétique vers la droite au niveau des éléments sur le Tensor lhs par rhs bits et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur de type entier (C1)
(I2) rhs tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]

 Autres exemples

shift_right_logical

Sémantique

Effectue une opération de décalage logique vers la droite au niveau des éléments sur le Tensor lhs par rhs bits et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs tenseur de type entier (C1)
(I2) rhs tenseur de type entier (C1)

Sorties

Nom Type Contraintes
result tenseur de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]

 Autres exemples

signer

Sémantique

Renvoie le signe de l'élément operand et produit un Tensor result. Plus formellement, pour chaque élément x, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit :

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Pour les types quantifiés, effectue dequantize_op_quantize(sign, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

 Autres exemples

sinus

Sémantique

Effectue une opération sinus au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : sin à partir de la norme IEEE-754.
  • Pour les nombres complexes : sinus complexe.
  • Pour les types quantifiés : dequantize_op_quantize(sine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]

 Autres exemples

slice

Sémantique

Extrait une tranche de operand à l'aide d'indices de début calculés de manière statique et produit un Tensor result. start_indices contient les index de début de la tranche pour chaque dimension, limit_indices contient les index de fin (exclusifs) de la tranche pour chaque dimension et strides contient les foulées pour chaque dimension.

Plus précisément, result[result_index] = operand[operand_index]operand_index = start_indices + result_index * strides.

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié par tenseur (C1-C3), (C5)
(I2) start_indices Constante Tensor de type si64 à une dimension (C2), (C3), (C5)
(I3) limit_indices Constante Tensor de type si64 à une dimension (C2), (C3), (C5)
(I4) strides Constante Tensor de type si64 à une dimension (C2), (C4)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié par tenseur (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Exemples

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indic<es = arra>yi64: 1, 2,
  limit_indic<es = arra>yi64: 3, 4,
  strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

 Autres exemples

trier

Sémantique

Trie ensemble les tranches à une dimension de inputs le long de la dimension dimension, selon un comparator, et produit results.

Contrairement aux entrées similaires dans d'autres opérations, dimension autorise les valeurs négatives, avec la sémantique décrite ci-dessous. À l'avenir, cela pourra être interdit pour des raisons de cohérence (#1377).

Si is_stable est défini sur "true", le tri est stable, c'est-à-dire que l'ordre relatif des éléments considérés comme égaux par le comparateur est conservé. Dans le cas où il n'y a qu'une seule entrée, deux éléments e1 et e2 sont considérés comme égaux par le comparateur si et seulement si comparator(e1, e2) = comparator(e2, e1) = false. Consultez la formalisation ci-dessous pour savoir comment cela se généralise à plusieurs entrées.

Plus formellement, pour tout result_index dans index_space(results[0]) :

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1]riN sont des éléments individuels dans result_index, et : est inséré à adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • sort trie une tranche unidimensionnelle par ordre croissant, en s'attendant à ce que comparator_together renvoie true si l'argument de gauche est inférieur au second argument de droite.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Entrées

Libellé Nom Type Contraintes
(I1) inputs un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C1-C5)
(I2) dimension constante de type si64 (C4)
(I3) is_stable constante de type i1
(I4) comparator fonction (C5)

Sorties

Nom Type Contraintes
results un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. (C2), (C3)

Contraintes

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, où R = rank(inputs[0]).
  • (C5) comparator est de type (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, où Ei = element_type(inputs[i]).

Exemples

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
  dimension = 0 : i64,
<  is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

 Autres exemples

sqrt

Sémantique

Effectue une opération de racine carrée élément par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : squareRoot à partir de la norme IEEE-754.
  • Pour les nombres complexes : racine carrée complexe.
  • Pour les types quantifiés : dequantize_op_quantize(sqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]

 Autres exemples

subtract

Sémantique

Effectue une soustraction élément par élément de deux Tensors lhs et rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les nombres entiers : soustraction d'entiers.
  • Pour les valeurs float : subtraction à partir de la norme IEEE-754.
  • Pour les nombres complexes : soustraction complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)
(I2) rhs Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]

 Autres exemples

tan

Sémantique

Effectue une opération tangente au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : tan à partir de la norme IEEE-754.
  • Pour les nombres complexes : tangente complexe.
  • Pour les types quantifiés : dequantize_op_quantize(tan, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

 Autres exemples

tanh

Sémantique

Effectue une opération de tangente hyperbolique élément par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs float : tanh à partir de la norme IEEE-754.
  • Pour les nombres complexes : tangente hyperbolique complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(tanh, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]

 Autres exemples

transpose

Sémantique

Permute les dimensions du Tensor operand à l'aide de permutation et produit un Tensor result. Plus formellement, result[result_index] = operand[operand_index]result_index[d] = operand_index[permutation[d]].

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur ou tenseur quantifié (C1-C4)
(I2) permutation Constante Tensor de type si64 à une dimension (C2-C4)

Sorties

Nom Type Contraintes
result tenseur ou tenseur quantifié (C1), (C3-C4)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), sauf que quantization_dimension(operand) et quantization_dimension(result) peuvent être différents.
  • (C2) permutation est une permutation de range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Si is_per_axis_quantized(result), alorsquantization_dimension(operand) = permutation(quantization_dimension(result)).

Exemples

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

 Autres exemples

triangular_solve

Sémantique

Résout des lots de systèmes d'équations linéaires avec des matrices de coefficients triangulaires inférieures ou supérieures.

Plus formellement, étant donné a et b, result[i0, ..., iR-3, :, :] est la solution de op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] lorsque left_side est true ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] lorsque left_side est false, en résolvant la variable xop(a) est déterminé par transpose_a, qui peut être l'une des valeurs suivantes :

  • NO_TRANSPOSE : effectuez l'opération en utilisant a tel quel.
  • TRANSPOSE : effectue l'opération sur la transposée de a.
  • ADJOINT : effectue l'opération sur la transposée conjuguée de a.

Les données d'entrée ne sont lues qu'à partir du triangle inférieur de a, si lower est true ou du triangle supérieur de a, dans le cas contraire. Les données de sortie sont renvoyées dans le même triangle. Les valeurs de l'autre triangle sont définies par l'implémentation.

Si unit_diagonal est défini sur "true", l'implémentation peut supposer que les éléments diagonaux de a sont égaux à 1. Sinon, le comportement est indéfini.

Pour les types quantifiés, effectue dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) a Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1-C3)
(I2) b Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1-C4)
(I3) left_side constante de type i1 (C3)
(I4) lower constante de type i1
(I5) unit_diagonal constante de type i1
(I6) transpose_a Énumération de NO_TRANSPOSE, TRANSPOSE et ADJOINT

Sorties

Nom Type Contraintes
result Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur (C1)

Contraintes

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) La relation entre shape(a) et shape(b) est définie comme suit :
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Exemples

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Sémantique

Génère un tuple result à partir des valeurs val.

Entrées

Libellé Nom Type Contraintes
(I1) val nombre variable de valeurs (C1)

Sorties

Nom Type Contraintes
result tuple (C1)

Contraintes

  • (C1) result est de type tuple<E0, ..., EN-1>Ei = type(val[i]).

Exemples

// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))

 Autres exemples

uniform_dequantize

Sémantique

Effectue une conversion élément par élément du tenseur quantifié operand en un tenseur à virgule flottante result en fonction des paramètres de quantification définis par le type operand.

Plus précisément, result = dequantize(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand tenseur quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Exemples

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]

uniform_quantize

Sémantique

Effectue une conversion élément par élément du tenseur à virgule flottante ou du tenseur quantifié operand en un tenseur quantifié result en fonction des paramètres de quantification définis par le type result.

Plus précisément,

  • Si is_float(operand) :
    • result = quantize(operand, type(result)).
  • Si is_quantized(operand) :
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tenseur de type à virgule flottante ou quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result tenseur quantifié (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Exemples

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]

pendant que

Sémantique

Produit la sortie de l'exécution de la fonction body zéro fois ou plus, tandis que la fonction cond génère true. Plus formellement, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit :

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Le comportement d'une boucle infinie est à déterminer (#383).

Entrées

Libellé Nom Type Contraintes
(I1) operand nombre variable de valeurs (C1-C3)
(I2) cond fonction (C1)
(I3) body fonction (C2)

Sorties

Nom Type Contraintes
results nombre variable de valeurs (C3)

Contraintes

  • (C1) cond est de type (T0, ..., TN-1) -> tensor<i1>, où Ti = type(operand[i]).
  • (C2) body est de type (T0, ..., TN-1) -> (T0, ..., TN-1), où Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Exemples

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_di<rection = #stablehlocom>parison_directio<n L>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    stablehlo.r<et>urn %cond : tensori1
  }, {
<  ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
    %new_sum = stablehlo.add <%ar>g1, %one : tensori64
    %new_i = stablehlo.add <%ar>g0, %one : tensori64
    stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10

 Autres exemples

xor

Sémantique

Effectue une opération XOR (OU exclusif) élément par élément sur deux Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :

  • Pour les valeurs booléennes : XOR logique.
  • Pour les nombres entiers : XOR (OU exclusif) bit à bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type booléen ou entier (C1)
(I2) rhs Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]

 Autres exemples

Interopérabilité des dialectes

À l'heure actuelle, les programmes StableHLO en circulation contiennent parfois des opérations qui ne sont pas définies par StableHLO.

Module, fonction, appel et retour

StableHLO utilise des opérations MLIR en amont pour ModuleOp, FuncOp, CallOp et ReturnOp. Cela a été fait pour une meilleure interopérabilité avec les mécanismes MLIR existants, car de nombreux passes utiles sont écrits en ciblant FuncOp et ModuleOp, et de nombreux pipelines de compilation s'attendent à ce que ces opérations soient présentes. Une compatibilité totale est garantie pour ces opérations. Si ces opérations changent de manière incompatible (par exemple, si elles sont supprimées), des équivalents StableHLO seront ajoutés pour préserver la compatibilité.

CHLO

L'opset CHLO contient des opérations de niveau supérieur qui se décomposent en StableHLO. Actuellement, aucune garantie de compatibilité n'est disponible pour CHLO. Pour garantir la compatibilité, le chlo-legalize-to-stablehlo pass doit être utilisé avant la sérialisation.

Opérations sur les formes

Il est courant dans la communauté d'utiliser certaines opérations des dialectes MLIR de base dans les programmes StableHLO dynamiques pour effectuer des calculs de forme. Le plus souvent, il s'agit d'opérations shape telles que shape_of ou num_elements, d'opérations tensor telles que dim ou from_elements, et du type index intégré.

Le RFC sur le dynamisme > O2 indique que ces types sont hors champ d'application. Toutefois, une certaine prise en charge des types index est incluse à des fins d'interopérabilité. Il n'existe aucune garantie de compatibilité pour ces opérations ou types. Le pass shape-legalize-to-stablehlo peut être utilisé pour convertir ces opérations en opérations StableHLO entièrement compatibles.

Opérations obsolètes

Plusieurs opérations StableHLO héritées de MHLO sont obsolètes et seront bientôt supprimées de StableHLO. Pour en savoir plus sur ces suppressions, consultez StableHLO v1.0 Cleanup #2283. Le problème de suivi de ces abandons est le n° 2340.

Ces opérations se répartissent en plusieurs catégories :

  • Catégorie "Not in HLO" des opérations StableHLO : elles faisaient initialement partie de l'opset StableHLO, mais ont ensuite été jugées inadaptées : broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • Opérations inutilisées : ces opérations ont peut-être été utiles à un moment donné, mais elles étaient soit sous-développées, soit les pipelines qui les utilisaient ont été refactorisés pour ne plus en avoir besoin. Cela inclut les comparaisons map, tuple (#598), get_tuple_element, rng, complex #560 et la convolution window_reversal (#1181).

Certaines de ces opérations peuvent être facilement supprimées, car elles peuvent être exprimées à l'aide d'opérations existantes (broadcast, create_token, cross-replica-sum, dot, unary_einsum) et seront supprimées une fois la période de compatibilité existante (six mois) écoulée. D'autres sont encore à l'étude pour être supprimées (comparaisons einsum, get_tuple_element, map, rng, torch_index_select, tuple, complex, window_reversal). En fonction des commentaires de la communauté, ces opérations seront supprimées ou ajoutées à la spécification avec une prise en charge complète. Tant que ces futures opérations ne sont pas connues, la compatibilité n'est garantie que pendant six mois.

Exécution

Exécution séquentielle

Un programme StableHLO est exécuté en fournissant des valeurs d'entrée à la fonction main et en calculant les valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées en exécutant le graphique des opérations ancré dans l'opération return correspondante.

L'ordre d'exécution est défini par l'implémentation tant qu'il est aligné sur le flux de données, c'est-à-dire si les opérations sont exécutées avant leurs utilisations. Dans StableHLO, toutes les opérations à effet secondaire consomment un jeton et en produisent un (plusieurs jetons peuvent être multiplexés en un seul jeton via after_all). L'ordre d'exécution des effets secondaires est donc également aligné sur le flux de données. Par exemple, dans le programme ci-dessous, il existe deux ordres d'exécution possibles : %0 → %1 → %2 → return et %1 → %0 → %2 → return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Plus précisément, un processus StableHLO est une combinaison des éléments suivants : 1) un programme StableHLO, 2) des états d'opération (pas encore exécutée, déjà exécutée) et 3) des valeurs intermédiaires sur lesquelles le processus travaille. Le processus commence par des valeurs d'entrée pour la fonction main, progresse dans le graphique des opérations en mettant à jour les états des opérations et les valeurs intermédiaires, et se termine par des valeurs de sortie. La formalisation supplémentaire est à déterminer (#484).

Exécution parallèle

Les programmes StableHLO peuvent être exécutés en parallèle, organisés dans une grille de processus 2D de num_replicas par num_partitions, qui sont tous deux de type ui32.

Dans la grille de processus StableHLO, num_replicas * num_partitions processus StableHLO s'exécutent en même temps. Chaque processus possède un process_id = (replica_id, partition_id) unique, où replica_id dans replica_ids = range(num_replicas) et partition_id dans partition_ids = range(num_partitions) ont tous deux le type ui32.

La taille de la grille de processus est connue de manière statique pour chaque programme (à l'avenir, nous prévoyons d'en faire une partie explicite des programmes StableHLO #650), et la position dans la grille de processus est connue de manière statique pour chaque processus. Chaque processus a accès à sa position dans la grille de processus via les opérations replica_id et partition_id.

Dans la grille de processus, les programmes peuvent être identiques (style "Programme unique, données multiples"), tous différents (style "Programmes multiples, données multiples") ou quelque chose entre les deux. À l'avenir, nous prévoyons d'ajouter la compatibilité avec d'autres idiomes de définition des programmes StableHLO parallèles, y compris GSPMD (#619).

Dans la grille de processus, les processus sont généralement indépendants les uns des autres. Ils ont des états d'opération distincts, des valeurs d'entrée/intermédiaires/de sortie distinctes et la plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.

Étant donné que l'exécution de la plupart des opérations n'utilise que des valeurs provenant du même processus, il est généralement clair de faire référence à ces valeurs par leur nom. Toutefois, cela ne suffit pas pour décrire la sémantique des opérations collectives, ce qui donne lieu à la notation name@process_id pour faire référence à la valeur name dans un processus particulier. (De ce point de vue, un name non qualifié peut être considéré comme une abréviation de name@(replica_id(), partition_id()).)

L'ordre d'exécution des processus est défini par l'implémentation, à l'exception de la synchronisation introduite par la communication point à point et les opérations collectives, comme décrit ci-dessous.

Communication point à point

Les processus StableHLO peuvent communiquer entre eux via des canaux StableHLO. Un canal est représenté par un ID positif de type si64. Grâce à différentes opérations, il est possible d'envoyer des valeurs aux canaux et de les recevoir des canaux.

Une formalisation plus poussée (par exemple, d'où proviennent ces ID de canal, comment les programmes de processus en prennent connaissance et quel type de synchronisation est introduit par eux) est à déterminer (#484).

Communication en streaming

Chaque processus StableHLO a accès à deux interfaces de streaming :

  • Infeed à partir duquel la lecture peut être effectuée.
  • Outfeed dans lequel écrire.

Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et ont donc des processus à leurs deux extrémités, les flux d'entrée et de sortie ont leur autre extrémité définie par l'implémentation.

La formalisation supplémentaire, par exemple la façon dont la communication en streaming influence l'ordre d'exécution et le type de synchronisation qu'elle introduit, est à déterminer (#484).

Opérations collectives

Il existe six opérations collectives dans StableHLO : all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute et reduce_scatter. Toutes ces opérations divisent les processus dans la grille de processus StableHLO en groupes de processus StableHLO et exécutent un calcul conjoint dans chaque groupe de processus, indépendamment des autres groupes de processus.

Dans chaque groupe de processus, les opérations collectives peuvent introduire une barrière de synchronisation. Une formalisation plus poussée, par exemple en précisant le moment exact où cette synchronisation se produit, la manière exacte dont les processus atteignent cette barrière et ce qui se passe s'ils ne l'atteignent pas, est à déterminer (#484).

Si le groupe de processus implique une communication entre partitions (c'est-à-dire qu'il existe des processus dans le groupe de processus dont les ID de partition sont différents), l'exécution de l'opération collective nécessite un canal, et l'opération collective doit fournir un channel_id positif de type si64. La communication entre les répliques n'a pas besoin de canaux.

Les calculs effectués par les opérations collectives sont spécifiques aux opérations individuelles et sont décrits dans les sections d'opérations individuelles ci-dessus. Toutefois, les stratégies selon lesquelles la grille de processus est divisée en groupes de processus sont partagées entre ces opérations et sont décrites dans cette section. Plus précisément, StableHLO prend en charge les quatre stratégies suivantes.

cross_replica

Seules les communications entre les répliques ont lieu au sein de chaque groupe de processus. Cette stratégie prend replica_groups (liste de listes d'ID de répliques) et calcule un produit cartésien de replica_groups par partition_ids. replica_groups doit comporter des éléments uniques et couvrir tous les replica_ids. Plus formellement, en utilisant la syntaxe Python :

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica générera [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Seules les communications entre partitions ont lieu au sein de chaque groupe de processus. Cette stratégie prend partition_groups (une liste de listes d'ID de partition) et calcule un produit cartésien de partition_groups par replica_ids. partition_groups doit comporter des éléments uniques et couvrir tous les partition_ids. Plus précisément, en utilisant la syntaxe Python :

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour partition_groups = [[0, 1]] et num_replicas = 4, cross_partition générera [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Les communications entre réplicas et entre partitions peuvent avoir lieu dans chaque groupe de processus. Cette stratégie prend replica_groups (liste de listes d'ID de répliques) et calcule les produits cartésiens de chaque replica_group par partition_ids. replica_groups doit comporter des éléments uniques et couvrir tous les replica_ids. Plus précisément, en utilisant la syntaxe Python :

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica_and_partition générera [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Cette stratégie prend flattened_id_groups (une liste de listes d'ID de processus "aplatis" sous la forme replica_id * num_partitions + partition_id) et les transforme en ID de processus. flattened_id_groups doit comporter des éléments uniques et couvrir tous les process_ids. Plus précisément, en utilisant la syntaxe Python :

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 et num_partitions = 2, flattened_ids générera [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Précision

Pour le moment, StableHLO ne fournit aucune garantie concernant la précision numérique, mais cela peut changer à l'avenir (#1156).

Sémantique d'exécution de l'opération quantifiée

L'interprétation des opérations StableHLO quantifiées peut varier en fonction des exigences et des capacités matérielles. Par exemple, certains matériels peuvent choisir d'interpréter les opérations quantifiées à l'aide d'une stratégie "déquantifier, effectuer une opération à virgule flottante et enfin quantifier". D'autres peuvent effectuer l'intégralité du calcul avec une arithmétique entière. Par conséquent, l'interprétation des opérations StableHLO quantifiées est exclusivement déterminée par l'implémentation spécifique. L'interprétation de la quantification hybride (#1575) doit être basée sur sa sémantique telle que prescrite dans la spécification (via 1792).

Erreurs

Les programmes StableHLO sont validés à l'aide d'un ensemble complet de contraintes pour les opérations individuelles, ce qui exclut de nombreuses classes d'erreurs avant l'exécution. Toutefois, des conditions d'erreur sont toujours possibles, par exemple en cas de dépassement de capacité d'entier, d'accès hors limites, etc. Sauf indication explicite, toutes ces erreurs entraînent un comportement défini par l'implémentation, mais cela peut changer à l'avenir (#1157).

Exceptions à virgule flottante

Par exception à cette règle, les exceptions à virgule flottante dans les programmes StableHLO ont un comportement bien défini. Les opérations qui entraînent des exceptions définies par la norme IEEE-754 (opération non valide, division par zéro, dépassement de capacité, dépassement de capacité négatif ou exceptions inexactes) produisent des résultats par défaut (tels que définis dans la norme) et poursuivent l'exécution sans déclencher l'indicateur d'état correspondant, comme la gestion des exceptions raiseNoFlag de la norme. Les exceptions pour les opérations non standards (par exemple, l'arithmétique complexe et certaines fonctions transcendantales) sont définies par l'implémentation.

Incompatibilités de forme

StableHLO est compatible avec les tenseurs de forme dynamique. Toutefois, les formes doivent être identiques au moment de l'exécution, sinon le comportement n'est pas défini. StableHLO ne fournit pas explicitement d'opération permettant d'affirmer qu'un Tensor a une forme donnée au moment de l'exécution. Il incombe au producteur de générer le code correct.

Par exemple, le programme ci-dessous est valide. Toutefois, au moment de l'exécution, les formes exactes de %arg0 et %arg1 devront être identiques, sinon le comportement du programme sera indéfini :

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

Pour décrire la syntaxe, ce document utilise la variante ISO modifiée de la syntaxe EBNF (ISO/IEC 14977:1996, Wikipédia), avec deux modifications : 1) les règles sont définies à l'aide de ::= plutôt que de =,

2) La concaténation est exprimée par juxtaposition plutôt que par ,.

Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constantes" et "Ops"), nous utilisons des formules basées sur la syntaxe Python, qui sont étendues pour permettre d'exprimer de manière concise les opérations sur les tableaux, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans de rares cas où des extraits de code plus volumineux sont nécessaires, nous utilisons la syntaxe Python de base, qui est toujours introduite de manière explicite.

Formules

Voyons comment fonctionnent les formules à l'aide d'un exemple tiré de la spécification dot_general. L'une des contraintes de cette opération se présente comme suit : dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Les noms utilisés dans cette formule proviennent de deux sources : 1) les fonctions globales, c'est-à-dire dim ; 2) les définitions de membre de l'élément de programme correspondant, c'est-à-dire les entrées lhs, lhs_batching_dimensions, rhs et rhs_batching_dimensions définies dans la section "Entrées" de dot_general.

Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python, avec quelques extensions axées sur la concision. Pour comprendre la formule, transformons-la en syntaxe Python standard.

A) Dans ces formules, nous utilisons = pour représenter l'égalité. La première étape pour obtenir la syntaxe Python consiste donc à remplacer = par ==, comme suit : dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Ces formules sont également compatibles avec les points de suspension (...), qui transforment les expressions scalaires en expressions Tensor. En résumé, f(xs...) signifie à peu près "pour chaque scalaire x dans le Tensor xs, calculez un scalaire f(x), puis renvoyez tous ces résultats scalaires ensemble sous forme de Tensor". Dans la syntaxe Python de base, notre exemple de formule devient : [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Grâce aux points de suspension, il est souvent possible d'éviter de travailler au niveau des scalaires individuels. Toutefois, dans certains cas complexes, une syntaxe semi-informelle de niveau inférieur peut être utilisée, comme dans la formule start_indices[bi0, ..., :, ..., biN] de la spécification gather. Par souci de concision, nous ne fournissons pas de formalisme exact pour traduire une telle syntaxe en Python standard, dans l'espoir qu'elle reste intuitivement compréhensible au cas par cas. N'hésitez pas à nous signaler les formules spécifiques qui vous semblent obscures. Nous essaierons de les améliorer.

Vous remarquerez également que les formules utilisent des points de suspension pour développer toutes sortes de listes, y compris les Tensors, les listes de Tensors (qui peuvent, par exemple, provenir d'un nombre variable de Tensors), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas de formalisme exact (par exemple, les listes ne font même pas partie du système de type StableHLO) et nous nous appuyons plutôt sur une compréhension intuitive.

C) Le dernier véhicule de notation notable que nous utilisons est la diffusion implicite. Bien que l'opset StableHLO ne soit pas compatible avec le broadcasting implicite, les formules le sont, également dans un souci de concision. En résumé, si un scalaire est utilisé dans un contexte où un Tensor est attendu, le scalaire est diffusé à la forme attendue.

Pour poursuivre l'exemple dot_general, voici une autre contrainte : 0 <= lhs_batching_dimensions < rank(lhs). Comme défini dans la spécification dot_general, lhs_batching_dimensions est un Tensor, mais 0 et rank(lhs) sont des scalaires. Après l'application de la diffusion implicite, la formule devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Lorsqu'elle est appliquée à une opération dot_general spécifique, cette formule est évaluée en tant que Tensor de booléens. Lorsque des formules sont utilisées comme contraintes, la contrainte est respectée si la formule correspond à true ou à un Tensor qui ne comporte que des éléments true.

Noms

Dans les formules, la portée lexicale inclut : 1) les fonctions globales, 2) les définitions de membres,

3) les définitions locales. La liste des fonctions globales est fournie ci-dessous. La liste des définitions d'éléments dépend de l'élément de programme auquel la notation est appliquée :

  • Pour les opérations, les définitions des membres incluent les noms introduits dans les sections "Entrées" et "Sorties".
  • Pour tout le reste, les définitions des membres incluent les parties structurelles de l'élément de programme, nommées d'après les non-terminaux EBNF correspondants. La plupart du temps, les noms de ces parties structurelles sont obtenus en convertissant les noms des non-terminaux en snake case (par exemple, IntegerLiteral => integer_literal), mais parfois les noms sont abrégés au cours du processus (par exemple, QuantizationStorageType => storage_type). Dans ce cas, les noms sont introduits explicitement de la même manière que les sections "Entrées" / "Sorties" dans les spécifications des opérations.
  • De plus, les définitions de membre incluent toujours self pour faire référence à l'élément de programme correspondant.

Valeurs

Lorsqu'elles sont évaluées, les formules fonctionnent avec les types de valeurs suivants : 1) Value (valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> ; leurs types sont toujours connus), 2) Placeholder (valeurs futures, par exemple lhs, rhs ou result ; leurs valeurs réelles ne sont pas encore connues, seuls leurs types le sont), 3) Type (types tels que définis dans la section "Types"), 4) Function (fonctions globales telles que définies dans la section "Fonctions").

Selon le contexte, les noms peuvent faire référence à différentes valeurs. Plus précisément, la section "Sémantique" pour les opérations (et les équivalents pour les autres éléments du programme) définit la logique d'exécution. Tous les inputs sont donc disponibles en tant que Value. En revanche, la section "Constraints" (Contraintes) pour les opérations (et les équivalents) définit la logique de "compilation", c'est-à-dire quelque chose qui est généralement exécuté avant l'exécution. Par conséquent, seules les entrées constantes sont disponibles en tant que Value, et les autres entrées ne sont disponibles qu'en tant que Placeholder.

Noms Dans "Sémantique" Dans "Contraintes"
Fonctions globales Function Function
Entrées constantes Value Value
Entrées non constantes Value Placeholder
Sorties Value Placeholder
Définitions locales Selon la définition Selon la définition

Prenons l'exemple d'une opération transpose :

%result = "stablehlo.transpose"(%operand) {
  permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32

Pour cette opération, permutation est une constante. Elle est donc disponible en tant que Value dans la sémantique et les contraintes. En revanche, operand et result sont disponibles en tant que Value dans la sémantique, mais uniquement en tant que Placeholder dans les contraintes.

Fonctions

Construction des types

Il n'existe aucune fonction permettant de construire des types. Nous utilisons plutôt directement la syntaxe de type, car elle est généralement plus concise. Par exemple, (tensor<E>, tensor<E>) -> (tensor<E>) plutôt que function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Fonctions sur les types

  • element_type est défini sur les types de tenseurs et les types de tenseurs quantifiés, et renvoie respectivement la partie TensorElementType ou QuantizedTensorElementType du TensorType ou du QuantizedTensorType correspondant.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool vérifie si le type x peut être promu au type y. Lorsque x et y sont des QuantizedTensorElementType, la promotion ne s'applique qu'à storage_type. Cette version spécifique de la promotion est actuellement utilisée dans le contexte du calcul des réductions (pour en savoir plus, consultez la RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Disponible pour tous les types. Par exemple, is_float(x) renvoie true si x est un FloatType. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour is_type_name(type(x)).

  • max_value(x: Type) -> Value renvoie la valeur maximale d'un TensorElementType. Si x n'est pas un TensorElementType, renvoie None.

  • min_value(x: Type) -> Value renvoie la valeur minimale possible d'un TensorElementType. Si x n'est pas un TensorElementType, renvoie None.

  • member_name(x: Value | Placeholder | Type) -> Any. Disponible pour toutes les définitions de membres member_name de tous les types. Par exemple, tensor_element_type(x) renvoie la partie TensorElementType d'un TensorType correspondant. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour member_name(type(x)). Si x n'est pas un type qui possède un membre approprié, ou une valeur ou un espace réservé de ce type, renvoie None.

  • is_empty_algorithm(*args: Type) vérifie si tous les champs de l'algorithme de points sont définis sur None. Cela est nécessaire, car les algorithmes de points ont des comportements par défaut définis par l'implémentation. Il serait donc incorrect de spécifier une valeur par défaut.

Construction des valeurs

  • operation_name(*xs: Value | Type) -> Value. Disponible pour toutes les opérations. Par exemple, add(lhs, rhs) prend deux valeurs de Tensor lhs et rhs et renvoie le résultat de l'évaluation de l'opération add avec ces entrées. Pour certaines opérations (par exemple, broadcast_in_dim), les types de leurs sorties sont "porteurs", c'est-à-dire nécessaires pour évaluer une opération. Dans ce cas, la fonction prend ces types comme arguments.

Fonctions sur les valeurs

  • Tous les opérateurs et fonctions Python sont disponibles. Par exemple, les notations subscription et slicing de Python sont disponibles pour indexer les Tensors, les Tensors quantifiés et les tuples.

  • to_destination_type(x: Value, destination_type: Type) -> Value est défini sur les Tensors et renvoie la valeur convertie de x en fonction de type(x) et destination_type comme suit :

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Une discussion préliminaire est en cours concernant la fusion des opérations convert, uniform_quantize et uniform_dequantize (#1576). Après la fusion, nous n'avons plus besoin de la fonction ci-dessus et nous pouvons utiliser le nom de l'opération pour convert.

  • is_nan(x: Value) -> Value est défini sur les Tensors et renvoie true si tous les éléments de x sont NaN ou false dans le cas contraire. Si x n'est pas un Tensor, renvoie None.

  • is_sorted(x: Value) -> Value est défini sur les Tensors et renvoie true si les éléments de x sont triés par ordre croissant par rapport à l'ordre lexicographique croissant de leurs indices, ou false dans le cas contraire. Si x n'est pas un Tensor, renvoie None.

  • is_unique(x: Value) -> Value est défini sur les Tensors et renvoie true si x ne comporte pas d'éléments en double, ou false dans le cas contraire. Si x n'est pas un Tensor, renvoie None.

  • member_name(x: Value) -> Any est défini pour toutes les définitions de membresmember_name de toutes les valeurs. Par exemple, real_part(x) renvoie la partie RealPart d'un ComplexConstant correspondant. Si x n'est pas une valeur qui possède un membre approprié, renvoie None.

  • same(x: Value) -> Value est défini sur les Tensors et renvoie true si les éléments de x sont tous égaux les uns aux autres, ou false dans le cas contraire. Si le Tensor ne comporte aucun élément, cela est considéré comme "tous égaux les uns aux autres ", c'est-à-dire que la fonction renvoie true. Si x n'est pas un Tensor, renvoie None.

  • split(x: Value, num_results: Value, axis: Value) -> Value est défini sur les Tensors et renvoie des tranches num_results de x le long de l'axe axis. Si x n'est pas un Tensor ou dim(x, axis) % num_results != 0, renvoie None.

  • is_defined_in_parent_scope(x: Value) -> Value est défini sur les chaînes et renvoie true si x est le nom d'une fonction définie dans le même champ d'application que la fonction parente de l'opération concernée.

  • is_namespaced_op_name(x: Value) -> Value est défini sur les chaînes et renvoie true si x est un nom d'opération valide, c'est-à-dire qu'il respecte l'expression régulière suivante : [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+.

Calculs de forme

  • axes(x: Value | Placeholder | Type) -> Value est un raccourci pour range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value est un raccourci pour shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List est un raccourci pour list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value est défini sur les Tensors et renvoie les indices size(x) pour le TensorType correspondant trié par ordre lexicographique croissant, c'est-à-dire [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Si x n'est pas un type Tensor, un type Tensor quantifié, une valeur ou un espace réservé de l'un de ces types, renvoie None.

  • rank(x: Value | Placeholder | Type) -> Value est un raccourci pour size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value est défini dans la section "Fonctions sur les types" via member_name.

  • size(x: Value | Placeholder | Type) -> Value est un raccourci pour reduce(lambda x, y: x * y, shape(x)).

Calculs de quantification

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type est un raccourci pour element_type(baseline_type(x)).

  • baseline_type est défini sur les types de Tensor et les types de Tensor quantifiés, et les transforme en "baseline", c'est-à-dire un type avec la même forme, mais avec les paramètres de quantification du type d'élément réinitialisés sur les valeurs par défaut. Il s'agit d'une astuce pratique pour comparer uniformément les types de Tensor et de Tensor quantifiés, ce qui est souvent nécessaire. Pour les types quantifiés, cela permet de comparer les types en ignorant les paramètres de quantification, c'est-à-dire que shape, storage_type, expressed_type, storage_min, storage_max et quantization_dimension (pour le type quantifié par axe) doivent tous correspondre, mais scales et zero points peuvent être différents.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize est défini sur les types de tenseurs quantifiés et les transforme en types de tenseurs à virgule flottante. Cela se produit en convertissant les éléments quantifiés qui représentent les valeurs entières du type de stockage en valeurs à virgule flottante correspondantes du type exprimé à l'aide du point zéro et de l'échelle associés au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize est défini sur les types de tenseurs à virgule flottante et les transforme en types de tenseurs quantifiés. Pour ce faire, les valeurs à virgule flottante du type exprimé sont converties en valeurs entières correspondantes du type de stockage à l'aide du point zéro et de l'échelle associés au type d'élément quantifié.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize permet de spécifier des calculs élément par élément sur les Tensors quantifiés. Il déquantifie, c'est-à-dire qu'il transforme les éléments quantifiés en leurs types exprimés, puis effectue une opération, puis quantifie, c'est-à-dire qu'il transforme les résultats en leurs types de stockage. Pour le moment, cette fonction ne fonctionne que pour la quantification par Tensor. La quantification par axe est en cours de développement (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op est utilisé pour spécifier la quantification du poids uniquement pour l'opération hybride qui accepte lhs en virgule flottante et rhs dans les types quantifiés. Il déquantifie les entrées quantifiées dans leurs types exprimés et effectue des calculs en float. Le type d'élément du Tensor lhs float et le type exprimé du Tensor rhs quantifié doivent être identiques.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Calculs de grille

  • cross_partition(replica_groups: Value) -> Value. Consultez la section "cross_replica" ci-dessus.

  • cross_replica(replica_groups: Value) -> Value. Consultez la section "cross_replica" ci-dessus.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Consultez la section "cross_replica_and_partition" ci-dessus.

  • flattened_ids(replica_groups: Value) -> Value. Consultez la section "flattened_ids" ci-dessus.

Dynamisme

Les valeurs StableHLO peuvent avoir des tailles de dimension dynamiques, par exemple tensor<?xi64>. Toutefois, les valeurs StableHLO ne peuvent pas avoir un nombre dynamique de dimensions (dynamisme non classé, par exemple tensor<*xi64>). Les opérandes et les résultats sont autorisés à utiliser des tailles de dimension dynamiques, même s'il existe des contraintes sur les tailles. Les contraintes seront validées de manière statique si possible. Sinon, elles seront différées jusqu'à l'exécution et les incohérences entraîneront un comportement indéfini. Vous trouverez des exemples ci-dessous.

Incompatibilités de forme pour les opérations unaires élément par élément

Prenons l'exemple de programme suivant :

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Un tel programme est inhabituel, car il est rare de connaître la forme du résultat sans connaître celle de l'entrée. Néanmoins, il s'agit d'un programme StableHLO valide. Il n'est pas possible de valider statiquement l'opération abs dans ce programme, car la forme exacte de l'opérande est inconnue. Toutefois, les formes sont certainement compatibles, et cela peut être vérifié de manière statique : ? pourrait s'avérer être 2 au moment de l'exécution, et il n'y aurait aucun problème. Toutefois, ? peut également s'avérer être un autre entier, auquel cas le comportement n'est pas défini.

Notez que si la taille d'une dimension est dynamique dans le résultat, il ne peut pas y avoir de comportement indéfini. En effet, il n'y a pas de taille "attendue", il ne peut donc pas y avoir de décalage.

Incompatibilité de forme pour les opérations binaires élément par élément

Prenons l'exemple de programme suivant :

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

En ce qui concerne les opérations binaires élément par élément, les formes des entrées et du résultat doivent correspondre au moment de l'exécution. Au moment de la compilation, les dimensions statiques doivent être égales. Sinon, elles doivent simplement être compatibles. Si une dimension est dynamique dans les entrées, un comportement indéfini peut se produire au moment de l'exécution, car la taille dynamique peut ne pas correspondre à la taille correspondante dans l'autre opérande (qu'elle soit statique ou dynamique). Si toutes les entrées sont statiques, la nature dynamique ou non du résultat n'a pas d'importance : les dimensions connues de manière statique seront vérifiées de manière statique, et les dimensions dynamiques n'imposent aucune contrainte.

Incompatibilités de forme pour les opérations qui prennent leur forme de sortie comme opérande

Prenons l'exemple de programme suivant :

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Les valeurs de l'opérande de forme lors de l'exécution doivent correspondre à la forme du résultat, sinon le comportement n'est pas défini. Autrement dit, au moment de l'exécution, %arg0 doit avoir une valeur de dense<[3, 4]> : tensor<2xi32>. Si l'opérande de forme est constant, cela peut être vérifié de manière statique. Si la forme du résultat est entièrement dynamique, il ne peut pas y avoir d'incohérence.