Spécification StableHLO

StableHLO est un ensemble d'opérations destiné aux opérations de haut niveau (HLO) dans les modèles de machine learning (ML). StableHLO fonctionne comme une couche de portabilité entre différents frameworks de ML et compilateurs de ML: les frameworks de ML qui génèrent des programmes StableHLO sont compatibles avec les compilateurs de ML qui utilisent des programmes StableHLO.

Notre objectif est de simplifier et d'accélérer le développement du ML en favorisant l'interopérabilité entre différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et des compilateurs de ML (tels que XLA et IREE). À cette fin, ce document fournit une spécification pour le langage de programmation StableHLO.

Cette spécification contient trois sections principales. Tout d'abord, la section Programmes décrit la structure des programmes StableHLO, qui sont constitués de fonctions StableHLO, qui sont elles-mêmes constituées d'opérations StableHLO. Au sein de cette structure, la section Ops spécifie la sémantique de chaque opération. La section Exécution fournit la sémantique de toutes ces opérations exécutées ensemble au sein d'un programme. Enfin, la section Notation décrit la notation utilisée tout au long de la spécification.

Programmes

Program ::= {Func}

Les programmes StableHLO sont constitués d'un nombre arbitraire de fonctions StableHLO. Vous trouverez ci-dessous un exemple de programme avec une fonction @main qui comporte trois entrées (%image, %weights et %bias) et une sortie. Le corps de la fonction comporte six opérations.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Fonctions

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Les fonctions StableHLO (également appelées fonctions nommées) ont un identifiant, des entrées/sorties et un corps. Nous prévoyons d'ajouter des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, #740, 744).

Identifiants

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Les identifiants StableHLO sont semblables aux identifiants dans de nombreux langages de programmation, avec deux particularités: 1) tous les identifiants ont des sigles qui distinguent différents types d'identifiants, 2) les identifiants de valeur peuvent être entièrement numériques pour simplifier la génération de programmes StableHLO.

Types

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Les types StableHLO sont classés en types de valeurs (également appelés types de première classe), qui représentent des valeurs StableHLO et des types sans valeur, qui décrivent d'autres éléments du programme. Les types StableHLO sont semblables aux types dans de nombreux langages de programmation, la principale particularité étant la nature spécifique au domaine de StableHLO, qui entraîne des résultats inhabituels (par exemple, les types scalaires ne sont pas des types de valeur).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

Les types de Tensors représentent des Tensors, c'est-à-dire des tableaux multidimensionnels. Elles ont une forme et un type d'élément, où elles représentent des tailles de dimensions non négatives dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0 à R-1. Le nombre de dimensions R est appelé rang. Par exemple, tensor<2x3xf32> est un type de Tensor avec la forme 2x3 et le type d'élément f32. Elle comporte deux dimensions (ou, en d'autres termes, deux axes) : la 0e dimension et la 1re dimension, dont les tailles sont 2 et 3. Son classement est de 2.

Cela permet de prendre en charge les formes statiques dans lesquelles les tailles de dimension sont connues de manière statique. À l'avenir, nous prévoyons d'ajouter également la compatibilité avec les formes dynamiques pour lesquelles les tailles de dimension sont partiellement ou entièrement inconnues (#8). De plus, nous prévoyons d'explorer l'extension des types de Tensor au-delà des tailles de dimension et des types d'éléments, par exemple pour inclure les mises en page (#629) et la parcimonie (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nom Type Contraintes
storage_type Type entier (C1-C4), (C9)
storage_min constante entière (C2), (C4), (C8)
storage_max constante entière (C3), (C4), (C8)
expressed_type type à virgule flottante (C1), (C5)
quantization_dimension constante entière facultative (C11-C13)
scales nombre varié de constantes à virgule flottante (C5-C7), (C10), (C11), (C13)
zero_points nombre varié de constantes entières (C8-C10)

Les types d'éléments quantifiés représentent des valeurs entières d'un type de stockage comprises entre storage_min et storage_max (inclus) qui correspondent à des valeurs à virgule flottante d'un type exprimé. Pour une valeur entière i donnée, la valeur à virgule flottante correspondante f peut être calculée comme f = (i - zero_point) * scale, où scale et zero_point sont appelés paramètres de quantification. Les éléments storage_min et storage_max sont facultatifs dans la grammaire, mais leurs valeurs par défaut sont respectivement min_value(storage_type) et max_value(storage_type). Les types d'éléments quantifiés présentent les contraintes suivantes:

  • (C1) num_bits(storage_type) < num_bits(expressed_type).
  • (C2) type(storage_min) = storage_type.
  • (C3) type(storage_max) = storage_type.
  • (C4) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C5) type(scales...) = expressed_type.
  • (C6) 0 < scales.
  • (C7) is_finite(scales...).
  • (C8) storage_min <= zero_points <= storage_max.
  • (C9) type(zero_points...) = storage_type.
  • (C10) size(scales) = size(zero_points).
  • (C11) Si la valeur est is_empty(quantization_dimension), alors size(scales) = 1.
  • (C12) 0 <= quantization_dimension.

Pour le moment, QuantizationScale est une constante à virgule flottante, mais il existe un intérêt important pour les échelles basées sur des entiers, représentées par des multiplicateurs et des décalages. Nous prévoyons d'explorer cette option prochainement (#1404).

Une discussion en cours sur la sémantique de QuantizationZeroPoint, y compris le type et les valeurs, et s'il peut y avoir un seul ou potentiellement plusieurs points zéros dans un type de Tensor quantifié Sur la base des résultats de cette discussion, la spécification autour de zéro point est susceptible d'être modifiée à l'avenir (#1405).

Une autre discussion en cours concerne la sémantique de QuantizationStorageMin et QuantizationStorageMax pour déterminer si des contraintes doivent être imposées à ces valeurs et aux valeurs des Tensors quantifiés (#1406).

Enfin, nous prévoyons d'explorer la représentation d'échelles inconnues et de zéros points, de la même manière que nous prévoyons d'explorer la représentation de tailles de dimension inconnues (#1407).

Les types de Tensor quantifiés représentent des Tensors avec des éléments quantifiés. Ces Tensors sont exactement identiques aux Tensors standards, si ce n'est que leurs éléments ont des types d'éléments quantifiés, et non des types d'éléments standards.

Dans les Tensors quantifiés, la quantification peut être par Tensor, c'est-à-dire avoir un scale et un zero_point pour l'ensemble du Tensor, ou par axe, c'est-à-dire avoir plusieurs scales et zero_points, une paire par tranche d'un quantization_dimension particulier de dimension. Plus formellement, dans un Tensor t avec quantification par axe, il existe des tranches dim(t, quantization_dimension) de quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], etc. Tous les éléments de la ie tranche utilisent scales[i] et zero_points[i] comme paramètres de quantification. Les types de Tensors quantifiés présentent les contraintes suivantes:

  • Pour la quantification par Tensor :
    • Aucune contrainte supplémentaire.
  • Pour la quantification par axe :
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Les types de jetons représentent des jetons, c'est-à-dire des valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer l'ordre d'exécution des opérations, comme décrit dans la section Exécution.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Les types de uple représentent des tuples, c'est-à-dire des listes hétérogènes. Tuples est une ancienne fonctionnalité qui n'existe que pour assurer la compatibilité avec HLO. Dans HLO, les tuples sont utilisés pour représenter les entrées et les sorties variables. Dans StableHLO, les entrées et les sorties variables sont prises en charge de manière native. La seule utilisation de tuples dans StableHLO consiste à représenter de manière exhaustive l'ABI HLO, où T, tuple<T> et tuple<tuple<T>> peuvent, par exemple, être sensiblement différents en fonction d'une implémentation particulière. À l'avenir, nous prévoyons de modifier l'ABI HLO pour pouvoir supprimer les types de tuples de StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Les types d'éléments représentent des éléments de types de Tensor. Contrairement à de nombreux langages de programmation, ces types ne constituent pas la première classe de StableHLO. Cela signifie que les programmes StableHLO ne peuvent pas représenter directement des valeurs de ces types (par conséquent, il est idiomatique de représenter des valeurs scalaires de type T avec des valeurs de Tensor à 0 dimension de type tensor<T>).

  • Le type booléen représente les valeurs booléennes true et false.
  • Les types d'entiers peuvent être signés (si) ou non signés (ui) et avoir l'une des largeurs de bits compatibles (4, 8, 16, 32 ou 64). Les types siN signés représentent des valeurs entières comprises entre -2^(N-1) et 2^(N-1)-1, et les types uiN non signés représentent des valeurs entières comprises entre 0 et 2^N-1 inclus.
  • Les types à virgule flottante peuvent être :
  • Les types complexes représentent des valeurs complexes qui ont une partie réelle et une partie imaginaire du même type d'élément. Les types complexes compatibles sont complex<f32> (les deux parties sont de type f32) et complex<f64> (les deux parties sont de type f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Les types de fonctions représentent à la fois des fonctions nommées et anonymes. Ils ont des types d'entrée (la liste des types se trouvant à gauche dans ->) et des types de sortie (liste des types à droite de ->). Dans de nombreux langages de programmation, les types de fonctions sont la première classe, mais pas dans StableHLO.

StringType ::= 'string'

Le type de chaîne représente des séquences d'octets. Contrairement à de nombreux langages de programmation, le type de chaîne n'est pas la première classe de StableHLO. Il n'est utilisé que pour spécifier des métadonnées statiques pour les éléments de programme.

Opérations

Les opérations StableHLO (également appelées opérations) représentent un ensemble fermé d'opérations de haut niveau dans les modèles de machine learning. Comme indiqué ci-dessus, la syntaxe StableHLO s'inspire fortement de MLIR, qui n'est pas nécessairement l'alternative la plus ergonomique, mais elle est sans doute la mieux adaptée à l'objectif de StableHLO, qui est de créer une plus grande interopérabilité entre les frameworks de ML et les compilateurs de ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Les opérations StableHLO (également appelées opérations) ont un nom, des entrées/sorties et une signature. Le nom se compose du préfixe stablehlo. et d'un mnémonique qui identifie de manière unique l'une des opérations compatibles. Vous trouverez ci-dessous la liste complète de toutes les opérations compatibles.

À l'heure actuelle, les programmes StableHLO à l'état actuel contiennent parfois des opérations qui ne sont pas décrites dans ce document. À l'avenir, nous prévoyons d'absorber ces opérations dans l'opération StableHLO ou de les interdire d'apparaître dans les programmes StableHLO. En attendant, voici la liste de ces opérations:

  • builtin.module, func.func, func.call et func.return (#425).
  • Opérations chlo (#602).
  • Catégorie "Pas dans l'ordre HLO" des opérations StableHLO : elles faisaient initialement partie de l'opération StableHLO, mais ont par la suite été considérées comme ne convenant pas bien : broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (n° 3).
  • Catégorie "Dynamism" des opérations StableHLO : elles ont été lancées à partir du MHLO, mais nous ne les avons pas encore spécifiées : compute_reshape_shape, cstr_reshapable, dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, real_dynamic_slice, set_dimension_size (n° 8).
  • Calculs de forme, y compris les opérations arith, shape et tensor (#8).
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Les opérations consomment des entrées et produisent des sorties. Les entrées sont classées en valeurs d'entrée (calculées lors de l'exécution), fonctions d'entrée (fournies de manière statique, car les fonctions StableHLO ne sont pas des valeurs de première classe) et attributs d'entrée (également fournis de manière statique). Le type d'entrées et de sorties consommées et produites par une opération dépend de son mnémonique. Par exemple, l'opération add utilise deux valeurs d'entrée et génère une valeur de sortie. En comparaison, l'opération select_and_scatter utilise 3 valeurs d'entrée, 2 fonctions d'entrée et 3 attributs d'entrée.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Les fonctions d'entrée (également appelées fonctions anonymes) sont très semblables aux fonctions nommées, sauf que: 1) elles n'ont pas d'identifiant (d'où leur nom "anonyme"), 2) elles ne déclarent pas de types de sortie (les types de sortie sont déduits de l'opération return dans la fonction).

La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir la production de Unused ci-dessus), qui assure la compatibilité avec MLIR. Dans MLIR, il existe un concept plus général de "régions" qui peut comporter plusieurs "blocs" d'opérations connectés entre eux via des opérations de saut. Ces blocs ont des identifiants qui correspondent à la production Unused, de sorte qu'ils peuvent être distingués les uns des autres. StableHLO n'a pas d'opérations de saut. Par conséquent, la partie correspondante de la syntaxe MLIR n'est pas utilisée (mais est toujours présente).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Les attributs d'entrée ont un nom et une valeur qui est l'une des constantes acceptées. Il s'agit du moyen principal de spécifier des métadonnées statiques pour les éléments d'un programme. Par exemple, l'opération concatenate utilise l'attribut dimension pour spécifier la dimension avec laquelle ses valeurs d'entrée sont concaténées. De même, l'opération slice utilise plusieurs attributs tels que start_indices et limit_indices pour spécifier les limites utilisées pour diviser la valeur d'entrée.

À l'heure actuelle, les programmes StableHLO dans la nature contiennent parfois des attributs qui ne sont pas décrits dans ce document. À l'avenir, nous prévoyons d'intégrer ces attributs dans l'opération StableHLO ou d'en interdire l'affichage dans les programmes StableHLO. En attendant, voici la liste de ces attributs:

  • layout (#629)
  • mhlo.frontend_attributes (#628)
  • mhlo.sharding (#619)
  • output_operand_aliases (#740)
  • Métadonnées de localisation (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

La signature Op comprend les types de toutes les valeurs d'entrée (la liste des types à gauche de ->) et les types de toutes les valeurs de sortie (la liste des types à droite de ->). À proprement parler, les types d'entrée sont redondants et les types de sortie sont presque toujours redondants également (car, pour la plupart des opérations StableHLO, les types de sortie peuvent être déduits à partir des entrées). Néanmoins, la signature d'opération est délibérément intégrée à la syntaxe StableHLO pour assurer la compatibilité avec MLIR.

Vous trouverez ci-dessous un exemple d'opération dont le mnémonique est select_and_scatter. Elle utilise 3 valeurs d'entrée (%operand, %source et %init_value), 2 fonctions d'entrée et 3 attributs d'entrée (window_dimensions, window_strides et padding). Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée (mais pas les types de fonctions d'entrée et d'attributs fournis de façon intégrée).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Constantes

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Les constantes StableHLO ont un littéral et un type, qui représentent ensemble une valeur StableHLO. En règle générale, le type fait partie de la syntaxe constante, sauf lorsqu'il est sans ambiguïté (par exemple, une constante booléenne est sans ambiguïté de type i1, alors qu'une constante entière peut avoir plusieurs types possibles).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Les constantes booléennes représentent les valeurs booléennes true et false. Les constantes booléennes sont de type i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Les constantes entières représentent des valeurs entières via des chaînes utilisant la notation décimale ou hexadécimale. Les autres bases (binaires ou octales, par exemple) ne sont pas compatibles. Les constantes entières présentent les contraintes suivantes:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Les constantes à virgule flottante représentent les valeurs à virgule flottante via des chaînes utilisant la notation décimale ou scientifique. En outre, la notation hexadécimale peut être utilisée pour spécifier directement les bits sous-jacents au format à virgule flottante du type correspondant. Les constantes à virgule flottante sont soumises aux contraintes suivantes:

  • (C1) Si vous utilisez une notation non hexadécimale, is_wellformed(float_literal, float_type).
  • (C2) Si vous utilisez la notation hexadécimale, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Les constantes complexes représentent des valeurs complexes à l'aide de listes composées d'une partie réelle (située en premier) et d'une partie imaginaire (située en deuxième). Par exemple, (1.0, 0.0) : complex<f32> représente 1.0 + 0.0i, et (0.0, 1.0) : complex<f32> représente 0.0 + 1.0i. L'ordre dans lequel ces parties sont ensuite stockées en mémoire est défini par l'implémentation. Les constantes complexes présentent les contraintes suivantes:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Les constantes Tensor représentent les valeurs de Tensor à l'aide de listes imbriquées spécifiées via la notation NumPy. Par exemple, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> représente une valeur de Tensor avec les mappages suivants des index aux éléments : {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. L'ordre dans lequel ces éléments sont ensuite stockés en mémoire est défini par l'implémentation. Les constantes de Tensor présentent les contraintes suivantes:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), où :
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), où :
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • Sinon, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Les constantes de Tensor quantifiées représentent les valeurs de Tensor quantifiées en utilisant la même notation que les constantes de Tensor, avec des éléments spécifiés en tant que constantes de leur type de stockage. Les constantes de Tensor quantifiées présentent les contraintes suivantes:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Les littéraux de chaîne sont constitués d'octets spécifiés à l'aide de caractères ASCII et de séquences d'échappement. Ils sont indépendants de l'encodage. L'interprétation de ces octets est donc définie par l'implémentation. Les littéraux de chaîne sont de type string.

Opérations

abs

Sémantique

Effectue une opération Absolu par élément sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les entiers signés: module d'entiers.
  • Pour les flottants: abs à partir de la norme IEEE-754.
  • Pour les nombres complexes: module complexe.
  • Pour les types quantifiés: dequantize_op_quantize(abs, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor d'un entier signé, d'un type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1-C2)

Sorties

Nom Type Contraintes
result Tensor de type entier signé, à virgule flottante ou d'un Tensor quantifié par Tensor (C1-C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • baseline_element_type(operand) dans les autres cas.

Exemples

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Autres exemples

add

Sémantique

Effectue une addition par élément de deux Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: opérateur logique OR.
  • Pour les nombres entiers: addition de nombres entiers.
  • Pour les flottants: addition à partir de la norme IEEE-754.
  • Pour les nombres complexes: addition complexe.
  • Pour les types quantifiés: dequantize_op_quantize(add, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Autres exemples

after_all

Sémantique

Permet de s'assurer que les opérations qui génèrent le inputs sont exécutées avant toute opération qui dépend de result. L'exécution de cette opération n'a aucun effet. Elle sert uniquement à établir des dépendances de données entre result et inputs.

Entrées

Libellé Nom Type
(I1) inputs nombre variable de token

Sorties

Nom Type
result token

Exemples

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Autres exemples

all_gather

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs du Tensor operand de chaque processus avec all_gather_dim et génère un Tensor result.

L'opération divise la grille de processus StableHLO en process_groups, défini comme suit:

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group:

  • operands@receiver = [operand@sender for sender in process_group] pour tous les receiver de process_group.
  • result@process = concatenate(operands@process, all_gather_dim) pour tous les process de process_group.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C6)
(I2) all_gather_dim constante de type si64 (C1), (C6)
(I3). replica_groups Constante de Tensor bidimensionnelle de type si64 (C2-C4)
(I4). channel_id constante de type si64 (C5)
(I5). use_global_device_ids constante de type i1 (C5)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C6)

Contraintes

  • (C1) 0 <= all_gather_dim < rank(operand).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Si la valeur est use_global_device_ids = true, alors channel_id > 0.
  • (C6) type(result) = type(operand), sauf :
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

Autres exemples

all_reduce

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, applique une fonction de réduction computation aux valeurs du Tensor operand de chaque processus et génère un Tensor result.

L'opération divise la grille de processus StableHLO en process_groups, défini comme suit:

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group:

  • result@process[result_index] = exec(schedule) pour une arborescence binaire schedule où :
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est une arborescence binaire définie par l'implémentation dont le balayage dans l'ordre est to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C5), (C6)
(I2) replica_groups nombre varié de constantes de Tensor unidimensionnelles de type si64 (C1-C3)
(I3). channel_id constante de type si64 (C4)
(I4). use_global_device_ids constante de type i1 (C4)
(I5). computation function (C5)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C6-C7)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Si la valeur est use_global_device_ids = true, alors channel_id > 0.
  • (C5) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>), où is_promotable(element_type(operand), E).
  • (C6) shape(result) = shape(operand).
  • (C7) element_type(result) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

Autres exemples

all_to_all

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, il divise les valeurs du Tensor operand le long de split_dimension en plusieurs parties, disperse les parties divisées entre les processus, concatène les parties dispersées le long de concat_dimension et génère un Tensor result.

L'opération divise la grille de processus StableHLO en process_groups, défini comme suit:

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

Ensuite, dans chaque process_group:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) pour tous les sender dans process_group.
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]receiver_index = process_group.index(receiver).
  • result@process = concatenate(scattered_parts@process, concat_dimension).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C3), (C9)
(I2) split_dimension constante de type si64 (C1), (C2), (C9)
(I3). concat_dimension constante de type si64 (C3), (C9)
(I4). split_count constante de type si64 (C2), (C4), (C8), (C9)
(I5). replica_groups Constante de Tensor bidimensionnelle de type si64 (C5-C8)
(I6). channel_id constante de type si64

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C9)

Contraintes

  • (C1) 0 <= split_dimension < rank(operand).
  • (C2) dim(operand, split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operand).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(result) = type(operand), sauf :
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

Autres exemples

et

Sémantique

Effectue l'opérateur ET par élément des deux Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: AND.
  • Pour les entiers: AND au niveau du bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type booléen ou entier (C1)
(I2) rhs Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

Sémantique

Effectue une opération atan2 par élément sur le Tensor lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: atan2 à partir de la norme IEEE-754.
  • Pour les nombres complexes: atan2 complexe.
  • Pour les types quantifiés: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Autres exemples

batch_norm_grad

Sémantique

Calcule les gradients de plusieurs entrées de rétropropagation batch_norm_training à partir de grad_output, et génère les Tensors grad_operand, grad_scale et grad_offset. Plus formellement, cette opération peut être exprimée sous forme de décomposition d'opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1-C3), (C5)
(I2) scale Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C4), (C5)
(I3). mean Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C4)
(I4). variance Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C4)
(I5). grad_output Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C2), (C3)
(I6). epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C5)

Sorties

Nom Type Contraintes
grad_operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C2), (C3)
grad_scale Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C4)
grad_offset Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C4)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale et grad_offset ont le même baseline_element_type.
  • (C3) operand, grad_output et grad_operand ont la même forme.
  • (C4) scale, mean, variance, grad_scale et grad_offset ont la même forme.
  • (C5) size(scale) = dim(operand, feature_index).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Sémantique

Normalise le Tensor operand pour toutes les dimensions, à l'exception de la dimension feature_index, et génère un Tensor result. Plus formellement, cette opération peut être exprimée sous la forme d'une décomposition d'opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Pour les types quantifiés, exécute dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1-C7)
(I2) scale Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C3)
(I3). offset Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C4)
(I4). mean Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C5)
(I5). variance Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor (C2), (C6)
(I6). epsilon constante de type f32
(I7) feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C2), (C7)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance et result ont le même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Sémantique

Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de la dimension feature_index, et normalise le Tensor operand qui produit les Tensors output, batch_mean et batch_var. Plus formellement, cette opération peut être exprimée sous la forme d'une décomposition d'opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)
(I2) scale Tensor unidimensionnel à virgule flottante ou quantifié par Tensor (C2), (C3)
(I3). offset Tensor unidimensionnel à virgule flottante ou quantifié par Tensor (C2), (C4)
(I4). epsilon constante de type f32 (C1), (C3-C6)
(I5). feature_index constante de type si64 (C1), (C3-C6)

Sorties

Nom Type Contraintes
output Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C7)
batch_mean Tensor unidimensionnel à virgule flottante ou quantifié par Tensor (C2), (C5)
batch_var Tensor unidimensionnel à virgule flottante ou quantifié par Tensor (C2), (C6)

Contraintes

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var et output ont le même baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Exemples

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Sémantique

Effectue une opération de bitcast sur le Tensor operand et génère un Tensor result, dans lequel les bits de l'ensemble du Tensor operand sont réinterprétés à l'aide du type du Tensor result.

Plus formellement, avec E = element_type(operand), E' = element_type(result) et R = rank(operand):

  • Si la valeur est num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Si la valeur est num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Si la valeur est num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits renvoie la représentation en mémoire d'une valeur donnée, et son comportement est défini par l'implémentation, car la représentation exacte des Tensors est définie par l'implémentation et la représentation exacte des types d'éléments est également définie par l'implémentation.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C2)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1-C2)

Contraintes

  • (C1) Avec E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) et R = rank(operand) :
    • Si la valeur est num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Si la valeur est num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Si la valeur est num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) pour tous les 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Si la valeur est is_complex(operand) or is_complex(result), alors is_complex(operand) and is_complex(result).

Exemples

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Autres exemples

broadcast_in_dim

Sémantique

Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données dans le Tensor operand et génère un Tensor result. Plus formellement, result[result_index] = operand[operand_index] où pour tous les d dans axes(operand):

  • operand_index[d] = 0 si dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C2), (C5-C6)
(I2) broadcast_dimensions Constante de Tensor unidimensionnelle de type si64 (C2-C6)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1), (C3), (C5-C6)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), si ce n'est que quantization_dimension(operand), scales(operand) et zero_points(operand) peuvent être différents de quantization_dimension(result), scales(result) et zero_points(result), sinon.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Pour tous les d de axes(operand) :
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si is_per_axis_quantized(result) :
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si la valeur est dim(operand, quantization_dimension(operand)) = 1, alors scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Exemples

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Autres exemples

demande

Sémantique

Génère le résultat de l'exécution d'une seule fonction à partir de branches en fonction de la valeur de index. Plus formellement, result = selected_branch() où:

  • selected_branch = branches[index] si 0 <= index < size(branches).
  • selected_branch = branches[-1] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) index Tensor à 0 dimension de type si32
(I2) branches nombre variable de fonctions (C1-C4)

Sorties

Nom Type Contraintes
results nombre varié de Tensors, de Tensors quantifiés ou de jetons (C4)

Contraintes

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Exemples

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Autres exemples

RTL

Sémantique

Effectue une opération de racine cubique par élément sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: rootn(x, 3) à partir de la norme IEEE-754.
  • Pour les nombres complexes: racine cubique complexe.
  • Pour les types quantifiés: dequantize_op_quantize(cbrt, operand, type(result))

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Autres exemples

ceil

Sémantique

Effectue le ceil par élément du Tensor operand et génère un Tensor result. Met en œuvre l'opération roundToIntegralTowardPositive à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(ceil, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Autres exemples

Cholesky

Sémantique

Calcule la décomposition de cholésique d'un lot de matrices.

Plus formellement, pour tous les i de index_space(result), result[i0, ..., iR-3, :, :] est une décomposition cholématique de a[i0, ..., iR-3, :, :], sous la forme d'une matrice triangulaire inférieure (si lower correspond à true) ou d'une matrice triangulaire supérieure (si lower est false). Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict ou le triangle inférieur strict, sont définies par l'implémentation.

S'il existe une valeur i dans laquelle la matrice d'entrée n'est pas une matrice positive positive de l'hermite, le comportement n'est pas défini.

Pour les types quantifiés, exécute dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1-C3)
(I2) lower Constante de Tensor à 0 dimension de type i1

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Exemples

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

limiter

Sémantique

Applique chaque élément du Tensor operand entre une valeur minimale et une valeur maximale, et génère un Tensor result. Plus formellement, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), où min_element = rank(min) = 0 ? min[] : min[result_index] et max_element = rank(max) = 0 ? max[] : max[result_index]. Pour les types quantifiés, effectue dequantize_op_quantize(clamp, min, operand, max, type(result)).

L'application d'un ordre à des nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes pour cette opération (#560).

Entrées

Libellé Nom Type Contraintes
(I1) min Tensor ou Tensor quantifié par Tensor (C1), (C3)
(I2) operand Tensor ou Tensor quantifié par Tensor (C1-C4)
(I3). max Tensor ou Tensor quantifié par Tensor (C2), (C3)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C4)

Contraintes

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Exemples

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Autres exemples

collective_broadcast

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du Tensor operand du processus source aux processus cibles et générez un Tensor result.

L'opération divise la grille de processus StableHLO en process_groups, défini comme suit:

  • cross_replica(replica_groups) si channel_id <= 0.
  • cross_partition(replica_groups) si channel_id > 0.

Ensuite, result@process est donné par:

  • operand@process_groups[i, 0] s'il existe un élément i tel que le processus se trouve dans process_groups[i].
  • broadcast_in_dim(constant(0, element_type(result)), [], type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor (C3)
(I2) replica_groups nombre varié de constantes de Tensor unidimensionnelles de type si64 (C1), (C2)
(I3). channel_id constante de type si64

Sorties

Nom Type Contraintes
result Tensor (C3)

Contraintes

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, où N est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C3) type(result) = type(operand).

Exemples

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du Tensor operand du processus source au processus cible et génère un Tensor result.

L'opération divise la grille de processus StableHLO en process_groups, défini comme suit:

  • cross_replica(source_target_pairs) si channel_id <= 0.
  • cross_partition(source_target_pairs) si channel_id > 0.

Ensuite, result@process est donné par:

  • operand@process_groups[i, 0], s'il existe un i tel que process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C5)
(I2) source_target_pairs Constante de Tensor bidimensionnelle de type si64 (C1-C4)
(I3). channel_id constante de type si64

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, où N est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_partitions si cross_partition est utilisé.
  • (C5) type(result) = type(operand).

Exemples

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Autres exemples

compare

Sémantique

Effectue une comparaison élément par élément des Tensors lhs et rhs selon comparison_direction et compare_type, et génère un Tensor result.

Les valeurs de comparison_direction et compare_type ont la sémantique suivante:

Pour les types d'éléments booléens et entiers:

  • EQ : lhs = rhs.
  • NE : lhs != rhs.
  • GE : lhs >= rhs.
  • GT : lhs > rhs.
  • LE : lhs <= rhs.
  • LT : lhs < rhs.

Pour les types d'éléments à virgule flottante avec compare_type = FLOAT, l'opération implémente les opérations IEEE-754 suivantes:

  • EQ : compareQuietEqual.
  • NE : compareQuietNotEqual.
  • GE : compareQuietGreaterEqual.
  • GT : compareQuietGreater.
  • LE : compareQuietLessEqual.
  • LT : compareQuietLess.

Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER, l'opération utilise la combinaison des opérations totalOrder et compareQuietEqual de la norme IEEE-754. Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#584).

Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag) est effectuée à l'aide des éléments comparison_direction et compare_type fournis. L'application d'un ordre aux nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes lorsque comparison_direction correspond à GE, GT, LE ou LT (#560).

Pour les types quantifiés. Exécute dequantize_compare(lhs, rhs, comparison_direction).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1-C3)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1-C2)
(I3). comparison_direction énumération de EQ, NE, GE, GT, LE et LT
(I4). compare_type énumération de FLOAT, TOTALORDER, SIGNED et UNSIGNED (C3)

Sorties

Nom Type Contraintes
result Tensor de type booléen (C2)

Contraintes

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type est défini comme suit :
    • SIGNED si is_signed_integer(element_type(lhs)).
    • UNSIGNED si is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT ou TOTALORDER si is_float(element_type(lhs)).
    • FLOAT si is_complex(element_type(lhs)).

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Autres exemples

complexe

Sémantique

Effectue une conversion par élément en valeur complexe à partir d'une paire de valeurs réelles et imaginaires, lhs et rhs, et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type f32 ou f64 (C1-C3)
(I2) rhs Tensor de type f32 ou f64 (C1)

Sorties

Nom Type Contraintes
result Tensor de type complexe (C2), (C3)

Contraintes

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) est de type complex<E>, où E = element_type(lhs).

Exemples

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Autres exemples

concatenate

Sémantique

Concatène inputs selon la dimension dimension dans le même ordre que les arguments donnés et génère un Tensor result. Plus formellement, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], où:

  1. id = d0 + ... + dk-1 + kd.
  2. d est égal à dimension, et d0, ... est la de taille de dimension de inputs.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre varié de Tensors ou Tensors quantifiés par Tensor (C1-C6)
(I2) dimension constante de type si64 (C2), (C4), (C6)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C5-C6)

Contraintes

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)), sauf pour dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]), sauf :
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Exemples

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Autres exemples

constante

Sémantique

Génère un Tensor output à partir d'une constante value.

Entrées

Libellé Nom Type Contraintes
(I1) value constante (C1)

Sorties

Nom Type Contraintes
output Tensor ou Tensor quantifié (C1)

Contraintes

  • (C1) type(value) = type(output).

Exemples

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Autres exemples

d'effectuer une conversion

Sémantique

Effectue une conversion par élément d'un type d'élément à un autre sur le Tensor operand et génère un Tensor result.

Pour les conversions boolean-to-any-supported-type, la valeur false est convertie en zéro, et la valeur true est convertie en un. Pour les conversions de type any-supported-type-to-boolean, une valeur nulle est convertie en false, et les valeurs non nulles sont converties en true. Reportez-vous à la section ci-dessous pour en savoir plus sur le fonctionnement des types complexes.

Pour les conversions impliquant un nombre entier à entier, un nombre entier à virgule flottante ou un point flottant à virgule flottante, si la valeur source peut être exactement représentée dans le type de destination, la valeur obtenue sera cette représentation exacte. Sinon, le comportement est à déterminer (#180).

Pour les conversions impliquant une floating-point-to-integer, la partie fractionnaire est tronquée. Si la valeur tronquée ne peut pas être représentée dans le type de destination, le comportement est à déterminer (#180).

Les conversions impliquant complexe à complexe suivent le même comportement que les conversions de point flottant à virgule flottante pour la conversion de pièces réelles et imaginaires.

Pour les conversions de type complex-to-any-other-type et complex-to-any-other-type, la valeur imaginaire source est ignorée ou la valeur imaginaire de la destination est remise à zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.

En principe, cette opération pourrait exprimer la déquantification (conversion de Tensors quantifiés en Tensors réguliers), la quantification (conversion de Tensors standards en Tensors quantifiés) et la requantification (conversion entre des Tensors quantifiés). Toutefois, pour le moment, nous disposons d'opérations dédiées (uniform_dequantize pour le premier cas d'utilisation et uniform_quantize pour le deuxième et le troisième). À l'avenir, ces deux opérations pourront être fusionnées dans convert (#1576).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor (C1)

Contraintes

  • (C1) shape(operand) = shape(result).

Exemples

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Autres exemples

convolution

Sémantique

Calcule les produits scalaires entre les fenêtres de lhs et les tranches de rhs, et génère la valeur result. Le schéma suivant montre comment les éléments de result sont calculés à partir de lhs et rhs à l'aide d'un exemple concret.

Plus formellement, considérons le recadrage suivant des entrées en termes de lhs afin de pouvoir exprimer des fenêtres de lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Ce recadrage utilise les fonctions d'assistance suivantes:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1], où j[d] = i[permutation[d]].

Si feature_group_count = 1 et batch_group_count = 1, alors pour tous les output_spatial_index dans index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product où:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Si la valeur est feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Si la valeur est batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Pour les types quantifiés, exécute dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1), (C10-C11), (C14) (C25), (C27-C30)
(I2) rhs Tensor ou Tensor quantifié (C1), (C14-C16), (C25), (C27-C32)
(I3). window_strides Constante de Tensor unidimensionnelle de type si64 (C2-C3), (C25)
(I4). padding Constante de Tensor bidimensionnelle de type si64 (C4), (C25)
(I5). lhs_dilation Constante de Tensor unidimensionnelle de type si64 (C5-C6), (C25)
(I6). rhs_dilation Constante de Tensor unidimensionnelle de type si64 (C7-C8), (C25)
(I7) window_reversal Constante de Tensor unidimensionnelle de type i1 (C9)
(I8). input_batch_dimension constante de type si64 (C10), (C13), (C25)
(I9). input_feature_dimension constante de type si64 (C11), (C13-C14)
(I10). input_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C12), (C13), (C25)
(I11). kernel_input_feature_dimension constante de type si64 (C14), (C18)
(I12). kernel_output_feature_dimension constante de type si64 (C15-C16), (C18), (C25), (C32)
(I13). kernel_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C17-C18), (C25)
(I14). output_batch_dimension constante de type si64 (C20), (C25)
(I15). output_feature_dimension constante de type si64 (C20), (C25), (C33)
(I16). output_spatial_dimensions Constante de Tensor unidimensionnelle de type si64 (C19-C20), (C25)
(I17). feature_group_count constante de type si64 (C11), (C14), (C16), (C21), (C23)
(I18). batch_group_count constante de type si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config nombre variable d'énumérations DEFAULT, HIGH et HIGHEST (C24)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C25-C28), (C30-C31), (C33)

Contraintes

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Compte input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension] donné :
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Compte kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension] donné :
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Compte output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension] donné :
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) est défini comme suit :
    • dim(lhs, input_batch_dimension) / batch_group_count si result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si result_dim = output_feature_dimension.
    • num_windows dans les autres cas, où:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si l'opération utilise des Tensors non quantifiés :
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si l'opération utilise des Tensors quantifiés :
    • (C28) is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result).
    • (C29) storage_type(lhs) = storage_type(rhs).
    • (C30) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C31) Si la valeur est is_per_tensor_quantized(rhs), alors is_per_tensor_quantized(result).
    • (C32) Si la valeur est is_per_axis_quantized(rhs), alors quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C33) Si la valeur est is_per_axis_quantized(result), alors quantization_dimension(result) = output_feature_dimension.

Exemples

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

cosinus

Sémantique

Effectue une opération de cosinus par élément sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: cos à partir de la norme IEEE-754.
  • Pour les nombres complexes: cosinus complexe.
  • Pour les types quantifiés: dequantize_op_quantize(cosine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Autres exemples

count_leading_zeros

Sémantique

Effectue un comptage par élément du nombre de bits zéros au début dans le Tensor operand et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Autres exemples

custom_call

Sémantique

Encapsule une opération définie par l'implémentation call_target_name qui reçoit inputs et called_computations, et génère results. Vous pouvez utiliser has_side_effect, backend_config et api_version pour fournir des métadonnées supplémentaires définies par l'implémentation.

À l'heure actuelle, cette opération contient une collection de métadonnées relativement désorganisée qui reflète l'évolution organique de son équivalent dans le compilateur XLA. Nous prévoyons d'unifier ces métadonnées à l'avenir (#741).

Entrées

Libellé Nom Type
(I1) inputs nombre varié de valeurs
(I2) call_target_name constante de type string
(I3). has_side_effect constante de type i1
(I4). backend_config constante de type string
(I5). api_version constante de type si32
(I6). called_computations nombre varié de constantes de type string

Sorties

Nom Type
results nombre varié de valeurs

Exemples

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

diviser

Sémantique

Effectue une division par élément des Tensors de dividende lhs et des Tensors rhs du diviseur, et produit un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les entiers: division d'entiers qui produit le quotient algébrique avec toute partie fractionnaire rejetée.
  • Pour les flottants: division à partir de la norme IEEE-754.
  • Pour les nombres complexes: division complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)
(I2) rhs Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Autres exemples

dot_general

Sémantique

Calcule les produits scalaires entre les tranches de lhs et les tranches de rhs, et génère un Tensor result.

Plus formellement, result[result_index] = dot_product, où:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index, où size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) et size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Pour les types quantifiés, exécute dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Cette valeur spécifie uniquement la sémantique pour la quantification par Tensor. La quantification par axe est en cours (#1574). Nous pourrions également envisager d'ajouter la compatibilité avec la quantification hybride à l'avenir (#1575).

precision_config contrôle le compromis entre vitesse et précision des calculs sur les backends d'accélérateur. Il peut s'agir de l'un des éléments suivants (à l'heure actuelle, la sémantique de ces valeurs d'énumération est sous-spécifiée, mais nous prévoyons de résoudre ce problème à la n° 755):

  • DEFAULT: calcul le plus rapide, mais approximation la moins précise du nombre d'origine.
  • HIGH: calcul plus lent, mais approximation plus précise du nombre d'origine.
  • HIGHEST: calcul le plus lent, mais approximation la plus précise par rapport au nombre d'origine.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C5-C6), (C9-C10), (C12-C16)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C7-C10), (C12)
(I3). lhs_batching_dimensions Constante de Tensor unidimensionnelle de type si64 (C1), (C3), (C5), (C9), (C12)
(I4). rhs_batching_dimensions Constante de Tensor unidimensionnelle de type si64 (C1), (C4), (C7), (C9)
(I5). lhs_contracting_dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C3), (C6), (C10)
(I6). rhs_contracting_dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C8), (C10)
(I7) precision_config nombre variable d'énumérations DEFAULT, HIGH et HIGHEST (C11)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C12), (C14), (C16)

Contraintes

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Si l'opération utilise des Tensors non quantifiés :
    • (C13) element_type(lhs) = element_type(rhs).
  • Si l'opération utilise des Tensors quantifiés :
    • (C14) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C15) storage_type(lhs) = storage_type(rhs).
    • (C16) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C17) zero_points(rhs) = 0.

Exemples

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Autres exemples

dynamic_slice

Sémantique

Extrait une tranche de operand à l'aide d'indices de départ calculés de manière dynamique et génère un Tensor result. start_indices contient les index de départ de la tranche pour chaque dimension susceptible d'être ajustée, et slice_sizes contient les tailles de la tranche pour chaque dimension. Plus formellement, result[result_index] = operand[operand_index] où:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C4)
(I2) start_indices nombre varié de Tensors à 0 dimension de type entier (C2), (C3)
(I3). slice_sizes Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C5)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Exemples

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Autres exemples

dynamic_update_slice

Sémantique

Génère un Tensor result égal au Tensor operand, à la différence que la tranche commençant à start_indices est mise à jour avec les valeurs de update. Plus formellement, result[result_index] est défini comme suit:

  • update[update_index] si 0 <= update_index < shape(update), où :
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C4), (C6)
(I2) update Tensor ou Tensor quantifié par Tensor (C2), (C3), (C6)
(I3). start_indices nombre varié de Tensors à 0 dimension de type entier (C4), (C5)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Exemples

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Autres exemples

exponentiel

Sémantique

Effectue une opération exponentielle par élément sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: exp à partir de la norme IEEE-754.
  • Pour les nombres complexes: exponentielle complexe.
  • Pour les types quantifiés : dequantize_op_quantize(exponential, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Autres exemples

exponential_minus_one

Sémantique

Effectue une opération exponentielle moins 1 par élément sur le Tensor operand et produit un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: expm1 à partir de la norme IEEE-754.
  • Pour les nombres complexes: puissance exponentielle complexe moins un.
  • Pour les types quantifiés : dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Autres exemples

FFT

Sémantique

Effectue les transformations de Fourier avant et inverse pour des entrées/sorties réelles et complexes.

fft_type est l'un des éléments suivants :

  • FFT: transfère FFT complexe à complexe.
  • IFFT: FFT inverse complexe à complexe.
  • RFFT: transfère FFT réel vers complexe.
  • IRFFT: FFT inverse de vrai à complexe (l'opération prend des valeurs complexes et renvoie des valeurs réelles).

Plus formellement, avec la fonction fft, qui prend en entrée des Tensors unidimensionnels de types complexes, elle produit des Tensors unidimensionnels des mêmes types qu'en sortie et calcule la transformation de Fourier discrète:

Pour fft_type = FFT, result est défini comme le résultat final d'une série de calculs L, où L = size(fft_length). Par exemple, pour L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

De plus, compte tenu de la fonction ifft, qui a le même type de signature et calcule l'inverse de fft:

Pour fft_type = IFFT, result est défini comme l'inverse des calculs de fft_type = FFT. Par exemple, pour L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

En outre, grâce à la fonction rfft, qui prend les Tensors unidimensionnels de types à virgule flottante, elle produit des Tensors unidimensionnels de types complexes de la même sémantique à virgule flottante, et fonctionne comme suit:

  • rfft(real_operand) = truncated_result
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Lorsque la transformation de Fourier discrète est calculée pour des opérandes réels, les premiers éléments N/2 + 1 du résultat définissent sans ambiguïté le reste du résultat. Le résultat de rfft est donc tronqué pour éviter de calculer des éléments redondants.)

Pour fft_type = RFFT, result est défini comme le résultat final d'une série de calculs L, où L = size(fft_length). Par exemple, pour L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Enfin, avec la fonction irfft, qui a le même type de signature et calcule l'inverse de rfft:

Pour fft_type = IRFFT, result est défini comme l'inverse des calculs de fft_type = RFFT. Par exemple, pour L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe (C1), (C2), (C4), (C5)
(I2) fft_type énumération de FFT, IFFT, RFFT et IRFFT (C2), (C5)
(I3). fft_length Constante de Tensor unidimensionnelle de type si64 (C1), (C3), (C4)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe (C2), (C4), (C5)

Contraintes

  • (C1) size(fft_length) <= rank(operand).
  • (C2) La relation entre les types d'éléments operand et result varie :
    • Si fft_type = FFT, element_type(operand) et element_type(result) ont le même type complexe.
    • Si fft_type = IFFT, element_type(operand) et element_type(result) ont le même type complexe.
    • Si la valeur est fft_type = RFFT, element_type(operand) est un type à virgule flottante et element_type(result) est un type complexe de la même sémantique à virgule flottante.
    • Si la valeur est fft_type = IRFFT, element_type(operand) est un type complexe et element_type(result) est un type à virgule flottante de la même sémantique à virgule flottante.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) S'il existe entre operand et result, il existe un Tensor real de type à virgule flottante, alors shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand), sauf :
    • Si la valeur est fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Si la valeur est fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Exemples

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

Sémantique

Effectue le plancher élément par élément du Tensor operand et génère un Tensor result. Met en œuvre l'opération roundToIntegralTowardNegative à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(floor, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Autres exemples

rassembler

Sémantique

Recueille les tranches du Tensor operand à partir des décalages spécifiés dans start_indices et génère un Tensor result.

Le schéma suivant montre comment les éléments de result sont mappés sur des éléments de operand à l'aide d'un exemple concret. Le schéma choisit quelques exemples d'indices result et explique en détail à quels index operand ils correspondent.

Plus formellement, result[result_index] = operand[operand_index] où:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index est défini comme suit :
    • start_indices[bi0, ..., :, ..., biN], où bi sont des éléments individuels dans batch_index, et : est inséré au niveau de l'index index_vector_dim, si index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] dans les autres cas.
  • Pour d_operand dans axes(operand) :
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) si d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 dans les autres cas.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], où oi sont des éléments individuels dans offset_index, et 0 est inséré aux index de collapsed_slice_dims.
  • operand_index = full_start_index + full_offset_index.

Si indices_are_sorted est défini sur true, l'implémentation peut supposer que les start_indices sont triés par rapport à start_index_map. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2 de indices(result) : full_start_index(i1) <= full_start_index(i2).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices Tensor de type entier (C2), (C3), (C13)
(I3). offset_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C4-C5), (C13)
(I4). collapsed_slice_dims Constante de Tensor unidimensionnelle de type si64 (C1), (C6-C8), (C13)
(I5). start_index_map Constante de Tensor unidimensionnelle de type si64 (C3), (C9), (C10)
(I6). index_vector_dim constante de type si64 (C2), (C3), (C13)
(I7) slice_sizes Constante de Tensor unidimensionnelle de type si64 (C8), (C11-C13)
(I8). indices_are_sorted constante de type i1

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C5), (C13-C14)

Contraintes

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) où :
    • batch_dim_sizes = shape(start_indices), si ce n'est que la taille de la dimension de start_indices correspondant à index_vector_dim n'est pas incluse.
    • offset_dim_sizes = shape(slice_sizes), si ce n'est que les dimensions des dimensions dans slice_sizes correspondant à collapsed_slice_dims ne sont pas incluses.
    • combine place batch_dim_sizes sur les axes correspondant à batch_dims et offset_dim_sizes aux axes correspondant à offset_dims.
  • (C14) element_type(operand) = element_type(result).

Exemples

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Autres exemples

get_dimension_size

Sémantique

Génère la taille de l'élément dimension donné pour operand. Plus formellement, result = dim(operand, dimension).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor (C1)
(I2) dimension constante de type si64 (C1)

Sorties

Nom Type
result Tensor à 0 dimension de type si32

Contraintes

  • (C1) 0 <= dimension < rank(operand).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Autres exemples

get_tuple_element

Sémantique

Extrait l'élément à la position index du tuple operand et génère une result. Plus formellement : result = operand[index].

Entrées

Libellé Nom Type Contraintes
(I1) operand tuple (C1), (C2)
(I2) index constante de type si32 (C1), (C2)

Sorties

Nom Type Contraintes
result tous les types compatibles (C2)

Contraintes

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Exemples

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Autres exemples

if

Sémantique

Génère le résultat de l'exécution d'une seule fonction à partir de true_branch ou false_branch en fonction de la valeur de pred. Plus formellement : result = pred ? true_branch() : false_branch().

Entrées

Libellé Nom Type Contraintes
(I1) pred Tensor à 0 dimension de type i1
(I2) true_branch function (C1-C3)
(I3). false_branch function (C1), (C2)

Sorties

Nom Type Contraintes
results nombre varié de Tensors, de Tensors quantifiés ou de jetons (C3)

Contraintes

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Exemples

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Autres exemples

Imageg

Sémantique

Extrait la partie imaginaire, par élément, de la operand et génère un Tensor result. Plus formellement, pour chaque élément x : imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • element_type(operand) dans les autres cas.

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Autres exemples

flux d'entrée

Sémantique

Lit les données du flux d'entrée et génère results.

La sémantique de infeed_config est définie par l'implémentation.

results sont constitués de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (#670).

Entrées

Libellé Nom Type
(I1) token token
(I2) infeed_config constante de type string

Sorties

Nom Type Contraintes
results nombre varié de Tensors, de Tensors quantifiés ou de jetons (C1-C3)

Contraintes

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Exemples

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Autres exemples

Ito

Sémantique

Remplit un Tensor output avec des valeurs dans l'ordre croissant à partir de zéro le long de la dimension iota_dimension. Plus formellement,

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

Entrées

Libellé Nom Type Contraintes
(I1) iota_dimension si64 (C1)

Sorties

Nom Type Contraintes
output Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) 0 <= iota_dimension < rank(output).

Exemples

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Autres exemples

is_finite

Sémantique

Vérifie au niveau des éléments si la valeur de x est finie (c'est-à-dire si elle n'est ni +Inf, -Inf, ni NaN) et génère un Tensor y. Met en œuvre l'opération isFinite à partir de la spécification IEEE-754. Pour les types quantifiés, le résultat est toujours true.

Entrées

Libellé Nom Type Contraintes
(I1) x Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
y Tensor de type booléen (C1)

Contraintes

  • (C1) shape(x) = shape(y).

Exemples

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Autres exemples

log

Sémantique

Effectue une opération de logarithme par élément sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: log à partir de la norme IEEE-754.
  • Pour les nombres complexes: logarithme complexe.
  • Pour les types quantifiés: dequantize_op_quantize(log, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Autres exemples

log_plus_one

Sémantique

Effectue un logarithme par élément plus une opération sur le Tensor operand et produit un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: logp1 à partir de la norme IEEE-754.
  • Pour les nombres complexes: logarithme complexe plus un.
  • Pour les types quantifiés : dequantize_op_quantize(log_plus_one, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Autres exemples

logistique

Sémantique

Effectue une opération logistique au niveau des éléments sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: division(1, addition(1, exp(-x))) à partir de la norme IEEE-754.
  • Pour les nombres complexes: logistique complexe.
  • Pour les types quantifiés : dequantize_op_quantize(logistic, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Autres exemples

carte

Sémantique

Applique une fonction de mappage computation à inputs le long de dimensions et génère un Tensor result.

Plus formellement : result[result_index] = computation(inputs...[result_index]). Notez que les dimensions ne sont pas utilisés actuellement et seront probablement supprimés à l'avenir (#487).

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre varié de Tensors ou Tensors quantifiés par Tensor (C1-C4)
(I2) dimensions Constante de Tensor unidimensionnelle de type si64 (C3)
(I3). computation function (C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C4)

Contraintes

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation est de type (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, où Ei = element_type(inputs[i]) et E' = element_type(result).

Exemples

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Autres exemples

maximum

Sémantique

Effectue une opération maximale par élément sur les Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: opérateur logique OR.
  • Pour les nombres entiers: nombre entier maximum.
  • Pour les flottants: maximum à partir de la norme IEEE-754.
  • Pour les nombres complexes: valeur lexicographique maximale pour la paire (real, imaginary). L'application d'un ordre à des nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes pour cette opération (#560).
  • Pour les types quantifiés :
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Autres exemples

minimum

Sémantique

Effectue une opération minimale par élément sur les Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: AND.
  • Pour les nombres entiers: nombre entier minimal.
  • Pour les flottants: minimum à partir de la norme IEEE-754.
  • Pour les nombres complexes: valeur lexicographique minimale pour la paire (real, imaginary). L'application d'un ordre à des nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes pour cette opération (#560).
  • Pour les types quantifiés :
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Autres exemples

multiplier

Sémantique

Effectue un produit par élément de deux Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: AND.
  • Pour les nombres entiers: multiplication de nombres entiers.
  • Pour les flottants: multiplication à partir de la norme IEEE-754.
  • Pour les nombres complexes: multiplication complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Autres exemples

negate

Sémantique

Effectue une négation élément par élément du Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les entiers signés: négation des entiers.
  • Pour les entiers non signés: bitcast en entier signé, négation entière, retransmission de bit en entier non signé.
  • Pour les flottants: negate à partir de la norme IEEE-754.
  • Pour les nombres complexes: négation complexe.
  • Pour les types quantifiés : dequantize_op_quantize(negate, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Autres exemples

(ne vivant pas

Sémantique

Effectue l'opérateur NOT au niveau des éléments du Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: logique NOT.
  • Pour les nombres entiers: NOT au niveau du bit.

Arguments

Nom Type Contraintes
operand Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

Sémantique

S'assure que les opérations qui génèrent la operand sont exécutées avant toute opération qui dépend de la result et empêche les transformations du compilateur de déplacer les opérations au-delà de la barrière. En dehors de cela, l'opération est une identité, par exemple result = operand.

Arguments

Nom Type Contraintes
operand nombre varié de Tensors, Tensors quantifiés par Tensor ou jetons (C1)

Sorties

Nom Type Contraintes
result nombre varié de Tensors, Tensors quantifiés par Tensor ou jetons (C1)

Contraintes

  • (C1) type(operand...) = type(result...).

Exemples

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Autres exemples

ou

Sémantique

Effectue l'opérateur OR par élément des deux Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: opérateur logique OR.
  • Pour les entiers: OR au niveau du bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier ou booléen (C1)
(I2) rhs Tensor de type entier ou booléen (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier ou booléen (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

flux de sortie

Sémantique

Écrit inputs dans le flux de sortie et génère un jeton result.

La sémantique de outfeed_config est définie par l'implémentation.

Entrées

Libellé Nom Type
(I1) inputs nombre varié de Tensors ou de Tensors quantifiés
(I2) token token
(I3). outfeed_config constante de type string

Sorties

Nom Type
result token

Exemples

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Autres exemples

pad

Sémantique

Développe operand par une marge intérieure autour du Tensor, ainsi qu'entre les éléments du Tensor avec la padding_value donnée.

edge_padding_low et edge_padding_high spécifient la quantité de marge intérieure ajoutée dans la limite inférieure (à côté de l'index 0) et dans la limite supérieure (à côté de l'indice le plus élevé) de chaque dimension. La valeur de la marge intérieure négative peut être négative, la valeur absolue indiquant le nombre d'éléments à supprimer de la dimension spécifiée.

interior_padding spécifie la quantité de marge intérieure ajoutée entre deux éléments de chaque dimension, qui ne peut pas être négative. La marge intérieure intérieure se produit avant le remplissage du bord, de sorte que la marge intérieure négative supprime les éléments de l'opérande de remplissage intérieur.

Plus formellement, result[result_index] est défini comme suit:

  • operand[operand_index] si result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C4)
(I2) padding_value Tensor à 0 dimension ou Tensor quantifié par Tensor (C1)
(I3). edge_padding_low Constante de Tensor unidimensionnelle de type si64 (C1), (C4)
(I4). edge_padding_high Constante de Tensor unidimensionnelle de type si64 (C1), (C4)
(I5). interior_padding Constante de Tensor unidimensionnelle de type si64 (C2-C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C3-C6)

Contraintes

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemples

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Autres exemples

partition_id

Sémantique

Génère partition_id du processus actuel.

Sorties

Nom Type
result Tensor à 0 dimension de type ui32

Exemples

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Autres exemples

Popcnt

Sémantique

Effectue un comptage par élément du nombre de bits définis dans le Tensor operand et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(operand) = type(result).

Exemples

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Autres exemples

puissance

Sémantique

Effectue une exponentielle par élément du Tensor lhs par le Tensor rhs et produit un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les entiers: exponentielle entière.
  • Pour les flottants: pow à partir de la norme IEEE-754.
  • Pour les nombres complexes: exponentielle complexe.
  • Pour les types quantifiés: dequantize_op_quantize(power, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Autres exemples

real

Sémantique

Extrait la partie réelle, élément par élément, de operand et génère un Tensor result. Plus formellement, pour chaque élément x : real(x) = is_complex(x) ? real_part(x) : x.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) est défini comme suit :
    • complex_element_type(element_type(operand)) si is_complex(operand).
    • element_type(operand) dans les autres cas.

Exemples

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Autres exemples

recv

Sémantique

Reçoit les données d'un canal avec channel_id et génère results.

Si is_host_transfer est défini sur true, l'opération transfère les données de l'hôte. Sinon, les données seront transférées depuis un autre appareil. Cela signifie que l'implémentation est définie. Cet indicateur duplique les informations fournies dans channel_type. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir (#666).

results sont constitués de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (#670).

Entrées

Libellé Nom Type Contraintes
(I1) token token (C4)
(I2) channel_id constante de type si64
(I3). channel_type énumération de DEVICE_TO_DEVICE et HOST_TO_DEVICE (C1)
(I4). is_host_transfer constante de type i1 (C1)

Sorties

Nom Type Contraintes
results nombre varié de Tensors, de Tensors quantifiés ou de jetons (C2-C4)

Contraintes

  • (C1) channel_type est défini comme suit :
    • HOST_TO_DEVICE si is_host_transfer = true,
    • DEVICE_TO_DEVICE dans les autres cas.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Exemples

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Autres exemples

reduce

Sémantique

Applique une fonction de réduction body à inputs et init_values le long de dimensions et génère des Tensors results.

L'ordre des réductions est défini par l'implémentation, ce qui signifie que body et init_values doivent former un monoid pour garantir que l'opération produit les mêmes résultats pour toutes les entrées de toutes les implémentations. Cependant, cette condition ne s'applique pas à de nombreuses réductions courantes. Par exemple, l'addition à virgule flottante pour body et le zéro pour init_values ne forment pas réellement un monoid, car l'addition à virgule flottante n'est pas associative.

Plus formellement, results...[j0, ..., jR-1] = reduce(input_slices_converted) où:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], où les : sont insérés au niveau de dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) pour une arborescence binaire schedule où :
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule est une arborescence binaire complète définie par l'implémentation dont le balayage dans l'ordre comprend les éléments suivants :
    • Valeurs input_slices_converted...[index], pour tous les index de index_space(input_slices_converted) dans l'ordre lexicographique croissant de index.
    • Il est intercalé avec un nombre défini de init_values_converted à des positions définies par l'implémentation.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre varié de Tensors ou Tensors quantifiés par Tensor (C1-C4), (C6), (C7)
(I2) init_values nombre varié de Tensors à 0 dimension ou de Tensors quantifiés par Tensor (C2), (C3)
(I3). dimensions Constante de Tensor unidimensionnelle de type si64 (C4), (C5), (C7)
(I4). body function (C6)

Sorties

Nom Type Contraintes
results nombre varié de Tensors ou Tensors quantifiés par Tensor (C3), (C7), (C8)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), où is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), si ce n'est que les tailles de dimensions de inputs... correspondant à dimensions ne sont pas incluses.
  • (C8) element_type(results[i]) = Ei pour tous les i dans [0,N).

Exemples

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Autres exemples

reduce_precision

Sémantique

Effectue la conversion de operand au niveau de l'élément en un autre type à virgule flottante qui utilise exponent_bits et mantissa_bits, puis revient au type à virgule flottante d'origine et génère un Tensor output.

Plus formellement:

  • Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine à la valeur représentable la plus proche avec mantissa_bits à l'aide de la sémantique roundToIntegralTiesToEven.
  • Ensuite, si la valeur de mantissa_bits est inférieure au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués à mantissa_bits.
  • Ensuite, si les bits d'exposant du résultat intermédiaire ne correspondent pas à la plage fournie par exponent_bits, le résultat intermédiaire déborde à l'infini avec le signe d'origine ou est insuffisant à zéro à l'aide du signe d'origine.
  • Pour les types quantifiés, exécute dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)
(I2) exponent_bits constante de type si32 (C2)
(I3). mantissa_bits constante de type si32 (C3)

Sorties

Nom Type Contraintes
output Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Exemples

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Autres exemples

reduce_scatter

Sémantique

Dans chaque groupe de processus de la grille de processus StableHLO, effectue une réduction sur les valeurs du Tensor operand de chaque processus à l'aide de computations, divise le résultat de la réduction en plusieurs parties avec scatter_dimension, puis disperse les parties divisées entre les processus pour produire result.

L'opération divise la grille de processus StableHLO en process_groups, défini comme suit:

  • cross_replica(replica_groups) si channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si channel_id > 0 and use_global_device_ids = true.

Ensuite, dans chaque process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] pour tous les sender dans process_group, où receiver_index = process_group.index(receiver).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension constante de type si64 (C1), (C2), (C8)
(I3). replica_groups Constante de Tensor bidimensionnelle de type si64 (C3-C5)
(I4). channel_id constante de type si64 (C6)
(I5). use_global_device_ids constante de type i1 (C6)
(I6). computation function (C7)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C8-C9)

Contraintes

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) est défini comme suit :
    • num_replicas si cross_replica est utilisé.
    • num_replicas si cross_replica_and_partition est utilisé.
    • num_processes si flattened_ids est utilisé.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Si la valeur est use_global_device_ids = true, alors channel_id > 0.
  • (C7) computation est de type (tensor<E>, tensor<E>) -> (tensor<E>), où is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand), sauf :
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Exemples

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Autres exemples

reduce_window

Sémantique

Applique une fonction de réduction body aux fenêtres de inputs et init_values, et génère results.

Le schéma suivant montre comment les éléments de results... sont calculés à partir de inputs... à l'aide d'un exemple concret.

Plus formellement, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (voir réduire) où:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre varié de Tensors ou Tensors quantifiés par Tensor (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values nombre varié de Tensors à 0 dimension ou de Tensors quantifiés par Tensor (C1), (C13)
(I3). window_dimensions Constante de Tensor unidimensionnelle de type si64 (C4), (C5), (C15)
(I4). window_strides Constante de Tensor unidimensionnelle de type si64 (C6), (C7), (C15)
(I5). base_dilations Constante de Tensor unidimensionnelle de type si64 (C8), (C9), (C15)
(I6). window_dilations Constante de Tensor unidimensionnelle de type si64 (C10), (C11), (C15)
(I7) padding Constante de Tensor bidimensionnelle de type si64 (C12), (C15)
(I8). body function (C13)

Sorties

Nom Type Contraintes
results nombre varié de Tensors ou Tensors quantifiés par Tensor (C1), (C14-C16)

Contraintes

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), où is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows où :
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei pour tous les i dans [0,N).

Exemples

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Autres exemples

reste

Sémantique

Effectue le reste par élément des Tensors de dividende lhs et rhs du diviseur, et produit un Tensor result.

Plus formellement, le signe du résultat est tiré du dividende, et la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur. Le reste est calculé comme suit : lhs - d * rhs, où d est calculé comme suit :

  • Pour les entiers: stablehlo.divide(lhs, rhs).
  • Pour les valeurs à virgule flottante: division(lhs, rhs) à partir de la norme IEEE-754 avec l'attribut d'arrondi roundTowardZero.
  • Pour les nombres complexes: à déterminer (#997).
  • Pour les types quantifiés :
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Pour les types d'éléments à virgule flottante, cette opération diffère de l'opération remainder de la spécification IEEE-754, où d est une valeur intégrale la plus proche de la valeur exacte de lhs/rhs avec un lien égal.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)
(I2) rhs Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Autres exemples

replica_id

Sémantique

Génère replica_id du processus actuel.

Sorties

Nom Type
result Tensor à 0 dimension de type ui32

Exemples

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Autres exemples

remodeler

Sémantique

Effectue le remodelage du Tensor operand en un Tensor result. Conceptuellement, cela revient à conserver la même représentation canonique, mais en modifiant éventuellement la forme, par exemple de tensor<2x3xf32> à tensor<3x2xf32> ou tensor<6xf32>.

Plus formellement, result[result_index] = operand[operand_index], où result_index et operand_index ont la même position dans l'ordre lexicographique de index_space(result) et index_space(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C3)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1-C3)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), si ce n'est que quantization_dimension(operand) et quantization_dimension(result) peuvent différer.
  • (C2) size(operand) = size(result).
  • (C3) Si is_per_axis_quantized(operand) :
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Exemples

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Autres exemples

reverse

Sémantique

Inverse l'ordre des éléments dans operand selon le dimensions spécifié et génère un Tensor result. Plus formellement, result[result_index] = operand[operand_index] où:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 si d dans dimensions.
  • operand_index[d] = result_index[d] dans les autres cas.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1), (C3)
(I2) dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C3)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C3)

Contraintes

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Exemples

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Autres exemples

rng

Sémantique

Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution et génère un Tensor result d'une forme donnée shape.

Si la valeur est rng_distribution = UNIFORM, les nombres aléatoires sont générés en suivant la distribution uniforme sur l'intervalle [a, b). Si la valeur est a >= b, le comportement n'est pas défini.

Si la valeur est rng_distribution = NORMAL, les nombres aléatoires sont générés en suivant la distribution normale avec une moyenne de a et un écart-type de b. Si la valeur est b < 0, le comportement n'est pas défini.

La manière exacte dont les nombres aléatoires sont générés est définie par l'implémentation. Par exemple, elles peuvent être déterministes ou non, et utiliser ou non un état caché.

Lors de conversations avec de nombreuses personnes concernées, cette opération est devenue obsolète. Nous prévoyons donc d'envisager de la supprimer à l'avenir (#597).

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor à 0 dimension de type entier, booléen ou à virgule flottante (C1), (C2)
(I2) b Tensor à 0 dimension de type entier, booléen ou à virgule flottante (C1), (C2)
(I3). shape Constante de Tensor unidimensionnelle de type si64 (C3)
(I4). rng_distribution énumération de UNIFORM et NORMAL (C2)

Sorties

Nom Type Contraintes
result Tensor de type entier, booléen ou à virgule flottante (C1-C3)

Contraintes

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Si la valeur est rng_distribution = NORMAL, alors is_float(a).
  • (C3) shape(result) = shape.

Exemples

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Sémantique

Renvoie une output remplie de bits aléatoires uniformes et un état de sortie mis à jour output_state à l'aide de l'algorithme de générateur de nombres pseudo-aléatoires rng_algorithm en fonction d'un état initial initial_state. Le résultat sera une fonction déterministe de initial_state, mais il n'est pas garanti qu'il soit déterministe entre les implémentations.

rng_algorithm est l'un des éléments suivants :

  • DEFAULT: algorithme défini par l'implémentation.
  • THREE_FRY: variante définie par l'implémentation de l'algorithme Threefry*.
  • PHILOX: variante de l'algorithme Philox définie par l'implémentation*.

* Voir: Salmon et al. SC 2011. Les nombres aléatoires parallèles: 1, 2, 3...

Entrées

Libellé Nom Type Contraintes
(I1) rng_algorithm énumération de DEFAULT, THREE_FRY et PHILOX (C2)
(I2) initial_state Tensor unidimensionnel de type ui64 (C1), (C2)

Sorties

Nom Type Contraintes
output_state Tensor unidimensionnel de type ui64 (C1)
output Tensor de type entier ou à virgule flottante

Contraintes

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) est défini comme suit :
    • L'implémentation est définie si rng_algorithm = DEFAULT.
    • 2 si rng_algorithm = THREE_FRY.
    • 2 ou 3 si rng_algorithm = PHILOX.

Exemples

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Sémantique

Effectue un arrondi par élément par élément vers l'entier le plus proche, en dissociant les liaisons de zéro sur le Tensor operand, et en générant un Tensor result. Met en œuvre l'opération roundToIntegralTiesToAway à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Autres exemples

round_nearest_even

Sémantique

Effectue un arrondi par élément par élément vers l'entier le plus proche, en rompant les liens avec l'entier pair sur le Tensor operand et génère un Tensor result. Met en œuvre l'opération roundToIntegralTiesToEven à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(round_nearest_even, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Autres exemples

rigueur

Sémantique

Effectue une opération de racine carrée réciproque par élément sur le Tensor operand et produit un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: rSqrt à partir de la norme IEEE-754.
  • Pour les nombres complexes: racine carrée réciproque complexe.
  • Pour les types quantifiés: dequantize_op_quantize(rsqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Autres exemples

scatter

Sémantique

Génère des Tensors results, qui sont égaux aux Tensors inputs, à la différence que plusieurs tranches spécifiées par scatter_indices sont mises à jour avec les valeurs updates à l'aide de update_computation.

Le schéma suivant montre comment les éléments de updates... sont mappés sur des éléments de results... à l'aide d'un exemple concret. Le schéma choisit quelques exemples d'indices updates... et explique en détail à quels indices results... ils correspondent.

Plus formellement, pour tous les update_index de index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index est défini comme suit :
    • scatter_indices[si0, ..., :, ..., siN], où les si sont des éléments individuels dans update_scatter_index, et : est inséré au niveau de l'index index_vector_dim, si index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] dans les autres cas.
  • Pour d_input dans axes(inputs[0]) :
    • full_start_index[d_input] = start_index[d_start] si d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 dans les autres cas.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], où wi sont des éléments individuels dans update_window_index, et 0 est inséré aux index de inserted_window_dims.
  • result_index = full_start_index + full_window_index.

Compte tenu de cela, results = exec(schedule, inputs), où:

  • schedule est une permutation de index_space(updates[0]) définie par l'implémentation.
  • exec([update_index, ...], results) = exec([...], updated_results) où :
    • Si result_index est compris dans les limites de shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results est une copie de results avec results...[result_index] défini sur updated_values....
    • Sinon, procédez comme suit :
    • updated_results = results.
  • exec([], results) = results.

Si indices_are_sorted est défini sur true, l'implémentation peut supposer que les scatter_indices sont triés par rapport à scatter_dims_to_operand_dims. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Si unique_indices est défini sur true, l'implémentation peut supposer que tous les indices de dispersion de result_index sont uniques. Si unique_indices est défini sur true, mais que les indices de dispersion ne sont pas uniques, le comportement n'est pas défini.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre varié de Tensors ou Tensors quantifiés par Tensor (C1), (C2), (C4-C6), (C10), (C13), (C15-C16)
(I2) scatter_indices Tensor de type entier (C4), (C11), (C14)
(I3). updates nombre varié de Tensors ou Tensors quantifiés par Tensor (C3-C6), (C8)
(I4). update_window_dims Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C7), (C8)
(I5). inserted_window_dims Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C9), (C10)
(I6). scatter_dims_to_operand_dims Constante de Tensor unidimensionnelle de type si64 (C11-C13)
(I7) index_vector_dim constante de type si64 (C4), (C11), (C14)
(I8). indices_are_sorted constante de type i1
(I9). unique_indices constante de type i1
(I10). update_computation function (C15)

Sorties

Nom Type Contraintes
results nombre varié de Tensors ou Tensors quantifiés par Tensor (C15-C17)

Contraintes

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) où :
    • update_scatter_dim_sizes = shape(scatter_indices), si ce n'est que la taille de la dimension de scatter_indices correspondant à index_vector_dim n'est pas incluse.
    • update_window_dim_sizes <= shape(inputs[0]), si ce n'est que les dimensions des dimensions dans inputs[0] correspondant à inserted_window_dims ne sont pas incluses.
    • combine place update_scatter_dim_sizes sur les axes correspondant à update_scatter_dims et update_window_dim_sizes aux axes correspondant à update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(inserted_window_dims) and is_sorted(update_window_dims).
  • (C10) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C11) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C12) is_unique(scatter_dims_to_operand_dims).
  • (C13) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C14) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C15) update_computation est de type (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), où is_promotable(element_type(inputs[i]), Ei).
  • (C16) shape(inputs...) = shape(results...).
  • (C17) element_type(results[i]) = Ei pour tous les i dans [0,N).

Exemples

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

Autres exemples

select

Sémantique

Génère un Tensor result, où chaque élément est sélectionné à partir du Tensor on_true ou on_false en fonction de la valeur de l'élément correspondant de pred. Plus formellement, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], où pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Pour les types quantifiés, exécute dequantize_select_quantize(pred, on_true, on_false, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) pred Tensor de type i1 (C1)
(I2) on_true Tensor ou Tensor quantifié par Tensor (C1-C2)
(I3). on_false Tensor ou Tensor quantifié par Tensor (C2)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C2)

Contraintes

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Exemples

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Autres exemples

select_and_scatter

Sémantique

Disperse les valeurs du Tensor source à l'aide de scatter en fonction du résultat du reduce_window du Tensor input à l'aide de select et génère un Tensor result.

Le schéma suivant montre comment les éléments de result sont calculés à partir de operand et source à l'aide d'un exemple concret.

Plus formellement:

  • selected_values = reduce_window_without_init(...) avec les entrées suivantes:

    • `inputs = [opérande].
    • window_dimensions, window_strides et padding, qui sont utilisés tels quels.
    • base_dilations = windows_dilations = 1.
    • body est défini comme suit:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    E = element_type(operand) et reduce_window_without_init fonctionnent exactement comme reduce_window, sauf que le schedule du reduce sous-jacent (voir Réduire) n'inclut pas de valeurs d'initialisation. Actuellement, ce qui se passe si la fenêtre correspondante ne possède pas de valeur n'est pas spécifiée (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) où:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index si selected_values[source_index] contient l'élément operand de operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C4), (C6), (C8-C11)
(I2) source Tensor ou Tensor quantifié par Tensor (C1), (C2)
(I3). init_value Tensor à 0 dimension ou Tensor quantifié par Tensor (C3)
(I4). window_dimensions Constante de Tensor unidimensionnelle de type si64 (C2), (C4), (C5)
(I5). window_strides Constante de Tensor unidimensionnelle de type si64 (C2), (C6), (C7)
(I6). padding Constante de Tensor bidimensionnelle de type si64 (C2), (C8)
(I7) select function (C9)
(I8). scatter function (C10)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C11-C12)

Contraintes

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows où :
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select est de type (tensor<E>, tensor<E>) -> tensor<i1>, où E = element_type(operand).
  • (C10) scatter est de type (tensor<E>, tensor<E>) -> tensor<E>, où is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Exemples

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Autres exemples

envoyer

Sémantique

Envoie inputs à un canal channel_id et génère un jeton result.

Si is_host_transfer est défini sur true, l'opération transfère les données à l'hôte. Sinon, les données sont transférées vers un autre appareil. Cela signifie que l'implémentation est définie. Cet indicateur duplique les informations fournies dans channel_type. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir (#666).

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre varié de Tensors ou de Tensors quantifiés
(I2) token token
(I3). channel_id constante de type si64
(I4). channel_type énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST (C1)
(I5). is_host_transfer constante de type i1 (C1)

Sorties

Nom Type
result token

Contraintes

  • (C1) channel_type est défini comme suit :
    • DEVICE_TO_HOST si is_host_transfer = true,
    • DEVICE_TO_DEVICE dans les autres cas.

Exemples

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Autres exemples

shift_left

Sémantique

Effectue un décalage à gauche par élément sur le Tensor lhs en fonction d'un nombre de bits rhs et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier (C1)
(I2) rhs Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Autres exemples

shift_right_arithmetic

Sémantique

Effectue un décalage arithmétique à droite par élément sur le Tensor lhs en fonction d'un nombre de bits de rhs, et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier (C1)
(I2) rhs Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Autres exemples

shift_right_logical

Sémantique

Effectue un décalage logique élémentaire vers la droite sur le Tensor lhs en fonction d'un nombre de bits rhs et génère un Tensor result.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier (C1)
(I2) rhs Tensor de type entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Autres exemples

signe "=".

Sémantique

Renvoie le signe de operand par élément et génère un Tensor result. Plus formellement, pour chaque élément x, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Pour les types quantifiés, exécute dequantize_op_quantize(sign, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor d'un entier signé, d'un type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor d'un entier signé, d'un type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Autres exemples

sinus

Sémantique

Effectue une opération sinus par élément sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: sin à partir de la norme IEEE-754.
  • Pour les nombres complexes: sinus complexe.
  • Pour les types quantifiés: dequantize_op_quantize(sine, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Autres exemples

tranche

Sémantique

Extrait une tranche de operand à l'aide d'indices de départ calculés de manière statique et génère un Tensor result. start_indices contient les index de départ de la tranche pour chaque dimension, limit_indices contient les index de fin (exclusifs) de la tranche pour chaque dimension, et strides contient les progrès de chaque dimension.

Plus formellement, result[result_index] = operand[operand_index], où operand_index = start_indices + result_index * strides.

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié par Tensor (C1-C3), (C5)
(I2) start_indices Constante de Tensor unidimensionnelle de type si64 (C2), (C3), (C5)
(I3). limit_indices Constante de Tensor unidimensionnelle de type si64 (C2), (C3), (C5)
(I4). strides Constante de Tensor unidimensionnelle de type si64 (C2), (C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié par Tensor (C1), (C5)

Contraintes

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Exemples

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Autres exemples

sort

Sémantique

Trie les tranches unidimensionnelles de inputs le long de la dimension dimension en fonction d'une comparator et génère results.

Contrairement aux entrées similaires dans d'autres opérations, dimension autorise les valeurs négatives, avec la sémantique décrite ci-dessous. À l'avenir, cette action peut être interdite pour des raisons de cohérence (#1377).

Si is_stable est "true", le tri est stable, c'est-à-dire que l'ordre relatif des éléments considérés comme égaux par le comparateur est conservé. Dans le cas où il y a une seule entrée, les deux éléments e1 et e2 sont considérés comme égaux par le comparateur si et seulement si comparator(e1, e2) = comparator(e2, e1) = false. Reportez-vous à la formalisation ci-dessous pour découvrir comment cela se généralise à plusieurs entrées.

Plus formellement, pour tous les result_index de index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], où riN sont des éléments individuels dans result_index, et : est inséré à adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • sort trie une tranche unidimensionnelle dans l'ordre non décroissant en s'attendant à ce que comparator_together renvoie true si l'argument de gauche est inférieur au deuxième argument de droite.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Entrées

Libellé Nom Type Contraintes
(I1) inputs nombre varié de Tensors ou Tensors quantifiés par Tensor (C1-C5)
(I2) dimension constante de type si64 (C4)
(I3). is_stable constante de type i1
(I4). comparator function (C5)

Sorties

Nom Type Contraintes
results nombre varié de Tensors ou Tensors quantifiés par Tensor (C2), (C3)

Contraintes

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, où R = rank(inputs[0]).
  • (C5) comparator est de type (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, où Ei = element_type(inputs[i]).

Exemples

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Autres exemples

sqrt

Sémantique

Effectue une opération de racine carrée élément par élément sur le Tensor operand et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: squareRoot à partir de la norme IEEE-754.
  • Pour les nombres complexes: racine carrée complexe.
  • Pour les types quantifiés: dequantize_op_quantize(sqrt, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Autres exemples

subtract

Sémantique

Effectue une soustraction par élément de deux Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les nombres entiers: soustraction d'entiers.
  • Pour les flottants: subtraction à partir de la norme IEEE-754.
  • Pour les nombres complexes: soustraction complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)
(I2) rhs Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemples

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Autres exemples

Tanh

Sémantique

Effectue une opération de tangente hyperbolique par élément sur le Tensor operand et produit un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les flottants: tanh à partir de la norme IEEE-754.
  • Pour les nombres complexes: tangente hyperbolique complexe.
  • Pour les types quantifiés :
    • dequantize_op_quantize(tanh, operand, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_type(operand) = baseline_type(result).

Exemples

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Autres exemples

transposer

Sémantique

Permute les dimensions du Tensor operand à l'aide de permutation et génère un Tensor result. Plus formellement, result[result_index] = operand[operand_index], où result_index[d] = operand_index[permutation[d]].

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor ou Tensor quantifié (C1-C4)
(I2) permutation Constante de Tensor unidimensionnelle de type si64 (C2-C4)

Sorties

Nom Type Contraintes
result Tensor ou Tensor quantifié (C1), (C3-C4)

Contraintes

  • (C1) element_type(result) est donné par :
    • element_type(operand), si !is_per_axis_quantized(operand).
    • element_type(operand), si ce n'est que quantization_dimension(operand) et quantization_dimension(result) peuvent différer.
  • (C2) permutation est une permutation de range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Si la valeur est is_per_axis_quantized(result), alors quantization_dimension(operand) = permutation(quantization_dimension(result)).

Exemples

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Autres exemples

triangular_solve

Sémantique

Résoudre des lots de systèmes d'équations linéaires avec des matrices triangulaires inférieures ou supérieures.

Plus formellement, avec a et b, result[i0, ..., iR-3, :, :] est la solution pour op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] lorsque left_side est true ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] lorsque left_side est false, et résout la variable xop(a) est déterminé par transpose_a, qui peut être l'un des éléments suivants:

  • NO_TRANSPOSE: effectuer l'opération en utilisant a tel quel.
  • TRANSPOSE: effectuer une opération sur la transposition de a.
  • ADJOINT: effectuer une opération sur la transposition conjuguée de a.

Les données d'entrée sont lues uniquement à partir du triangle inférieur de a. Sinon, lower correspond à true ou au triangle supérieur de a. Les données de sortie sont renvoyées dans le même triangle. Les valeurs de l'autre triangle sont définies par l'implémentation.

Si unit_diagonal est "true", l'implémentation peut supposer que les éléments diagonaux de a sont égaux à 1. Sinon, le comportement n'est pas défini.

Pour les types quantifiés, exécute dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) a Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1-C3)
(I2) b Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1-C4)
(I3). left_side constante de type i1 (C3)
(I4). lower constante de type i1
(I5). unit_diagonal constante de type i1
(I6). transpose_a énumération de NO_TRANSPOSE, TRANSPOSE et ADJOINT

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor (C1)

Contraintes

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) La relation entre shape(a) et shape(b) est définie comme suit :
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Exemples

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Sémantique

Génère un tuple result à partir des valeurs val.

Entrées

Libellé Nom Type Contraintes
(I1) val nombre varié de valeurs (C1)

Sorties

Nom Type Contraintes
result tuple (C1)

Contraintes

  • (C1) result est de type tuple<E0, ..., EN-1>, où Ei = type(val[i]).

Exemples

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Autres exemples

uniform_dequantize

Sémantique

Effectue la conversion par élément du Tensor quantifié operand en Tensor à virgule flottante result en fonction des paramètres de quantification définis par le type operand.

Plus formellement : result = dequantize(operand).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor de type à virgule flottante (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Exemples

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Sémantique

Effectue la conversion par élément d'un Tensor à virgule flottante ou d'un Tensor quantifié operand en un Tensor quantifié result en fonction des paramètres de quantification définis par le type result.

Plus formellement,

  • Si is_float(operand) :
    • result = quantize(operand, type(result)).
  • Si is_quantized(operand) :
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Entrées

Libellé Nom Type Contraintes
(I1) operand Tensor de type à virgule flottante ou quantifié (C1), (C2)

Sorties

Nom Type Contraintes
result Tensor quantifié (C1), (C2)

Contraintes

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Exemples

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

alors que

Sémantique

Génère le résultat de l'exécution de la fonction body zéro fois ou plus, tandis que la fonction cond génère true. Plus formellement, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Le comportement d'une boucle infinie est à déterminer (#383).

Entrées

Libellé Nom Type Contraintes
(I1) operand nombre varié de Tensors, de Tensors quantifiés ou de jetons (C1-C3)
(I2) cond function (C1)
(I3). body function (C2)

Sorties

Nom Type Contraintes
results nombre varié de Tensors, de Tensors quantifiés ou de jetons (C3)

Contraintes

  • (C1) cond est de type (T0, ..., TN-1) -> tensor<i1>, où Ti = type(operand[i]).
  • (C2) body est de type (T0, ..., TN-1) -> (T0, ..., TN-1), où Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Exemples

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Autres exemples

Xor

Sémantique

Effectue l'opération XOR par élément des deux Tensors lhs et rhs, et génère un Tensor result. Selon le type d'élément, effectue les opérations suivantes:

  • Pour les valeurs booléennes: logique XOR.
  • Pour les entiers: opération XOR au niveau du bit.

Entrées

Libellé Nom Type Contraintes
(I1) lhs Tensor de type booléen ou entier (C1)
(I2) rhs Tensor de type booléen ou entier (C1)

Sorties

Nom Type Contraintes
result Tensor de type booléen ou entier (C1)

Contraintes

  • (C1) type(lhs) = type(rhs) = type(result).

Exemples

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Exécution

Exécution séquentielle

Pour exécuter un programme StableHLO, il faut fournir les valeurs d'entrée à la fonction main et calculer les valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées en exécutant le graphe des opérations en mode root dans l'opération return correspondante.

L'ordre d'exécution est défini par l'implémentation, à condition qu'il soit aligné sur Dataflow, c'est-à-dire si les opérations sont exécutées avant leur utilisation. Dans StableHLO, toutes les opérations ayant des effets secondaires consomment un seul jeton (plusieurs jetons peuvent être multiplexés en un seul jeton via after_all). L'ordre d'exécution des effets secondaires est donc également aligné sur Dataflow. Les ordres d'exécution possibles de l'exemple de programme ci-dessus sont %0%1%2%3%4return ou %3%0%1%2%4return.

Plus formellement, un processus StableHLO est une combinaison des éléments suivants : 1) un programme StableHLO, 2) des états d'opération (pas encore exécutés, déjà exécutés) et 3) des valeurs intermédiaires sur lesquelles le processus travaille. Le processus commence par les valeurs d'entrée dans la fonction main, passe par le graphique des opérations mettant à jour les états des opérations et les valeurs intermédiaires, et se termine par les valeurs de sortie. D'autres formalisations sont à déterminer (#484).

Exécution parallèle

Les programmes StableHLO peuvent être exécutés en parallèle et sont organisés dans une grille de processus 2D de num_replicas par num_partitions, de type ui32.

Dans la grille de processus StableHLO, num_replicas * num_partitions des processus StableHLO s'exécutent en même temps. Chaque processus possède un process_id = (replica_id, partition_id) unique, où replica_id dans replica_ids = range(num_replicas) et partition_id dans partition_ids = range(num_partitions), qui sont tous deux de type ui32.

La taille de la grille de processus est connue de manière statique pour chaque programme (nous prévoyons d'en faire une partie explicite à l'avenir des programmes StableHLO n° 650), et la position dans la grille de processus est connue de manière statique pour chaque processus. Chaque processus a accès à sa position dans la grille de processus via les opérations replica_id et partition_id.

Dans la grille de processus, les programmes peuvent tous être identiques (dans le style "Programme unique, données multiples") ou différents (dans le style "Programme multiple, données multiples") ou intermédiaire. À l'avenir, nous prévoyons de prendre en charge d'autres idiomes permettant de définir des programmes StableHLO parallèles, y compris GSPMD (#619).

Au sein de la grille de processus, les processus sont pour la plupart indépendants les uns des autres. Ils ont des états d'opération distincts, des valeurs d'entrée/intermédiaire/sortie distinctes et la plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.

Étant donné que l'exécution de la plupart des opérations n'utilise que des valeurs du même processus, il est généralement clair de faire référence à ces valeurs par leur nom. Toutefois, lorsque vous décrivez la sémantique des opérations collectives, cela est insuffisant, et cela génère la notation name@process_id pour faire référence à la valeur name dans un processus particulier. (De ce point de vue, un name non qualifié peut être considéré comme un raccourci pour name@(replica_id(), partition_id()).)

L'ordre d'exécution des processus est défini par l'implémentation, à l'exception de la synchronisation introduite par la communication point à point et les opérations collectives, comme décrit ci-dessous.

Communication point à point

Les processus StableHLO peuvent communiquer entre eux via des canaux StableHLO. Un canal est représenté par un ID positif de type si64. Via différentes opérations, il est possible d'envoyer des valeurs aux canaux et de les recevoir de ces canaux.

Une formalisation supplémentaire (par exemple, la provenance de ces ID de canaux, la manière dont les processus les programmes prennent en compte et le type de synchronisation qu'ils introduit) est à déterminer (#484).

Communication en flux continu

Chaque processus StableHLO a accès à deux interfaces de streaming:

  • Infeed (Flux d'entrée) pouvant être lu.
  • OutFeed sur lequel une écriture est possible.

Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et qui ont donc des processus à leurs deux extrémités, l'implémentation des autres extrémités des flux d'entrée et de sortie est définie.

Une formalisation supplémentaire, par exemple la manière dont la communication par flux influence l'ordre d'exécution et le type de synchronisation qu'elle introduit, reste à déterminer (#484).

Opérations collectives

Il existe six opérations collectives dans StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute et reduce_scatter. Toutes ces opérations divisent les processus de la grille de processus StableHLO en groupes de processus StableHLO et exécutent un calcul conjoint dans chaque groupe de processus, indépendamment des autres groupes de processus.

Au sein de chaque groupe de processus, les opérations collectives peuvent introduire une barrière de synchronisation. D'autres formalisations, telles que le moment exact de cette synchronisation, la manière exacte dont les processus arrivent à cet obstacle et ce qui se passe s'ils ne la rencontrent pas, sont à déterminer (#484).

Si le groupe de processus implique une communication entre partitions, c'est-à-dire s'il existe des processus dans le groupe de processus dont les ID de partition sont différents, l'exécution de l'opération collective a besoin d'un canal, et l'opération collective doit fournir une valeur channel_id positive de type si64. La communication entre les instances répliquées n'a pas besoin de canaux.

Les calculs effectués par les opérations collectives sont spécifiques à des opérations individuelles et sont décrits dans les sections correspondantes ci-dessus. Toutefois, les stratégies par lesquelles la grille de processus est divisée en groupes de processus sont partagées entre ces opérations et sont décrites dans cette section. Plus formellement, StableHLO prend en charge les quatre stratégies suivantes.

cross_replica

Seules les communications entre instances répliquées ont lieu au sein de chaque groupe de processus. Cette stratégie utilise replica_groups (une liste de listes d'ID d'instances répliquées) et calcule un produit cartésien de replica_groups par partition_ids. replica_groups doit comporter des éléments uniques et couvrir tous les replica_ids. Plus formellement, en utilisant la syntaxe Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica produira [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Seules les communications entre partitions ont lieu dans chaque groupe de processus. Cette stratégie utilise partition_groups (une liste de listes d'ID de partition) et calcule un produit cartésien de partition_groups par replica_ids. partition_groups doit comporter des éléments uniques et couvrir l'intégralité des partition_ids. Plus formellement, en utilisant la syntaxe Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Par exemple, pour partition_groups = [[0, 1]] et num_replicas = 4, cross_partition produira [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Des communications entre instances répliquées et entre partitions peuvent avoir lieu au sein de chaque groupe de processus. Cette stratégie utilise replica_groups (une liste d'ID d'instances répliquées) et calcule les produits cartésiens de chaque replica_group par partition_ids. replica_groups doit comporter des éléments uniques et couvrir tous les replica_ids. Plus formellement, en utilisant la syntaxe Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica_and_partition produira [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Cette stratégie utilise flattened_id_groups (une liste de listes d'ID de processus "aplatis" sous la forme replica_id * num_partitions + partition_id) et les transforme en ID de processus. flattened_id_groups doit comporter des éléments uniques et couvrir tous les process_ids. Plus formellement, en utilisant la syntaxe Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 et num_partitions = 2, flattened_ids produira [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Justesse

Pour le moment, StableHLO n'offre aucune garantie de précision numérique, mais cela est susceptible de changer à l'avenir (#1156).

Erreurs

Les programmes StableHLO sont validés par un ensemble complet de contraintes pour les opérations individuelles, ce qui permet d'exclure de nombreuses classes d'erreurs avant l'exécution. Toutefois, des conditions d'erreur restent possibles, par exemple via des dépassements d'entiers, des accès hors limites, etc. Sauf indication contraire, toutes ces erreurs entraînent un comportement défini par l'implémentation, mais cela est susceptible de changer à l'avenir (#1157).

À l'exception de cette règle, les exceptions à virgule flottante dans les programmes StableHLO ont un comportement bien défini. Les opérations qui entraînent des exceptions définies par la norme IEEE-754 (opération non valide, division par zéro, dépassement, dépassement de capacité négatif ou exceptions inexactes) génèrent des résultats par défaut (tels que définis dans la norme) et poursuivent l'exécution sans générer l'indicateur d'état correspondant (comme pour la gestion des exceptions raiseNoFlag de la norme). Les exceptions pour les opérations non standards (par exemple, les calculs arithmétiques complexes et certaines fonctions transcendantes) sont définies par l'implémentation.

Notation

Pour décrire la syntaxe, ce document utilise le type de syntaxe ISO modifié de la syntaxe EBNF (ISO/IEC 14977:1996, Wikipédia). Deux modifications sont apportées: 1) les règles sont définies à l'aide de ::= au lieu de =,

2) La concaténation est exprimée à l'aide de la juxtaposition plutôt que de ,.

Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constantes" et "Opérations"), nous utilisons des formules basées sur la syntaxe Python étendue avec prise en charge de l'expression concise des opérations de tableau, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans les rares cas où de plus grands extraits de code sont nécessaires, nous utilisons la syntaxe Python vanille, qui est toujours introduite explicitement.

Formules

Découvrons le fonctionnement des formules en nous basant sur un exemple tiré de la spécification dot_general. L'une des contraintes pour cette opération est la suivante : dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Les noms utilisés dans cette formule proviennent de deux sources: 1) les fonctions globales (par exemple, dim) ; 2) les définitions des membres de l'élément de programme correspondant (c'est-à-dire les entrées lhs, lhs_batching_dimensions, rhs et rhs_batching_dimensions) définies dans la section "Entrées" de dot_general.

Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python avec certaines extensions orientées vers la concision. Pour donner un sens à la formule, transformons-la en syntaxe Python vanille.

A) Dans ces formules, nous utilisons = pour représenter l'égalité. La première étape pour obtenir la syntaxe Python consiste donc à remplacer = par ==, comme suit : dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) De plus, ces formules acceptent les points de suspension (...), qui transforment les expressions scalaires en expressions de Tensor. En bref, f(xs...) signifie à peu près "pour chaque x scalaire du Tensor xs, calculer un f(x) scalaire, puis renvoyer tous ces résultats scalaires sous forme de résultat de Tensor". En syntaxe Python vanille, notre exemple de formule se transforme en : [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Grâce aux ellipses, il est souvent possible d'éviter de travailler au niveau des scalaires individuels. Toutefois, dans certains cas délicats, une syntaxe semi-informelle de niveau inférieur peut être utilisée comme dans la formule start_indices[bi0, ..., :, ..., biN] de la spécification gather. Par souci de concision, nous ne fournissons pas de formalisme exact pour traduire cette syntaxe en Python vanille, en espérant qu'elle reste intuitivement compréhensible au cas par cas. Veuillez nous indiquer si certaines formules spécifiques semblent opaques et nous essaierons de les améliorer.

Vous remarquerez également que les formules utilisent des points de suspension pour développer toutes sortes de listes, y compris les Tensors, les listes de Tensors (par exemple, qui peuvent provenir d'un nombre variable de Tensors), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas de formalité exacte (par exemple, les listes ne font même pas partie du système de types StableHLO) et s'appuient plutôt sur la compréhension intuitive.

C) Le dernier moyen de notation notable que nous utilisons est la diffusion implicite. Bien que l'opération StableHLO ne soit pas compatible avec la diffusion implicite, les formules le sont également, à des fins de concision. En résumé, si un scalaire est utilisé dans un contexte où un Tensor est attendu, il est diffusé vers la forme attendue.

Pour continuer l'exemple dot_general, voici une autre contrainte : 0 <= lhs_batching_dimensions < rank(lhs). Comme défini dans la spécification dot_general, lhs_batching_dimensions est un Tensor, mais 0 et rank(lhs) sont tous deux scalaires. Une fois la diffusion implicite appliquée, la formule devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Lorsqu'elle est appliquée à une opération dot_general particulière, cette formule sera évaluée par rapport à un Tensor de valeurs booléennes. Lorsque des formules sont utilisées comme contraintes, la contrainte est valable si la formule est évaluée sur true ou sur un Tensor qui ne contient que des éléments true.

Noms

Dans les formules, la portée lexicale comprend: 1) les fonctions globales, 2) les définitions de membre,

3) définitions locales. La liste des fonctions globales est fournie ci-dessous. La liste des définitions d'éléments dépend de l'élément du programme auquel la notation est appliquée:

  • Pour les opérations, les définitions de membre incluent les noms introduits dans les sections "Entrées" et "Sorties".
  • Pour tout le reste, les définitions de membre incluent les parties structurelles de l'élément de programme, nommées d'après les non-terminaux EBNF correspondants. La plupart du temps, les noms de ces parties structurelles sont obtenus en convertissant les noms des non-terminaux en snake case (par exemple, IntegerLiteral => integer_literal). Toutefois, les noms sont parfois abrégés dans le processus (par exemple, QuantizationStorageType => storage_type). Dans ce cas, les noms sont introduits explicitement de la même manière que dans les sections "Entrées" et "Sorties" dans les opérations.
  • En outre, les définitions de membre incluent toujours self pour faire référence à l'élément de programme correspondant.

Valeurs

Lorsque les formules sont évaluées, elles fonctionnent avec les types de valeurs suivants : 1) Value (valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>) ; 2) Placeholder (valeurs futures, par exemple lhs, rhs ou result ; leurs valeurs réelles ne sont pas encore connues, seuls leurs types sont connus), 3) Type (types tels que définis dans la section "Types") ; 4) Function (fonctions globales telles que définies dans la section "Fonctions globales").

Selon le contexte, les noms peuvent faire référence à des valeurs différentes. Plus précisément, la section "Sémantique" des opérations (et des équivalents pour d'autres éléments du programme) définit la logique d'exécution, de sorte que toutes les entrées sont disponibles en tant que Value. En revanche, la section "Contraintes" des opérations (et des équivalents) définit une logique de "compilation", c'est-à-dire un élément qui est généralement exécuté avant l'exécution. Par conséquent, seules les entrées constantes sont disponibles en tant que Value et les autres entrées ne sont disponibles qu'en tant que Placeholder.

Noms Dans "Sémantique" Dans "Contraintes"
Fonctions globales Function Function
Entrées constantes Value Value
Entrées non constantes Value Placeholder
Sorties Value Placeholder
Définitions locales Dépend de la définition Dépend de la définition

Prenons un exemple d'opération transpose:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Pour cette opération, permutation est une constante. Il est donc disponible en tant que Value à la fois dans la sémantique et dans les contraintes. En revanche, operand et result sont disponibles en tant que Value dans la sémantique, mais uniquement en tant que Placeholder dans les contraintes.

Fonctions

Construction de types

Aucune fonction ne peut être utilisée pour construire des types. Nous utilisons directement la syntaxe de type, car elle est généralement plus concise. Par exemple, (tensor<E>, tensor<E>) -> (tensor<E>) au lieu de function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Fonctions sur les types

  • element_type est défini sur les types de Tensors et quantifiés, et renvoie respectivement la partie TensorElementType ou QuantizedTensorElementType du TensorType ou du QuantizedTensorType correspondant.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool vérifie si le type x peut être promu au type y. Lorsque x et y correspondent à des QuantizedTensorElementType, la promotion n'est appliquée qu'au storage_type. Cette version spécifique de la promotion est actuellement utilisée dans le contexte du calcul de la réduction (consultez le document RFC pour plus d'informations).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value est un raccourci pour is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Disponible pour tous les types. Par exemple, is_float(x) renvoie true si x est un FloatType. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour is_type_name(type(x)).

  • max_value(x: Type) -> Value renvoie la valeur maximale d'un TensorElementType. Si x n'est pas une TensorElementType, renvoie None.

  • min_value(x: Type) -> Value renvoie la valeur minimale possible d'un TensorElementType. Si x n'est pas une TensorElementType, renvoie None.

  • member_name(x: Value | Placeholder | Type) -> Any. Disponible pour toutes les définitions de membre member_name de tous types. Par exemple, tensor_element_type(x) renvoie la partie TensorElementType d'un objet TensorType correspondant. Si x est une valeur ou un espace réservé, cette fonction est un raccourci pour member_name(type(x)). Si x n'est pas un type possédant un membre approprié, ou une valeur ou un espace réservé de ce type, renvoie None.

Construction de valeurs

  • operation_name(*xs: Value | Type) -> Value. Disponible pour toutes les opérations. Par exemple, add(lhs, rhs) prend deux valeurs de Tensor lhs et rhs, et renvoie la sortie de l'évaluation de l'opération add avec ces entrées. Pour certaines opérations comme broadcast_in_dim, les types de sorties sont "portants", c'est-à-dire nécessaires pour évaluer une opération. Dans ce cas, la fonction utilise ces types comme arguments.

Fonction sur les valeurs

  • Tous les opérateurs et fonctions Python sont disponibles. Par exemple, les notations d'abonnement et de tranchement de Python peuvent être indexées dans des Tensors, des Tensors quantifiés et des tuples.

  • to_destination_type(x: Value, destination_type: Type) -> Value est défini sur des Tensors et renvoie la valeur convertie de x en fonction des type(x) et destination_type, comme suit:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Une discussion préliminaire a été abordée sur la fusion des opérations convert, uniform_quantize et uniform_dequantize (#1576). Après la fusion, nous n'avons plus besoin de la fonction ci-dessus. Nous pouvons utiliser le nom de l'opération pour convert à la place.

  • is_nan(x: Value) -> Value est défini sur les Tensors et renvoie true si tous les éléments de x sont NaN ou false dans les autres cas. Si x n'est pas un Tensor, elle renvoie None.

  • is_sorted(x: Value) -> Value est défini sur les Tensors et renvoie true si les éléments de x sont triés par ordre croissant par rapport à l'ordre lexicographique croissant de leurs indices, ou false dans le cas contraire. Si x n'est pas un Tensor, la fonction renvoie None.

  • is_unique(x: Value) -> Value est défini sur des Tensors et renvoie true si x n'a pas d'éléments en double, ou false dans le cas contraire. Si x n'est pas un Tensor, elle renvoie None.

  • member_name(x: Value) -> Any est défini pour toutes les définitions de membre member_name de l'ensemble des valeurs. Par exemple, real_part(x) renvoie la partie RealPart d'un objet ComplexConstant correspondant. Si x n'est pas une valeur associée à un membre approprié, elle renvoie None.

  • same(x: Value) -> Value est défini sur les Tensors et renvoie true si les éléments de x sont tous égaux les uns aux autres, ou false dans le cas contraire. Si le Tensor ne comporte aucun élément, il est comptabilisé comme "tous égaux les uns par rapport aux autres ", c'est-à-dire que la fonction renvoie true. Si x n'est pas un Tensor, la fonction renvoie None.

  • split(x: Value, num_results: Value, axis: Value) -> Value est défini sur des Tensors et renvoie des tranches num_results de x le long de l'axe axis. Si x n'est pas un Tensor ou dim(x, axis) % num_results != 0, renvoie None.

Calculs de formes

  • axes(x: Value | Placeholder | Type) -> Value est un raccourci pour range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value est un raccourci pour shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List est un raccourci pour list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value est défini sur des Tensors et renvoie les indices size(x) du TensorType correspondant, trié par ordre lexicographique croissant, par exemple [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Si x n'est pas un type de Tensor, un type de Tensor quantifié, une valeur ou un espace réservé de l'un de ces types, renvoie None.

  • rank(x: Value | Placeholder | Type) -> Value est un raccourci pour size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value est défini dans la section "Fonctions sur les types" via member_name.

  • size(x: Value | Placeholder | Type) -> Value est un raccourci pour reduce(lambda x, y: x * y, shape(x)).

Calculs de quantification

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type est un raccourci pour element_type(baseline_type(x)).

  • baseline_type est défini sur les types de Tensor et les types de Tensor quantifiés, et les transforme en "référence", c'est-à-dire un type ayant la même forme, mais dont les paramètres de quantification du type d'élément sont réinitialisés aux valeurs par défaut. Cette astuce pratique permet de comparer de manière uniforme les types de Tensors et de Tensors quantifiés, ce qui est assez souvent nécessaire. Pour les types quantifiés, cela permet de comparer les types sans tenir compte des paramètres de quantification. Autrement dit, shape, storage_type, expressed_type, storage_min, storage_max et quantization_dimension (pour le type quantifié par axe) doivent tous correspondre, mais scales et zero points peuvent être différents.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize est défini sur des types de Tensor quantifiés et les convertit en types de Tensor à virgule flottante. Pour ce faire, les éléments quantifiés qui représentent des valeurs entières du type de stockage sont convertis en valeurs à virgule flottante correspondantes du type exprimé, à l'aide du point zéro et de l'échelle associées au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize est défini sur les types de Tensors à virgule flottante et les convertit en types de Tensor quantifiés. Pour ce faire, les valeurs à virgule flottante du type exprimé sont converties en valeurs entières correspondantes du type de stockage à l'aide du point zéro et de l'échelle associées au type d'élément quantifié.
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize permet de spécifier des calculs par élément sur des Tensors quantifiés. Elle déquantifie, c'est-à-dire transforme les éléments quantifiés en types exprimés, effectue une opération, puis quantifie (en d'autres termes, reconvertit les résultats en types de stockage). Pour le moment, cette fonction ne fonctionne que pour la quantification par Tensor. La quantification par axe est en cours (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

Calculs en mode grille

  • cross_partition(replica_groups: Value) -> Value. Consultez la section "cross_réplique" ci-dessus.

  • cross_replica(replica_groups: Value) -> Value. Consultez la section "cross_réplique" ci-dessus.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Consultez la section "cross_réplique_and_partition" ci-dessus.

  • flattened_ids(replica_groups: Value) -> Value : consultez la section "flattened_ids" ci-dessus.