StableHLO est un ensemble d'opérations pour les opérations de haut niveau (HLO) dans les modèles de machine learning (ML). StableHLO fonctionne comme une couche de portabilité entre différents frameworks de ML et compilateurs de ML : les frameworks de ML qui produisent des programmes StableHLO sont compatibles avec les compilateurs de ML qui consomment des programmes StableHLO.
Notre objectif est de simplifier et d'accélérer le développement du ML en créant une interopérabilité plus poussée entre les différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et les compilateurs de ML (tels que XLA et IREE). À cette fin, ce document fournit une spécification pour le langage de programmation StableHLO.
Cette spécification comporte trois sections principales. Tout d'abord, la section Programmes décrit la structure des programmes StableHLO, qui consistent en des fonctions StableHLO qui, à leur tour, consistent en des opérations StableHLO. Au sein de cette structure, la section Ops spécifie la sémantique des opérations individuelles. La section Execution fournit une sémantique pour toutes ces opérations exécutées ensemble dans un programme. Enfin, la section Notation décrit la notation utilisée dans l'ensemble de la spécification.
Pour afficher les spécifications d'une version précédente de StableHLO, ouvrez le dépôt à la version taguée de votre choix. Par exemple, la spécification StableHLO v0.19.0. Pour afficher les modifications apportées à chaque version mineure de StableHLO, consultez le journal des versions dans VhloDialect.td.
Programmes
Program ::= {Func}
Les programmes StableHLO se composent d'un nombre arbitraire de fonctions StableHLO.
Vous trouverez ci-dessous un exemple de programme avec une fonction @main
qui comporte trois entrées (%image
, %weights
et %bias
) et une sortie. Le corps de la fonction comporte 6 opérations.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Fonctions
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Les fonctions StableHLO (également appelées fonctions nommées) possèdent un identifiant, des entrées/sorties et un corps. À l'avenir, nous prévoyons d'introduire des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, #740, #744).
Identifiants
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Les identifiants StableHLO sont semblables aux identifiants de nombreux langages de programmation, avec deux particularités : 1) tous les identifiants comportent des sigils qui distinguent les différents types d'identifiants ; 2) les identifiants de valeur peuvent être entièrement numériques pour simplifier la génération de programmes StableHLO.
Types
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Les types StableHLO sont classés en types de valeurs (également appelés types de première classe) qui représentent les valeurs StableHLO et en types non de valeur qui décrivent d'autres éléments de programme. Les types StableHLO sont semblables à ceux de nombreux langages de programmation, la principale particularité étant la nature spécifique au domaine de StableHLO, qui entraîne des résultats inhabituels (par exemple, les types scalaires ne sont pas des types de valeurs).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Les types de Tensor représentent des Tensors, c'est-à-dire des tableaux multidimensionnels. Ils ont une forme et un type d'élément, où une forme représente des tailles de dimension non négatives ou inconnues dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0
à R-1
. Le nombre de dimensions R
est appelé classement. Par exemple, tensor<2x3xf32>
est un type de tenseur avec la forme 2x3
et le type d'élément f32
. Il comporte deux dimensions (ou, en d'autres termes, deux axes) : la dimension 0 et la dimension 1, dont les tailles sont respectivement 2 et 3. Son classement est de 2.
Les formes peuvent être partiellement ou totalement inconnues (dynamiques), par exemple, tensor<?x2xf64>
est partiellement inconnu et tensor<?x?xf64>
est totalement inconnu. Les tailles de dimension dynamique sont représentées à l'aide d'un ?
. Les formes ne peuvent pas être désclassées.
À l'avenir, nous prévoyons d'étendre les types de tenseurs au-delà des tailles de dimension et des types d'éléments, par exemple pour inclure les mises en page (numéro 629) et la sparsité (numéro 1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nom | Type | Contraintes |
---|---|---|
storage_type |
type entier | (C1-C3), (C8) |
storage_min |
constante entière | (C1), (C3), (C7) |
storage_max |
constante d'entier | (C2), (C3), (C7) |
expressed_type |
type à virgule flottante | (C4) |
quantization_dimension |
constante facultative (nombre entier) | (C10-C12) |
scales |
nombre variadique de constantes à virgule flottante | (C4-C6), (C9), (C10), (C13) |
zero_points |
nombre variadique de constantes entières | (C7-C9) |
Les types d'éléments quantifiés représentent les valeurs entières d'un type de stockage compris entre storage_min
et storage_max
(inclus) qui correspondent aux valeurs à virgule flottante d'un type exprimé. Pour une valeur entière donnée i
, la valeur à virgule flottante correspondante f
peut être calculée comme f = (i - zero_point) * scale
, où scale
et zero_point
sont appelés paramètres de quantification. Les éléments storage_min
et storage_max
sont facultatifs dans la grammaire, mais leurs valeurs par défaut sont respectivement min_value(storage_type)
et max_value(storage_type)
. Les types d'éléments quantifiés présentent les contraintes suivantes:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Si
is_empty(quantization_dimension)
, alorssize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Pour le moment, QuantizationScale
est une constante à virgule flottante, mais les échelles basées sur des entiers, représentées par des multiplicateurs et des décalages, suscitent un vif intérêt. Nous prévoyons d'explorer cette question prochainement (#1404).
La sémantique de QuantizationZeroPoint
est en cours de discussion, y compris concernant le type et les valeurs, et sur la possibilité de n'avoir qu'un seul ou plusieurs points zéro dans un type de Tensor quantifié. D'après les résultats de cette discussion, les spécifications concernant les points nuls peuvent changer à l'avenir (numéro 1405).
Une autre discussion en cours concerne la sémantique de QuantizationStorageMin
et QuantizationStorageMax
pour déterminer si des contraintes doivent être imposées à ces valeurs et aux valeurs des tenseurs quantifiés (numéro 1406).
Enfin, nous prévoyons d'explorer la représentation d'échelles et de points zéro inconnus, comme nous prévoyons de le faire pour la représentation de tailles de dimension inconnues (numéro 1407).
Les types de tenseurs quantifiés représentent des tenseurs dont les éléments sont quantifiés. Ces Tensors sont exactement les mêmes que les Tensors standards, sauf que leurs éléments ont des types d'éléments quantifiés, et non des types d'éléments standards.
Dans les tenseurs quantifiés, la quantification peut être par Tensor, ce qui signifie qu'elle peut avoir une valeur scale
et zero_point
pour l'ensemble du Tensor ou par axe, ce qui signifie qu'elle peut avoir plusieurs scales
et zero_points
, une paire par tranche d'une dimension quantization_dimension
particulière. Plus formellement, dans un tenseur t
avec quantification par axe, il existe dim(t, quantization_dimension)
tranches de quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
, etc. Tous les éléments de la i
e tranche utilisent scales[i]
et zero_points[i]
comme paramètres de quantification. Les types de Tensors quantifiés présentent les contraintes suivantes:
- Pour la quantification par tenseur :
- Aucune autre contrainte.
- Pour la quantification par axe :
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Les types de jetons représentent des jetons, c'est-à-dire des valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer l'ordre d'exécution aux opérations, comme décrit dans la section Exécution.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Les types de tuple représentent des tupels, c'est-à-dire des listes hétérogènes. Les tupels sont une ancienne fonctionnalité qui n'existe que pour la compatibilité avec HLO. Dans HLO, les tupels sont utilisés pour représenter les entrées et les sorties variadiques. Dans StableHLO, les entrées et sorties variadiques sont prises en charge en mode natif, et le seul usage des tupels dans StableHLO est de représenter de manière exhaustive l'ABI HLO, par exemple, T
, tuple<T>
et tuple<tuple<T>>
peuvent être sensiblement différents en fonction d'une implémentation particulière. À l'avenir, nous prévoyons d'apporter des modifications à l'ABI HLO, ce qui nous permettra peut-être de supprimer les types de tuple de StableHLO (numéro 598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Les types d'éléments représentent les éléments des types de Tensor. Contrairement à de nombreux langages de programmation, ces types ne sont pas de première classe dans StableHLO. Cela signifie que les programmes StableHLO ne peuvent pas représenter directement les valeurs de ces types (par conséquent, il est courant de représenter les valeurs scalaires de type T
avec des valeurs de tenseur à dimension 0 de type tensor<T>
).
- Le type booléen représente les valeurs booléennes
true
etfalse
. - Les types d'entiers peuvent être signés (
si
) ou non signés (ui
) et avoir l'une des largeurs de bits acceptées (2
,4
,8
,16
,32
ou64
). Les typessiN
signés représentent les valeurs d'entiers de-2^(N-1)
à2^(N-1)-1
inclus, et les typesuiN
non signés représentent les valeurs d'entiers de0
à2^N-1
inclus. - Les types à virgule flottante peuvent être l'un des éléments suivants :
f8E3M4
,f8E4M3
etf8E5M2
: nombres à virgule flottante de 8 bits suivant les conventions IEEE-754.- Types
f8E4M3FN
etf8E5M2
correspondant respectivement aux encodagesE4M3
etE5M2
du format FP8 décrit dans la section Formats FP8 pour l'apprentissage profond. - Types
f8E4M3FNUZ
etf8E5M2FNUZ
correspondant aux encodagesE4M3
etE5M2
des formats FP8 décrits dans la section Formats numériques 8 bits pour les réseaux de neurones profonds. - Type
f8E4M3B11FNUZ
correspondant à l'encodageE4M3
des formats FP8 décrits dans la section Entraînement et inférence à virgule flottante hybride 8 bits (HFP8) pour les réseaux de neurones profonds. - Type de
bf16
correspondant au formatbfloat16
décrit dans la section BFloat16: le secret des hautes performances sur les Cloud TPU. - Types
f16
,f32
etf64
correspondant respectivement aux formatsbinary16
("demi-précision"),binary32
("précision simple") etbinary64
("double précision") décrits dans la norme IEEE 754. - Le type
tf32
correspond au format TensorFloat32 et n'est pas entièrement compatible avec StableHLO. - Types MX (microscaling)
f4E2M1FN
,f6E2M3FN
,f6E3M2FN
etf8E8M0FNU
décrits dans la spécification des formats de microscaling OCP.
- Les types complexes représentent des valeurs complexes ayant une partie réelle et une partie imaginaire du même type d'élément. Les types complexes compatibles sont
complex<f32>
(les deux parties sont de typef32
) etcomplex<f64>
(les deux parties sont de typef64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Les types de fonctions représentent à la fois les fonctions nommées et anonymes. Ils ont des types d'entrée (la liste des types à gauche de ->
) et des types de sortie (la liste des types à droite de ->
). Dans de nombreux langages de programmation, les types de fonction sont de première classe, mais pas dans StableHLO.
StringType ::= 'string'
Le type de chaîne représente des séquences d'octets. Contrairement à de nombreux langages de programmation, le type de chaîne n'est pas de première classe dans StableHLO et n'est utilisé que pour spécifier des métadonnées statiques pour les éléments de programme.
Opérations
Les opérations StableHLO (également appelées opérations) représentent un ensemble fermé d'opérations de haut niveau dans les modèles de machine learning. Comme indiqué ci-dessus, la syntaxe StableHLO est fortement inspirée de MLIR, qui n'est pas nécessairement l'alternative la plus ergonomique, mais qui est sans doute la plus adaptée à l'objectif de StableHLO, qui est de créer une interopérabilité accrue entre les frameworks ML et les compilateurs ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Les opérations StableHLO (également appelées opérations) ont un nom, des entrées/sorties et une signature. Le nom se compose du préfixe stablehlo.
et d'un mnémonique qui identifie de manière unique l'une des opérations compatibles. Vous trouverez ci-dessous la liste complète de toutes les opérations prises en charge.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Les opérations consomment des entrées et produisent des sorties. Les entrées sont classées en valeurs d'entrée (calculées lors de l'exécution), fonctions d'entrée (fournies de manière statique, car dans StableHLO, les fonctions ne sont pas des valeurs de première classe) et attributs d'entrée (également fournis de manière statique). Le type d'entrées et de sorties consommées et produites par une opération dépend de son mnémonique. Par exemple, l'opération add
consomme deux valeurs d'entrée et produit une valeur de sortie. En comparaison, l'opération select_and_scatter
consomme 3 valeurs d'entrée, 2 fonctions d'entrée et 3 attributs d'entrée.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Les fonctions d'entrée (également appelées fonctions anonymes) sont très similaires aux fonctions nommées, à l'exception des points suivants : 1) elles n'ont pas d'identifiant (d'où le nom "anonyme"), 2) elles ne déclarent pas de types de sortie (les types de sortie sont inférés à partir de l'opération return
dans la fonction).
La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir la production Unused
ci-dessus) qui est là pour la compatibilité avec MLIR. Dans MLIR, il existe un concept plus général de "régions" pouvant comporter plusieurs "blocs" d'opérations reliés entre eux via des jump ops. Ces blocs ont des ID qui correspondent à la production Unused
afin de pouvoir les distinguer les uns des autres.
StableHLO ne comporte pas d'opérations de saut. Par conséquent, la partie correspondante de la syntaxe MLIR n'est pas utilisée (mais elle est toujours là).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Les attributs d'entrée ont un nom et une valeur qui correspond à l'une des constantes acceptées. Il s'agit du principal moyen de spécifier des métadonnées statiques pour les éléments de programme. Par exemple, l'opération concatenate
utilise l'attribut dimension
pour spécifier la dimension via laquelle ses valeurs d'entrée sont concaténées. De même, l'opération slice
utilise plusieurs attributs tels que start_indices
et limit_indices
pour spécifier les limites utilisées pour découper la valeur d'entrée.
Pour le moment, les programmes StableHLO dans la nature contiennent parfois des attributs qui ne sont pas décrits dans ce document. À l'avenir, nous prévoyons d'intégrer ces attributs dans l'ensemble d'opérations StableHLO ou de les interdire dans les programmes StableHLO. En attendant, voici la liste de ces attributs:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Métadonnées de lieu (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La signature d'opération se compose des types de toutes les valeurs d'entrée (la liste des types à gauche de ->
) et des types de toutes les valeurs de sortie (la liste des types à droite de ->
). À strictement parler, les types d'entrée sont redondants, et les types de sortie sont presque toujours redondants également (car pour la plupart des opérations StableHLO, les types de sortie peuvent être déduits des entrées). Néanmoins, la signature d'opération fait délibérément partie de la syntaxe StableHLO pour assurer la compatibilité avec MLIR.
Vous trouverez ci-dessous un exemple d'opération dont l'expression mnémotechnique est select_and_scatter
. Elle consomme trois valeurs d'entrée (%operand
, %source
et %init_value
), deux fonctions d'entrée et trois attributs d'entrée (window_dimensions
, window_strides
et padding
). Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée (mais pas les types de fonctions et d'attributs d'entrée qui sont fournis en ligne).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Les constantes StableHLO comportent un littéral et un type qui, ensemble, représentent une valeur StableHLO. En général, le type fait partie de la syntaxe de la constante, sauf lorsqu'il est sans ambiguïté (par exemple, une constante booléenne a un type i1
sans ambiguïté, tandis qu'une constante entière peut avoir plusieurs types possibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Les constantes booléennes représentent les valeurs booléennes true
et false
. Les constantes booléennes sont de type i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Les constantes entières représentent des valeurs entières via des chaînes qui utilisent la notation décimale ou hexadécimale. Les autres bases, telles que les bases binaires ou octales, ne sont pas acceptées. Les constantes entières présentent les contraintes suivantes:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Les constantes à virgule flottante représentent des valeurs à virgule flottante via des chaînes qui utilisent une notation décimale ou scientifique. De plus, la notation hexadécimale peut être utilisée pour spécifier directement les bits sous-jacents au format à virgule flottante du type correspondant. Les constantes à virgule flottante présentent les contraintes suivantes:
- (C1) Si une notation non hexadécimale est utilisée,
is_wellformed(float_literal, float_type)
. - (C2) Si vous utilisez la notation hexadécimale,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Les constantes complexes représentent des valeurs complexes à l'aide de listes d'une partie réelle (qui vient en premier) et d'une partie imaginaire (qui vient en deuxième). Par exemple, (1.0, 0.0) : complex<f32>
représente 1.0 + 0.0i
et (0.0, 1.0) : complex<f32>
représente 0.0 + 1.0i
. L'ordre dans lequel ces parties sont ensuite stockées en mémoire est défini par l'implémentation. Les constantes complexes présentent les contraintes suivantes:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Les constantes de tenseur représentent les valeurs de tenseur à l'aide de listes imbriquées spécifiées via la notation NumPy. Par exemple, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
représente une valeur de Tensor avec le mappage suivant entre les index et les éléments : {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
, {1, 2} => 6
. L'ordre dans lequel ces éléments sont ensuite stockés en mémoire est défini par l'implémentation. Les constantes de tenseur sont soumises aux contraintes suivantes:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, où :has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, où :has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- Sinon,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Les constantes de tenseur quantifié représentent les valeurs de tenseur quantifié à l'aide de la même notation que les constantes de tenseur, avec des éléments spécifiés en tant que constantes de leur type de stockage. Les constantes de tenseur quantifiées présentent les contraintes suivantes :
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Les chaînes littérales sont constituées d'octets spécifiés à l'aide de caractères ASCII et de séquences d'échappement. Ils sont indépendants de l'encodage, de sorte que l'interprétation de ces octets est définie par mise en œuvre. Les littéraux de chaîne sont de type string
.
Opérations
abs
Sémantique
Effectue une opération abs élément par élément sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les entiers signés : module d'entier.
- Pour les nombres à virgule flottante :
abs
d'IEEE-754. - Pour les nombres complexes: module complexe.
- Pour les types quantifiés:
dequantize_op_quantize(abs, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1-C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tenseur de type entier signé ou à virgule flottante, ou tenseur quantifié par tenseur | (C1-C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
est défini comme suit :complex_element_type(element_type(operand))
siis_complex(operand)
.baseline_element_type(operand)
dans les autres cas.
Exemples
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Sémantique
Effectue l'addition par élément de deux Tensors lhs
et rhs
, et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les valeurs booléennes : opérateur logique "OR".
- Pour les entiers: addition d'entiers.
- Pour les nombres à virgule flottante :
addition
d'IEEE-754. - Pour les nombres complexes: addition complexe.
- Pour les types quantifiés :
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur ou un tenseur quantifié ; | (C1-C6) |
(I2) | rhs |
Tensor ou Tensor quantifié | (C1-C5), (C7) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C1-C7) |
Contraintes
- Si l'opération utilise des tenseurs non linéarisables :
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Si l'opération utilise des tenseurs quantifiés :
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Si
is_per_axis_quantized(lhs)
, alorsquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Si
is_per_axis_quantized(rhs)
, alorsquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Sémantique
Assure que les opérations produisant le inputs
sont exécutées avant toutes les opérations qui dépendent de result
. L'exécution de cette opération n'a aucun effet. Elle n'existe que pour établir des dépendances de données entre result
et inputs
.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
Nombre variable de token |
Sorties
Nom | Type |
---|---|
result |
token |
Exemples
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs des tenseurs operands
de chaque processus le long de all_gather_dim
et produit des tenseurs results
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est définie comme suit :
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
operands...@receiver = [operand@sender for sender in process_group]
pour tous lesreceiver
dansprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
pour l'ensemble desprocess
deprocess_group
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operands |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1), (C6) |
(I2) | all_gather_dim |
constante de type si64 |
(C1), (C6) |
(I3) | replica_groups |
Constante de tenseur bidimensionnel de type si64 |
(C2-C4) |
(I4) | channel_id |
constante de type si64 |
(C5) |
(I5) | use_global_device_ids |
constante de type i1 |
(C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C6) |
Contraintes
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Si
use_global_device_ids = true
, alorschannel_id > 0
. - (C6)
type(results...) = type(operands...)
, sauf :dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, applique une fonction de réduction computation
aux valeurs des tenseurs operands
de chaque processus et produit des tenseurs results
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est définie comme suit :
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
results...@process[result_index] = exec(schedule)
pour un arbre binaireschedule
où :exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
est une arborescence binaire définie par l'implémentation dont le balayage dans l'ordre estto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operands |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C5), (C6) |
(I2) | replica_groups |
nombre variadique de constantes de Tensor unidimensionnelles de type si64 |
(C1-C3) |
(I3) | channel_id |
constante de type si64 |
(C4) |
(I4) | use_global_device_ids |
constante de type i1 |
(C4) |
(I5) | computation |
fonction | (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C6-C7) |
Contraintes
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Si
use_global_device_ids = true
, alorschannel_id > 0
. - (C5)
computation
est de type(tensor<E>, tensor<E>) -> (tensor<E>)
, oùis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, il divise les valeurs des Tensors operands
le long de split_dimension
en plusieurs parties, répartit les parties divisées entre les processus, concatène les parties dispersées le long de concat_dimension
et produit des Tensors results
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est définie comme suit :
cross_replica(replica_groups)
sichannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Ensuite, dans chaque process_group
:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
pour tous lessender
dansprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
oùreceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operands |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1-C3), (C9) |
(I2) | split_dimension |
constante de type si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
constante de type si64 |
(C3), (C9) |
(I4) | split_count |
constante de type si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
Constante de Tensor bidimensionnelle de type si64 |
(C5-C8) |
(I6) | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C9) |
Contraintes
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
, sauf sisplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
et
Sémantique
Effectue une opération AND élément par élément sur deux tenseurs lhs
et rhs
, et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les valeurs booléennes : "AND" logique
- Pour les entiers: AND au niveau du bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur de type booléen ou entier ; | (C1) |
(I2) | rhs |
un tenseur de type booléen ou entier ; | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Sémantique
Effectue une opération atan2 élément par élément sur le tenseur lhs
et rhs
, et génère un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
atan2
d'IEEE-754. - Pour les nombres complexes: atan2 complexe.
- Pour les types quantifiés :
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
(I2) | rhs |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Sémantique
Calcule les gradients de plusieurs entrées de batch_norm_training
en rétropropagation à partir de grad_output
et produit les Tensors grad_operand
, grad_scale
et grad_offset
. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1-C3), (C5) |
(I2) | scale |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C4), (C5) |
(I3) | mean |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
(I4) | variance |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
(I5) | grad_output |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C2), (C3) |
(I6) | epsilon |
constante de type f32 |
|
(I7) | feature_index |
constante de type si64 |
(C1), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
grad_operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C2), (C3) |
grad_scale |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
grad_offset |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
etgrad_offset
ont le mêmebaseline_element_type
. - (C3)
operand
,grad_output
etgrad_operand
ont la même forme. - (C4)
scale
,mean
,variance
,grad_scale
etgrad_offset
ont la même forme. - (C5)
size(scale) = dim(operand, feature_index)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Sémantique
Normalise le tenseur operand
dans toutes les dimensions, à l'exception de la dimension feature_index
, et produit un tenseur result
. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1-C7) |
(I2) | scale |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C3) |
(I3) | offset |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
(I4) | mean |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C5) |
(I5) | variance |
Tensor 1D de type à virgule flottante ou quantifié par tenseur | (C2), (C6) |
(I6) | epsilon |
constante de type f32 |
|
(I7) | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C2), (C7) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
etresult
ont le mêmebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Sémantique
Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de la dimension feature_index
, et normalise le tenseur operand
pour produire les tenseurs output
, batch_mean
et batch_var
. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Pour les types quantifiés, effectue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1) |
(I2) | scale |
Tensor 1D à virgule flottante ou quantifié par tenseur | (C2), (C3) |
(I3) | offset |
Tensor 1D à virgule flottante ou quantifié par tenseur | (C2), (C4) |
(I4) | epsilon |
constante de type f32 |
(C1), (C3-C6) |
(I5) | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C7) |
batch_mean |
Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor | (C2), (C5) |
batch_var |
Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor | (C2), (C6) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
etoutput
ont le mêmebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Sémantique
Effectue une opération de cast de bits sur le tenseur operand
et produit un tenseur result
où les bits de l'ensemble du tenseur operand
sont réinterprétés à l'aide du type du tenseur result
.
Plus formellement, étant donné E = element_type(operand)
, E' = element_type(result)
et R = rank(operand)
:
- Si
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Si
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Si
num_bits(E') = num_bits(E)
, alorsbits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
renvoie la représentation en mémoire d'une valeur donnée, et son comportement est défini par l'implémentation, car la représentation exacte des tenseurs est définie par l'implémentation, et la représentation exacte des types d'éléments est également définie par l'implémentation.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié ; | (C1-C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C1-C2) |
Contraintes
- (C1) Étant donnés
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
etR = rank(operand)
:- Si
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Si
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
pour tous les0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Si
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
pour tous les0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Si
- (C2) Si
is_complex(operand) or is_complex(result)
, alorsis_complex(operand) and is_complex(result)
.
Exemples
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Sémantique
Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données du Tensor operand
, et produit un Tensor result
. Plus formellement, result[result_index] = operand[operand_index]
, où pour tous les d
dans axes(operand)
:
operand_index[d] = 0
sidim(operand, d) = 1
.- Sinon,
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié ; | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1), (C3), (C5-C6) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, à l'exception dequantization_dimension(operand)
,scales(operand)
etzero_points(operand)
, qui peuvent différer dequantization_dimension(result)
,scales(result)
etzero_points(result)
, respectivement, dans le cas contraire.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Pour tous les
d
deaxes(operand)
:dim(operand, d) = 1
oudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si
dim(operand, quantization_dimension(operand)) = 1
, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Exemples
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
coque
Sémantique
Génère le résultat de l'exécution d'une seule fonction de branches
en fonction de la valeur de index
. Plus formellement, result = selected_branch()
, où :
selected_branch = branches[index]
si0 <= index < size(branches)
.- Sinon,
selected_branch = branches[-1]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | index |
Tensor de dimension 0 de type si32 |
|
(I2) | branches |
nombre variable de fonctions | (C1-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C4) |
Contraintes
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Exemples
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Sémantique
Effectue une racine cubique élément par élément sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
rootn(x, 3)
d'IEEE-754. - Pour les nombres complexes: racine cubique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(cbrt, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Sémantique
Effectue le plafond par élément du tenseur operand
et produit un tenseur result
.
Implémente l'opération roundToIntegralTowardPositive
de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(ceil, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Sémantique
Calcule la décomposition de Cholesky d'un lot de matrices.
Plus formellement, pour tous les i
dans index_space(result)
, result[i0, ..., iR-3, :, :]
est une décomposition Cholesky de a[i0, ..., iR-3, :, :]
, sous la forme d'une matrice triangulaire inférieure (si lower
est true
) ou triangulaire supérieure (si lower
est false
).
Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict ou le triangle inférieur strict, sont définies par l'implémentation.
Si i
existe et que la matrice d'entrée n'est pas une matrice hermitienne définie positive, le comportement est indéfini.
Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1-C3) |
(I2) | lower |
Constante de tenseur de dimension 0 de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Exemples
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
limiter
Sémantique
Limite chaque élément du tenseur operand
entre une valeur minimale et maximale, et génère un tenseur result
. Plus formellement, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
, où min_element = rank(min) = 0 ? min[] : min[result_index]
, max_element = rank(max) = 0 ? max[] : max[result_index]
. Pour les types quantifiés, exécute dequantize_op_quantize(clamp, min, operand, max, type(result))
.
L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (numéro 560).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | min |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C3) |
(I2) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1-C4) |
(I3) | max |
un tenseur ou un tenseur quantifié par tenseur ; | (C2), (C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C4) |
Contraintes
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du tenseur operand
du processus source aux processus cibles et générez un tenseur result
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est définie comme suit :
cross_replica(replica_groups)
sichannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Ensuite, result@process
est donné par :
operand@process_groups[i, 0]
si uni
existe de sorte que le processus se trouve dansprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C3) |
(I2) | replica_groups |
Nombre variable de constantes de tenseur unidimensionnelles de type si64 |
(C1), (C2) |
(I3) | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C3) |
Contraintes
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, oùN
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C3)
type(result) = type(operand)
.
Exemples
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du tenseur operand
du processus source au processus cible et produit un tenseur result
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est définie comme suit :
cross_replica(source_target_pairs)
sichannel_id <= 0
.cross_partition(source_target_pairs)
sichannel_id > 0
.
Ensuite, result@process
est donné par :
operand@process_groups[i, 0]
, s'il existe uni
tel queprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C5) |
(I2) | source_target_pairs |
Constante de tenseur bidimensionnel de type si64 |
(C1-C4) |
(I3) | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Contraintes
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, oùN
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C5)
type(result) = type(operand)
.
Exemples
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Sémantique
Effectue une comparaison par élément des tenseurs lhs
et rhs
selon comparison_direction
et compare_type
, et produit un tenseur result
.
Les valeurs de comparison_direction
et compare_type
ont la sémantique suivante :
Pour les types d'éléments booléens et entiers :
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Pour les types d'éléments à virgule flottante avec compare_type = FLOAT
, l'opération implémente les opérations IEEE-754 suivantes :
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER
, l'opération utilise la combinaison des opérations totalOrder
et compareQuietEqual
de la norme IEEE-754.
Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag)
est effectuée à l'aide des méthodes comparison_direction
et compare_type
fournies.
L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes lorsque comparison_direction
est GE
, GT
, LE
ou LT
(numéro 560).
Pour les types quantifiés, effectue dequantize_compare(lhs, rhs,
comparison_direction)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1-C3) |
(I2) | rhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1-C2) |
(I3) | comparison_direction |
énumération de EQ , NE , GE , GT , LE et LT |
|
(I4) | compare_type |
énumération de FLOAT , TOTALORDER , SIGNED et UNSIGNED |
(C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen | (C2) |
Contraintes
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
est défini comme suit :SIGNED
siis_signed_integer(element_type(lhs))
.UNSIGNED
siis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
ouTOTALORDER
siis_float(element_type(lhs))
.FLOAT
siis_complex(element_type(lhs))
.
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complexe
Sémantique
Effectue une conversion par élément en une valeur complexe à partir d'une paire de valeurs réelles et imaginaires, lhs
et rhs
, et produit un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
tenseur de type f32 ou f64 |
(C1-C3) |
(I2) | rhs |
tenseur de type f32 ou f64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type complexe | (C2), (C3) |
Contraintes
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
est de typecomplex<E>
oùE = element_type(lhs)
.
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
composite
Sémantique
Encapsule une opération composée d'autres opérations StableHLO, prenant inputs
et composite_attributes
, générant ainsi results
. La sémantique de l'opération est implémentée par l'attribut decomposition
. L'opération composite
peut être remplacée par sa décomposition sans modifier la sémantique du programme. Si l'intégration de la décomposition ne fournit pas la même sémantique d'opération, privilégiez custom_call
.
Le champ version
(par défaut 0
) permet d'indiquer quand la sémantique d'un composite change.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre de valeurs variable |
(I2) | name |
constante de type string |
(I3) | composite_attributes |
dictionnaire d'attributs |
(I4) | decomposition |
constante de type string |
(I5) | version |
constante de type si32 |
Sorties
Nom | Type |
---|---|
results |
nombre variadique de valeurs |
Contraintes
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Exemples
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Sémantique
Chaîne inputs
le long de la dimension dimension
dans le même ordre que les arguments donnés et produit un tenseur result
. Plus formellement, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, où:
id = d0 + ... + dk-1 + kd
.d
est égal àdimension
, etd0
, etc. sont les tailles de lad
e dimension deinputs
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1-C6) |
(I2) | dimension |
constante de type si64 |
(C2), (C4), (C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C5-C6) |
Contraintes
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
, saufdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, sauf :dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Exemples
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Sémantique
Génère un Tensor output
à partir d'une constante value
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | value |
constante | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
un tenseur ou un tenseur quantifié ; | (C1) |
Contraintes
- (C1)
type(value) = type(output)
.
Exemples
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
d'effectuer une conversion
Sémantique
Effectue une conversion par élément d'un type d'élément à un autre sur le tenseur operand
et produit un tenseur result
.
Pour les conversions booléen vers n'importe quel type compatible, la valeur false
est convertie en zéro, et la valeur true
en un. Pour les conversions any-supported-type-to-boolean, une valeur nulle est convertie en false
, et les valeurs non nulles sont converties en true
. Pour en savoir plus sur le fonctionnement de cette fonctionnalité pour les types complexes, consultez la section ci-dessous.
Pour les conversions impliquant des nombres entier en entier, entier en virgule flottante ou de type virgule flottante en virgule flottante, si la valeur source peut être exactement représentée dans le type de destination, la valeur du résultat correspond à cette représentation exacte. Sinon, le comportement est à déterminer (numéro 180).
Pour les conversions impliquant une floating-point-to-integer, la partie fractionnaire est tronquée. Si la valeur tronquée ne peut pas être représentée dans le type de destination, le comportement est à déterminer (numéro 180).
Les conversions complexe-complexe suivent le même comportement que les conversions virgule flottante-virgule flottante pour convertir les parties réelles et imaginaires.
Pour les conversions complex-to-any-other-type et any-other-type-to-complex, la valeur imaginaire source est ignorée ou la valeur imaginaire de destination est mise à zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.
En principe, cette opération pourrait exprimer la déquantisation (conversion des tenseurs quantiques en tenseurs réguliers), la quantification (conversion des tenseurs réguliers en tenseurs quantiques) et la requantisation (conversion entre tenseurs quantiques), mais pour le moment, nous disposons d'opérations dédiées à cet effet : uniform_dequantize
pour le premier cas d'utilisation et uniform_quantize
pour le deuxième et le troisième cas d'utilisation. À l'avenir, ces deux opérations pourront être fusionnées dans convert
(#1576).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur | (C1) |
Contraintes
- (C1)
shape(operand) = shape(result)
.
Exemples
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolution
Sémantique
Calcule les produits scalaires entre les fenêtres de lhs
et les tranches de rhs
, et produit result
. Le diagramme suivant montre comment les éléments de result
sont calculés à partir de lhs
et rhs
à l'aide d'un exemple concret.
Plus formellement, considérez le recadrage suivant des entrées en termes de lhs
afin de pouvoir exprimer des fenêtres de lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
lhs_padding = lhs_shape([0, 0], padding, [0, 0])
lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Ce recadrage utilise les fonctions d'assistance suivantes:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
oùj[d] = i[permutation[d]]
.
Si feature_group_count = 1
et batch_group_count = 1
, alors pour tous les output_spatial_index
dans index_space(dim(result, output_spatial_dimensions...))
, result[result_shape(:, output_spatial_index, :)] = dot_product
, où :
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Si feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
result = concatenate(results, output_feature_dimension)
.
Si batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Pour les types quantifiés, effectue dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Pour les types hybrides quantifiés, exécute hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
un tenseur ou un tenseur quantifié ; | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C3), (C25) |
(I4) | padding |
Constante de tenseur bidimensionnel de type si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Constante de tenseur à dimension 1 de type si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Constante de tenseur à dimension 1 de type si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Constante de Tensor unidimensionnelle de type i1 |
(C9) |
(I8) | input_batch_dimension |
constante de type si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
constante de type si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
constante de type si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
constante de type si64 |
(C15-C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C17-C18), (C25) |
(I14). | output_batch_dimension |
constante de type si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
constante de type si64 |
(C20), (C25), (C30) |
(I16). | output_spatial_dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
constante de type si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
constante de type si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
Nombre variable d'énumérations de DEFAULT , HIGH et HIGHEST |
(C24) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C25-C28), (C30), (C32-34) |
Contraintes
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Étant donné
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Compte
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
donné :is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Avec
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
est défini comme suit :dim(lhs, input_batch_dimension) / batch_group_count
siresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
siresult_dim = output_feature_dimension
.num_windows
sinon, où :output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
rhs_dim = kernel_spatial_dimensions[spatial_dim]
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si l'opération utilise des tenseurs non linéarisables :
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si l'opération utilise des tenseurs quantifiés :
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si
is_per_axis_quantized(rhs)
, alorsquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, alorsquantization_dimension(result) = output_feature_dimension
. - Si
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, alorsis_per_tensor_quantized(result)
. - Si
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Exemples
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosinus
Sémantique
Effectue une opération cosinus élément par élément sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
cos
d'IEEE-754. - Pour les nombres complexes : cosinus complexe.
- Pour les types quantifiés :
dequantize_op_quantize(cosine, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Sémantique
Effectue un comptage par élément du nombre de bits à zéro au début du tenseur operand
et produit un tenseur result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tenseur de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Sémantique
Encapsule une opération call_target_name
définie par l'implémentation qui prend inputs
et called_computations
et produit results
. has_side_effect
, backend_config
et api_version
peuvent être utilisés pour fournir des métadonnées supplémentaires définies par l'implémentation.
Pour le moment, cette opération contient une collection de métadonnées assez désorganisée qui reflète l'évolution organique de son opération équivalente dans le compilateur XLA. À l'avenir, nous prévoyons d'unifier ces métadonnées (numéro 741).
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre de valeurs variable |
(I2) | call_target_name |
constante de type string |
(I3) | has_side_effect |
constante de type i1 |
(I4) | backend_config |
constante de type string ou dictionnaire d'attributs |
(I5) | api_version |
constante de type si32 |
(I6) | called_computations |
nombre variable de constantes de type string |
Sorties
Nom | Type |
---|---|
results |
nombre de valeurs variable |
Exemples
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
diviser
Sémantique
Effectue une division élément par élément des tenseurs lhs
et rhs
du dividende et du diviseur, et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les entiers: division entière qui produit le quotient algébrique, en ignorant toute partie fractionnaire.
- Pour les nombres à virgule flottante :
division
d'IEEE-754. - Pour les nombres complexes : division complexe.
- Pour les types quantifiés :
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Sémantique
Calcule les produits scalaires entre des tranches de lhs
et des tranches de rhs
, et produit un tenseur result
.
Plus formellement, result[result_index] = dot_product
, où:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
oùsize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
etsize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Pour les types quantifiés, effectue dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Pour les types de quantification hybride, effectue hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
contrôle le compromis entre vitesse et précision pour les calculs sur les backends d'accélérateur. Il peut s'agir de l'un des éléments suivants (pour le moment, la sémantique de ces valeurs d'énumération est sous-spécifiée, mais nous prévoyons de résoudre ce problème dans la version #755):
DEFAULT
: calcul le plus rapide, mais approximation la moins précise du nombre d'origine.HIGH
: calcul plus lent, mais approximation plus précise du nombre d'origine.HIGHEST
: calcul le plus lent, mais approximation la plus précise du nombre d'origine.
Un DotAlgorithm
définit les principales propriétés de l'algorithme utilisé pour implémenter l'opération de point, qui définit également la précision. Si les champs d'attribut de l'algorithme sont définis, precision_config
doit être DEFAULT
. DotAlgorithms
n'a pas de valeur par défaut, car les paramètres par défaut sont définis par l'implémentation. Par conséquent, tous les champs d'algorithme de point peuvent être définis sur None
pour spécifier un algorithme de point vide, qui utilisera plutôt la valeur precision_config
.
Les champs DotAlgorithm
incluent les suivants :
lhs_precision_type
etrhs_precision_type
, précisions auxquelles le côté gauche et le côté droit de l'opération sont arrondis. Les types de précision sont indépendants des types de stockage des entrées et des sorties.accumulation_type
: précision utilisée pour l'accumulation.lhs_component_count
,rhs_component_count
etnum_primitive_operations
s'appliquent lorsque nous effectuons un algorithme qui décompose le membre gauche et/ou le membre droit en plusieurs composants et effectue plusieurs opérations de produit scalaire "primaires" sur ces valeurs, généralement pour émuler une précision plus élevée (par exemple, Utiliser le type de données d'IA bfloat16 pour les calculs à plus haute précision : bf16_6x tf32_3x, etc.). Pour les algorithmes sans décomposition, ces valeurs doivent être définies sur1
.allow_imprecise_accumulation
pour spécifier si l'accumulation avec une précision inférieure est autorisée pour certaines étapes (par exemple,CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Exemples d'attributs DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
C'est aux implémentations de décider des combinaisons acceptées. En général, il n'est pas garanti que chaque algorithme soit compatible avec chaque type d'accélérateur par le consommateur de StableHLO. Si un algorithme donné n'est pas compatible, une erreur doit être générée au lieu de passer à une autre option. La validation StableHLO fournit une validation du meilleur effort, ce qui empêche les algorithmes qui ne sont pas connus pour être compatibles avec aucun matériel.
Consultez xla_data.proto > Algorithm
pour connaître certaines valeurs d'algorithme acceptées. La demande n° 2483 décrit le plan visant à créer un document centralisé sur les algorithmes compatibles par backend.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
un tenseur ou un tenseur quantifié ; | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
nombre variadique d'énumérations de DEFAULT , HIGH et HIGHEST |
(C11), (C21) |
(I8) | lhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType ou TensorFloat32 | (C21) |
(I11) | lhs_component_count |
constante de type si32 |
(C21), (C22) |
(I12) | rhs_component_count |
constante de type si32 |
(C21), (C23) |
(I13). | num_primitive_operations |
constante de type si32 |
(C21), (C24) |
(I14) | allow_imprecise_accumulation |
constante de type bool |
(C21) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C12), (C14), (C18-C20) |
Contraintes
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Si l'opération utilise des tenseurs non linéarisables :
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Si l'opération utilise des tenseurs quantifiés :
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Si
is_per_axis_quantized(rhs)
, alorsquantization_dimension(rhs)
n'est pas dansrhs_contracting_dimensions
. - Si
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Si
is_per_tensor_quantized(rhs)
, alorsis_per_tensor_quantized(result)
. - Si
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Si
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Exemples
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Sémantique
Cette opération est fonctionnellement identique à l'opération broadcast_in_dim, mais la forme du résultat est spécifiée de manière dynamique via output_dimensions
.
L'opération accepte également les attributs facultatifs known_expanding_dimensions
et known_nonexpanding_dimensions
pour exprimer des connaissances statiques sur le comportement d'expansion des dimensions.
Si aucune valeur n'est spécifiée, toutes les dimensions sont supposées pouvoir s'étendre.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensor unidimensionnel de type entier | (C7) |
(I3) | broadcast_dimensions |
Tensor constante à dimension 1 de type entier | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensor constante à dimension 1 de type entier | (C8-C9) |
(I5) | known_nonexpanding_dimensions |
Tensor constante à dimension 1 de type entier | (C8-C9) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C1), (C3), (C5-C7) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, à l'exception dequantization_dimension(operand)
,scales(operand)
etzero_points(operand)
, qui peuvent différer dequantization_dimension(result)
,scales(result)
etzero_points(result)
, respectivement, dans le cas contraire.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Pour tous les
d
deaxes(operand)
:dim(operand, d) = 1
oudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si la valeur est
dim(operand, quantization_dimension(operand)) = 1
, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Exemples
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Sémantique
Cette opération est fonctionnellement identique à l'opération de convolution, mais le remplissage est spécifié de manière dynamique via padding
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
un tenseur ou un tenseur quantifié ; | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensor bidimensionnel de type entier | (C4) |
(I4) | window_strides |
Constante de tenseur à dimension 1 de type si64 |
(C2-C3) |
(I5) | lhs_dilation |
Constante de tenseur à dimension 1 de type si64 |
(C5-C6) |
(I6) | rhs_dilation |
Constante de tenseur à dimension 1 de type si64 |
(C7-C8) |
(I7) | window_reversal |
Constante de Tensor unidimensionnelle de type i1 |
(C9) |
(I8) | input_batch_dimension |
constante de type si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
constante de type si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C12), (C13) |
(I11) | kernel_input_feature_dimension |
constante de type si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
constante de type si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C17-C18) |
(I14) | output_batch_dimension |
constante de type si64 |
(C20) |
(I15) | output_feature_dimension |
constante de type si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C19-C20) |
(I17) | feature_group_count |
constante de type si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
constante de type si64 |
(C10), (C15), (C22), (C23) |
(I19). | precision_config |
Nombre variable d'énumérations de DEFAULT , HIGH et HIGHEST |
(C24) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C25-C27), (C29), (C31-C33) |
Contraintes
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Étant donné
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Compte
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
donné :is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Avec
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
est défini comme suit :dim(lhs, input_batch_dimension) / batch_group_count
siresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
siresult_dim = output_feature_dimension
.num_windows
sinon, où :output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
rhs_dim = kernel_spatial_dimensions[spatial_dim]
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si l'opération utilise des tenseurs non linéarisables :
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si l'opération utilise des tenseurs quantifiés :
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si
is_per_axis_quantized(rhs)
, alorsquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, alorsquantization_dimension(result) = output_feature_dimension
. - Si
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, alorsis_per_tensor_quantized(result)
. - Si
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Exemples
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Sémantique
Cette opération est fonctionnellement identique à l'opération gather, avec l'slice_sizes
spécifié dynamiquement en tant que valeur.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
Tensor de type entier | (C2), (C3), (C13) |
(I3) | slice_sizes |
Tensor 1D de type entier | (C8), (C11-C13) |
(I4) | offset_dims |
Constante de tenseur à dimension 1 de type si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Constante de tenseur à dimension 1 de type si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
constante de type si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
constante de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C5), (C13-C14) |
Contraintes
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
où :batch_dim_sizes = shape(start_indices)
, sauf que la taille de dimension destart_indices
correspondant àindex_vector_dim
n'est pas incluse.offset_dim_sizes = shape(slice_sizes)
, sauf que les tailles des dimensions dansslice_sizes
correspondant àcollapsed_slice_dims
ne sont pas incluses.combine
placebatch_dim_sizes
sur les axes correspondant àbatch_dims
etoffset_dim_sizes
sur les axes correspondant àoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Exemples
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Sémantique
Cette opération est fonctionnellement identique à l'opération iota, mais la forme du résultat est spécifiée de manière dynamique via output_shape
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | output_shape |
Tensor unidimensionnel de type entier | (C1), (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C2) |
Contraintes
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Exemples
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Sémantique
Cette opération est fonctionnellement identique à l'opération pad, mais avec edge_padding_low
, edge_padding_high
et interior_padding
spécifiés dynamiquement en tant que valeurs.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor à dimension 0 ou tensor quantifié par tensor | (C1) |
(I3) | edge_padding_low |
Tensor 1D de type entier | (C1), (C4) |
(I4) | edge_padding_high |
Tensor 1D de type entier | (C1), (C4) |
(I5) | interior_padding |
Tensor unidimensionnel de type entier | (C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C3-C6) |
Contraintes
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Exemples
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Sémantique
Cette opération est fonctionnellement identique à l'opération reshape, mais la forme du résultat est spécifiée de manière dynamique via output_shape
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié ; | (C1-C3) |
(I2) | output_shape |
Tensor 1D de type entier | (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C1-C4) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, sauf quequantization_dimension(operand)
etquantization_dimension(result)
peuvent être différents.
- (C2)
size(operand) = size(result)
. - (C3) Si
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Sémantique
Extrait une tranche du operand
à l'aide d'index de départ calculés dynamiquement et produit un Tensor result
. start_indices
contient les indices de début de la tranche pour chaque dimension pouvant être ajustée, et slice_sizes
contient les tailles de la tranche pour chaque dimension. Plus formellement, result[result_index] = operand[operand_index]
, où :
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C2), (C4) |
(I2) | start_indices |
nombre variadique de Tensors à 0 dimensions de type entier | (C2), (C3) |
(I3) | slice_sizes |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Exemples
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Sémantique
Génère un tenseur result
égal au tenseur operand
, sauf que la tranche commençant à start_indices
est mise à jour avec les valeurs de update
.
Plus formellement, result[result_index]
est défini comme suit :
update[update_index]
si0 <= update_index < shape(update)
où :adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- Sinon,
operand[result_index]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C4), (C6) |
(I2) | update |
un tenseur ou un tenseur quantifié par tenseur ; | (C2), (C3), (C6) |
(I3) | start_indices |
nombre variadique de Tensors à 0 dimensions de type entier | (C4), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Exemples
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponentiel
Sémantique
Effectue une opération exponentielle au niveau des éléments sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
exp
d'IEEE-754. - Pour les nombres complexes: exponentiel complexe
- Pour les types quantifiés :
dequantize_op_quantize(exponential, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Sémantique
Effectue une opération exponentielle moins un par élément sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante:
expm1
d'IEEE-754. - Pour les nombres complexes : l'exponentielle complexe moins un.
- Pour les types quantifiés :
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Sémantique
Effectue les transformations de Fourier directe et inverse pour les entrées/sorties réelles et complexes.
fft_type
est l'un des éléments suivants :
FFT
: FFT de type "forward" complexe à complexe.IFFT
: FFT inverse complexe à complexe.RFFT
: FFT directe réelle-complexe.IRFFT
: FFT inverse réel-complexe (c'est-à-dire, prend complexe, renvoie réel).
Plus formellement, compte tenu de la fonction fft
, qui prend en entrée des Tensors unidimensionnels de types complexes, produit des Tensors unidimensionnels de mêmes types que la sortie et calcule la transformation de Fourier discrète:
Pour fft_type = FFT
, result
est défini comme le résultat final d'une série de calculs L où L = size(fft_length)
. Par exemple, pour L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
De plus, étant donné la fonction ifft
qui a la même signature de type et calcule l'inverse de fft
:
Pour fft_type = IFFT
, result
est défini comme l'inverse des calculs pour fft_type = FFT
. Par exemple, pour L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
result[i0, ..., :] = ifft(result2[i0, ..., :])
.
De plus, étant donné la fonction rfft
qui prend des tenseurs unidimensionnels de types à virgule flottante, produit des tenseurs unidimensionnels de types complexes de la même sémantique à virgule flottante et fonctionne comme suit :
rfft(real_operand) = truncated_result
oùcomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Lorsque la transformée de Fourier discrète est calculée pour des opérandes réels, les premiers éléments N/2 + 1
du résultat définissent de manière non ambiguë le reste du résultat. Le résultat de rfft
est donc tronqué pour éviter de calculer des éléments redondants.)
Pour fft_type = RFFT
, result
est défini comme le résultat final d'une série de calculs L où L = size(fft_length)
. Par exemple, pour L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Enfin, étant donné la fonction irfft
qui a la même signature de type et calcule l'inverse de rfft
:
Pour fft_type = IRFFT
, result
est défini comme l'inverse des calculs pour fft_type = RFFT
. Par exemple, pour L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tenseur de type à virgule flottante ou complexe | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
énumération de FFT , IFFT , RFFT et IRFFT |
(C2), (C5) |
(I3) | fft_length |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C3), (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type complexe ou à virgule flottante | (C2), (C4), (C5) |
Contraintes
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relation entre les types d'éléments
operand
etresult
varie :- Si
fft_type = FFT
,element_type(operand)
etelement_type(result)
ont le même type complexe. - Si
fft_type = IFFT
,element_type(operand)
etelement_type(result)
ont le même type complexe. - Si
fft_type = RFFT
,element_type(operand)
est un type à virgule flottante etelement_type(result)
est un type complexe de la même sémantique à virgule flottante. - Si
fft_type = IRFFT
,element_type(operand)
est un type complexe etelement_type(result)
est un type à virgule flottante de la même sémantique à virgule flottante.
- Si
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Si parmi
operand
etresult
, il existe un tenseurreal
d'un type à virgule flottante, alorsshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, sauf :- Si
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Si
fft_type = IRFFT
, alorsdim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Si
Exemples
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
étage
Sémantique
Effectue un prix plancher par élément du Tensor operand
et produit un Tensor result
.
Met en œuvre l'opération roundToIntegralTowardNegative
de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(floor, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
rassembler
Sémantique
Regroupe les tranches du Tensor operand
à partir des décalages spécifiés dans start_indices
et génère un Tensor result
.
Le diagramme suivant montre comment les éléments de result
sont mappés sur les éléments de operand
à l'aide d'un exemple concret. Le diagramme sélectionne quelques exemples d'indices result
et explique en détail à quels indices operand
ils correspondent.
Plus formellement, result[result_index] = operand[operand_index]
, où:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
est défini comme suit :start_indices[bi0, ..., :, ..., biN]
, oùbi
sont des éléments individuels dansbatch_index
et:
est inséré à l'indiceindex_vector_dim
, siindex_vector_dim
<rank(start_indices)
.- Sinon,
[start_indices[batch_index]]
.
- Pour
d_operand
dansaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sid_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
dans les autres cas.
- Pour
d_operand
dansaxes(operand)
,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_operand = operand_batching_dims[i_batching]
etd_start = start_indices_batching_dims[i_batching]
.- Sinon,
full_batching_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, oùoi
sont des éléments individuels dansoffset_index
, et0
est inséré aux indices decollapsed_slice_dims
etoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
.
Si indices_are_sorted
est true
, l'implémentation peut supposer que les start_indices
sont triées par rapport à start_index_map
. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2
à partir de indices(result)
, full_start_index(i1) <= full_start_index(i2)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
Tensor de type entier | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Constante de tenseur à dimension 1 de type si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Constante de tenseur à dimension 1 de type si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Constante de tenseur à dimension 1 de type si64 |
(C13-C17) |
(I7) | start_index_map |
Constante de tenseur à dimension 1 de type si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
constante de type si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Constante de tenseur à dimension 1 de type si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
constante de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C5), (C22-C23) |
Contraintes
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
où :batch_dim_sizes = shape(start_indices)
, sauf que la taille de dimension destart_indices
correspondant àindex_vector_dim
n'est pas incluse.offset_dim_sizes = slice_sizes
, sauf que les tailles de dimension dansslice_sizes
correspondant àcollapsed_slice_dims
etoperand_batching_dims
ne sont pas incluses.combine
placebatch_dim_sizes
sur les axes correspondant àbatch_dims
etoffset_dim_sizes
sur les axes correspondant àoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Exemples
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Sémantique
Renvoie la taille de l'dimension
donnée de l'operand
. Plus formellement, result = dim(operand, dimension)
. La sémantique ne concerne que le composant de forme du type. Le type d'élément peut être n'importe quoi.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié ; | (C1) |
(I2) | dimension |
constante de type si64 |
(C1) |
Sorties
Nom | Type |
---|---|
result |
Tensor de dimension 0 de type si32 |
Contraintes
- (C1)
0 <= dimension < rank(operand)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Sémantique
Extraction de l'élément à la position index
du tuple operand
et production d'un result
. Plus formellement, result = operand[index]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
constante de type si32 |
(C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tout type compatible | (C2) |
Contraintes
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Exemples
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
si
Sémantique
Génère la sortie de l'exécution d'une seule fonction à partir de true_branch
ou false_branch
, en fonction de la valeur de pred
. Plus formellement, result =
pred ? true_branch() : false_branch()
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | pred |
Tensor de dimension 0 de type i1 |
|
(I2) | true_branch |
fonction | (C1-C3) |
(I3) | false_branch |
fonction | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C3) |
Contraintes
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Exemples
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Sémantique
Extrait la partie imaginaire, par élément, de operand
et produit un tenseur result
. Plus formellement, pour chaque élément x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type complexe ou à virgule flottante | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
est défini comme suit :complex_element_type(element_type(operand))
siis_complex(operand)
.- Sinon,
element_type(operand)
.
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
infeed
Sémantique
Lit les données du flux d'entrée et génère results
.
La sémantique de infeed_config
est définie par l'implémentation.
results
se compose de valeurs de charge utile qui viennent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (numéro 670).
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
constante de type string |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
Nombre variable de tenseurs, de tenseurs quantifiés ou de jetons | (C1-C3) |
Contraintes
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Exemples
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Sémantique
Remplit un Tensor output
avec des valeurs dans l'ordre croissant à partir de zéro le long de la dimension iota_dimension
. Plus formellement,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
0 <= iota_dimension < rank(output)
.
Exemples
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Sémantique
Effectue une vérification par élément pour déterminer si la valeur dans x
est finie (c'est-à-dire qu'elle n'est ni +Inf, ni -Inf, ni NaN) et produit un tenseur y
. Implémente l'opération isFinite
de la spécification IEEE-754. Pour les types quantifiés, le résultat est toujours true
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | x |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
y |
tenseur de type booléen | (C1) |
Contraintes
- (C1)
shape(x) = shape(y)
.
Exemples
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Sémantique
Effectue une opération logarithmique élément par élément sur le tenseur operand
et génère un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
log
d'IEEE-754. - Pour les nombres complexes: logarithme complexe.
- Pour les types quantifiés :
dequantize_op_quantize(log, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Sémantique
Effectue le logarithme élément par élément plus une opération sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
logp1
d'IEEE-754. - Pour les nombres complexes : logarithme complexe plus un.
- Pour les types quantifiés :
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistique
Sémantique
Effectue une opération logistique au niveau des éléments sur le Tensor operand
et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante :
division(1, addition(1, exp(-x)))
d'IEEE-754. - Pour les nombres complexes: logistique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(logistic, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
carte
Sémantique
Applique une fonction de mappage computation
à inputs
le long de dimensions
et produit un tenseur result
.
Plus formellement, result[result_index] = computation(inputs...[result_index])
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1-C4) |
(I2) | dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C3) |
(I3) | computation |
fonction | (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C4) |
Contraintes
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
est de type(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, oùEi = element_type(inputs[i])
etE' = element_type(result)
.
Exemples
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maximum
Sémantique
Effectue une opération de valeur maximale au niveau des éléments sur les tensors lhs
et rhs
, et génère un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les valeurs booléennes : opérateur logique "OR".
- Pour les entiers : valeur maximale de l'entier.
- Pour les nombres à virgule flottante :
maximum
d'IEEE-754. - Pour les nombres complexes : valeur maximale lexicographique pour la paire
(real, imaginary)
. L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (numéro 560). - Pour les types quantifiés :
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
(I2) | rhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Sémantique
Effectue une opération min par élément sur les tenseurs lhs
et rhs
, et génère un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les valeurs booléennes : "AND" logique
- Pour les entiers : valeur minimale de l'entier.
- Pour les nombres à virgule flottante :
minimum
d'IEEE-754. - Pour les nombres complexes : valeur minimale lexicographique pour la paire
(real, imaginary)
. L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (numéro 560). - Pour les types quantifiés :
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
(I2) | rhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplier
Sémantique
Effectue le produit élément par élément de deux tenseurs lhs
et rhs
, et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes : "AND" logique
- Pour les entiers: multiplication d'entiers.
- Pour les nombres à virgule flottante:
multiplication
d'IEEE-754. - Pour les nombres complexes : multiplication complexe.
- Pour les types quantifiés :
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
(I2) | rhs |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
Sémantique
Effectue une négation par élément du Tensor operand
et produit un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les entiers signés: négation des entiers.
- Pour les entiers non signés : bitcast vers un entier signé, négation d'entier, bitcast vers un entier non signé.
- Pour les nombres à virgule flottante:
negate
d'IEEE-754. - Pour les nombres complexes : négation complexe.
- Pour les types quantifiés :
dequantize_op_quantize(negate, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
not
Sémantique
Effectue une opération NOT élément par élément du tenseur operand
et produit un tenseur result
.
En fonction du type d'élément, effectue les opérations suivantes :
- Pour les valeurs booléennes : négation logique.
- Pour les entiers: NOT (PAS) au niveau du bit.
Arguments
Nom | Type | Contraintes |
---|---|---|
operand |
Tensor de type booléen ou entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Sémantique
Assurez-vous que les opérations qui génèrent le operand
sont exécutées avant toute opération qui dépend de result
et empêche les transformations de compilation de déplacer les opérations au-delà de la barrière. En dehors de cela, l'opération est une identité, c'est-à-dire result = operand
.
Arguments
Nom | Type | Contraintes |
---|---|---|
operand |
nombre variable de tenseurs, de tenseurs quantifiés par tenseur ou de jetons | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
nombre variable de tenseurs, de tenseurs quantifiés par tenseur ou de jetons | (C1) |
Contraintes
- (C1)
type(operand...) = type(result...)
.
Exemples
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
ou
Sémantique
Effectue une opération OU élément par élément sur deux tenseurs lhs
et rhs
, et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les valeurs booléennes: opérateur logique OU.
- Pour les entiers: opération OR au niveau du bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier ou booléen | (C1) |
(I2) | rhs |
un tenseur de type entier ou booléen ; | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier ou booléen ; | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
flux de sortie
Sémantique
Écrit inputs
dans le flux de sortie et génère un jeton result
.
La sémantique de outfeed_config
est définie par l'implémentation.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
Nombre variable de tenseurs ou de tenseurs quantifiés |
(I2) | token |
token |
(I3) | outfeed_config |
constante de type string |
Sorties
Nom | Type |
---|---|
result |
token |
Exemples
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
Sémantique
Élargit operand
en ajoutant des marges internes autour du tenseur et entre les éléments du tenseur avec l'padding_value
donné.
edge_padding_low
et edge_padding_high
spécifient respectivement la quantité de marge intérieure ajoutée à l'extrémité inférieure (à côté de l'indice 0) et à l'extrémité supérieure (à côté de l'indice le plus élevé) de chaque dimension. La valeur de marge intérieure peut être négative, où la valeur absolue de la marge intérieure négative indique le nombre d'éléments à supprimer de la dimension spécifiée.
interior_padding
spécifie la quantité de marge intérieure ajoutée entre deux éléments de chaque dimension, qui peut ne pas être négative. La marge intérieure se produit avant la marge extérieure, de sorte que la marge extérieure négative supprime des éléments de l'opérande à marge intérieure.
Plus formellement, result[result_index]
est défini comme suit :
operand[operand_index]
siresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.- Sinon,
padding_value
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor à dimension 0 ou tensor quantifié par tensor | (C1) |
(I3) | edge_padding_low |
Constante de tenseur à dimension 1 de type si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Constante de tenseur à dimension 1 de type si64 |
(C1), (C4) |
(I5) | interior_padding |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C3-C6) |
Contraintes
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Exemples
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Sémantique
Génère partition_id
du processus en cours.
Sorties
Nom | Type |
---|---|
result |
Tensor de dimension 0 de type ui32 |
Exemples
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
Sémantique
Effectue un comptage par élément du nombre de bits définis dans le tenseur operand
et produit un tenseur result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tenseur de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
puissance
Sémantique
Effectue l'exponentiation élément par élément du tenseur lhs
par le tenseur rhs
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les entiers : exponentiation entière.
- Pour les nombres à virgule flottante :
pow
d'IEEE-754. - Pour les nombres complexes: exponentielle complexe.
- Pour les types quantifiés :
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Sémantique
Extrait la partie réelle, par élément, de la operand
et produit un tenseur result
. Plus formellement, pour chaque élément x
: real(x) = is_complex(x) ? real_part(x) : x
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type complexe ou à virgule flottante | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
est défini comme suit :complex_element_type(element_type(operand))
siis_complex(operand)
.- Sinon,
element_type(operand)
.
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Sémantique
Reçoit les données d'un canal avec channel_id
et produit results
.
Si is_host_transfer
est true
, l'opération transfère des données depuis l'hôte. Sinon, il transfère des données depuis un autre appareil. Cela signifie que la valeur est définie par l'implémentation. Cet indicateur duplique les informations fournies dans channel_type
. Nous prévoyons donc de n'en conserver qu'un seul (#666).
results
est constitué de valeurs de charge utile qui apparaissent en premier et d'un jeton qui apparaissent en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (numéro 670).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
constante de type si64 |
|
(I3) | channel_type |
énumération de DEVICE_TO_DEVICE et HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
constante de type i1 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
Nombre variable de tenseurs, de tenseurs quantifiés ou de jetons | (C2-C4) |
Contraintes
- (C1)
channel_type
est défini comme suit :HOST_TO_DEVICE
siis_host_transfer = true
,- Sinon,
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Exemples
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Sémantique
Applique une fonction de réduction body
à inputs
et init_values
le long de dimensions
et produit des Tensors results
.
L'ordre des réductions est défini par l'implémentation, ce qui signifie que body
et init_values
doivent former un monoide pour garantir que l'opération produit les mêmes résultats pour toutes les entrées sur toutes les implémentations. Cependant, cette condition n'est pas valable pour de nombreuses réductions populaires. Par exemple, l'addition à virgule flottante pour body
et zéro pour init_values
ne forment pas réellement un monoide, car l'addition à virgule flottante n'est pas associative.
Plus formellement, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, où :
input_slices = inputs...[j0, ..., :, ..., jR-1]
, où:
sont insérés àdimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
pour un arbre binaireschedule
où :exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
est une arborescence binaire complète définie par l'implémentation dont le balayage dans l'ordre comprend les éléments suivants :- Valeurs
input_slices_converted...[index]
, pour tous lesindex
dansindex_space(input_slices_converted)
dans l'ordre lexicographique croissant deindex
. - Intercalés avec une quantité de
init_values_converted
définie par l'implémentation aux positions définies par l'implémentation.
- Valeurs
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1-C4), (C6), (C7) |
(I2) | init_values |
nombre variadique de Tensors à 0 dimensions ou de Tensors quantifiés par Tensor | (C2), (C3) |
(I3) | dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C4), (C5), (C7) |
(I4) | body |
fonction | (C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C3), (C7), (C8) |
Contraintes
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
oùis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, à l'exception des tailles de dimension deinputs...
correspondant àdimensions
. - (C8)
element_type(results[i]) = Ei
pour tous lesi
de[0,N)
.
Exemples
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Sémantique
Effectue une conversion élément par élément de operand
vers un autre type à virgule flottante qui utilise exponent_bits
et mantissa_bits
, puis vers le type à virgule flottante d'origine et produit un tenseur output
.
Plus formellement :
- Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine à la valeur la plus proche pouvant être représentée avec
mantissa_bits
à l'aide de la sémantiqueroundToIntegralTiesToEven
. - Ensuite, si
mantissa_bits
est inférieur au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués àmantissa_bits
. - Ensuite, si les bits d'exposant du résultat intermédiaire ne rentrent pas dans la plage fournie par
exponent_bits
, le résultat intermédiaire déborde vers l'infini à l'aide du signe d'origine ou sous-déborde vers zéro à l'aide du signe d'origine. - Pour les types quantifiés, effectue
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1) |
(I2) | exponent_bits |
constante de type si32 |
(C2) |
(I3) | mantissa_bits |
constante de type si32 |
(C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Exemples
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, effectue une réduction, à l'aide de computations
, sur les valeurs du tenseur operand
de chaque processus, divise le résultat de la réduction en parties le long de scatter_dimension
et disperse les parties fractionnées entre les processus pour produire le result
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est définie comme suit :
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
pour tous lessender
dansprocess_group
, oùreceiver_index = process_group.index(receiver)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
constante de type si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
Constante de tenseur bidimensionnel de type si64 |
(C3-C5) |
(I4) | channel_id |
constante de type si64 |
(C6) |
(I5) | use_global_device_ids |
constante de type i1 |
(C6) |
(I6) | computation |
fonction | (C7) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C8-C9) |
Contraintes
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Si
use_global_device_ids = true
, alorschannel_id > 0
. - (C7)
computation
est de type(tensor<E>, tensor<E>) -> (tensor<E>)
, oùis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
, sauf :dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Sémantique
Applique une fonction de réduction body
aux fenêtres de inputs
et init_values
, et produit results
.
Le schéma suivant montre comment les éléments de results...
sont calculés à partir de inputs...
à l'aide d'un exemple concret.
Plus formellement, results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(voir reduce) où :
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
window_end = window_start + (window_dimensions - 1) * window_dilations + 1
windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
Nombre variable de tenseurs à dimension 0 ou de tenseurs quantifiés par tenseur | (C1), (C13) |
(I3) | window_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Constante de tenseur à dimension 1 de type si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Constante de tenseur à dimension 1 de type si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Constante de tenseur à dimension 1 de type si64 |
(C10), (C11), (C15) |
(I7) | padding |
Constante de tenseur bidimensionnel de type si64 |
(C12), (C15) |
(I8) | body |
fonction | (C13) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1), (C14-C16) |
Contraintes
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, oùis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
où :dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
pour l'ensemble desi
de[0,N)
.
Exemples
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
reste
Sémantique
Effectue le reste par élément des tensors lhs
et rhs
du dividende et du diviseur, et produit un tenseur result
.
Plus formellement, le signe du résultat est extrait du dividende, et la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur.
Le reste est calculé comme lhs - d * rhs
, où d
est donné par :
- Pour les entiers:
stablehlo.divide(lhs, rhs)
. - Pour les nombres à virgule flottante :
division(lhs, rhs)
d'IEEE-754 avec l'attribut d'arrondiroundTowardZero
. - Pour les nombres complexes: à déterminer (#997).
- Pour les types quantifiés :
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Pour les types d'éléments à virgule flottante, cette opération contraste avec l'opération remainder
de la spécification IEEE-754, où d
est une valeur entière la plus proche de la valeur exacte de lhs/rhs
avec des liens pairs.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Sémantique
Génère replica_id
du processus actuel.
Sorties
Nom | Type |
---|---|
result |
Tensor de dimension 0 de type ui32 |
Exemples
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
reshape
Sémantique
Effectue un remodelage du Tensor operand
en un Tensor result
. Conceptuellement, cela revient à conserver la même représentation canonique, mais à modifier potentiellement la forme, par exemple de tensor<2x3xf32>
à tensor<3x2xf32>
ou tensor<6xf32>
.
Plus formellement, result[result_index] = operand[operand_index]
, où result_index
et operand_index
ont la même position dans l'ordre lexicographique de index_space(result)
et index_space(operand)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié ; | (C1-C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1-C3) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, sauf quequantization_dimension(operand)
etquantization_dimension(result)
peuvent être différents.
- (C2)
size(operand) = size(result)
. - (C3) Si
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
inverser
Sémantique
Inverse l'ordre des éléments dans operand
le long de dimensions
spécifié et génère un tenseur result
. Plus formellement, result[result_index] = operand[operand_index]
, où :
operand_index[d] = dim(result, d) - result_index[d] - 1
sid
endimensions
.- Sinon,
operand_index[d] = result_index[d]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C3) |
(I2) | dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C3) |
Contraintes
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Exemples
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Sémantique
Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution
et produit un tenseur result
d'une forme shape
donnée.
Si la valeur est rng_distribution = UNIFORM
, les nombres aléatoires sont générés en suivant la distribution uniforme sur l'intervalle [a, b)
. Si la valeur est a >= b
, le comportement n'est pas défini.
Si la valeur est rng_distribution = NORMAL
, les nombres aléatoires sont générés selon la distribution normale, avec une moyenne = a
et l'écart type = b
.
Si la valeur est b < 0
, le comportement n'est pas défini.
La méthode exacte de génération des nombres aléatoires est définie par l'implémentation. Par exemple, ils peuvent être ou non déterministes, et ils peuvent ou non utiliser un état masqué.
Lors de discussions avec de nombreux partenaires, il est apparu que cette opération était effectivement obsolète. Nous envisageons donc de la supprimer à l'avenir (numéro 597).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor à zéro dimension de type entier, booléen ou à virgule flottante | (C1), (C2) |
(I2) | b |
Tensor 0 dimensionnel de type entier, booléen ou à virgule flottante | (C1), (C2) |
(I3) | shape |
Constante de tenseur à dimension 1 de type si64 |
(C3) |
(I4) | rng_distribution |
énumération de UNIFORM et NORMAL |
(C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier, booléen ou à virgule flottante ; | (C1-C3) |
Contraintes
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Si
rng_distribution = NORMAL
, alorsis_float(a)
. - (C3)
shape(result) = shape
.
Exemples
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Sémantique
Renvoie un output
rempli de bits aléatoires uniformes et un état de sortie output_state
mis à jour à l'aide de l'algorithme de générateur de nombres pseudo-aléatoires rng_algorithm
, étant donné un état initial initial_state
. La sortie est garantie comme étant une fonction déterministe de initial_state
, mais elle n'est pas garantie comme étant déterministe entre les implémentations.
rng_algorithm
est l'un des éléments suivants :
DEFAULT
: algorithme défini par l'implémentation.THREE_FRY
: variante de l'algorithme Threefry définie par l'implémentation*PHILOX
: variante de l'algorithme Philox définie par l'implémentation*
* Voir Salmon et al. SC 2011. Nombres aléatoires parallèles: c'est aussi simple que 1, 2, 3.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | rng_algorithm |
énumération de DEFAULT , THREE_FRY et PHILOX |
(C2) |
(I2) | initial_state |
Tensor unidimensionnel de type ui64 |
(C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output_state |
Tensor 1D de type ui64 |
(C1) |
output |
tenseur de type entier ou à virgule flottante |
Contraintes
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
est défini comme suit :- définie par l'implémentation si
rng_algorithm = DEFAULT
. 2
sirng_algorithm = THREE_FRY
.2
ou3
sirng_algorithm = PHILOX
.
- définie par l'implémentation si
Exemples
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Sémantique
Effectue un arrondi au niveau des éléments vers l'entier le plus proche, en rompant les liens avec zéro sur le Tensor operand
et produit un Tensor result
. Implémente l'opération roundToIntegralTiesToAway
de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Sémantique
Effectue un arrondi par élément vers l'entier le plus proche, en cas d'égalité, vers l'entier pair, sur le tenseur operand
et produit un tenseur result
. Implémente l'opération roundToIntegralTiesToEven
de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Sémantique
Effectue une opération de racine carrée réciproque par élément sur le Tensor operand
et produit un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
rSqrt
d'IEEE-754. - Pour les nombres complexes : racine carrée réciproque complexe.
- Pour les types quantifiés :
dequantize_op_quantize(rsqrt, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
disperser
Sémantique
Génère des Tensors results
égaux aux Tensors inputs
, à la différence que plusieurs tranches spécifiées par scatter_indices
sont mises à jour avec les valeurs updates
à l'aide de update_computation
.
Le schéma suivant montre comment les éléments de updates...
sont mappés sur des éléments de results...
à l'aide d'un exemple concret. Le diagramme sélectionne quelques exemples d'indices updates...
et explique en détail à quels indices results...
ils correspondent.
Plus formellement, pour tous les update_index
de index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
est défini comme suit :scatter_indices[si0, ..., :, ..., siN]
, oùsi
correspond à des éléments individuels dansupdate_scatter_index
et:
est inséré à l'indexindex_vector_dim
, siindex_vector_dim
<rank(scatter_indices)
.- Sinon,
[scatter_indices[update_scatter_index]]
.
- Pour
d_input
dansaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
sid_input = scatter_dims_to_operand_dims[d_start]
.- Sinon,
full_start_index[d_input] = 0
.
- Pour
d_input
dansaxes(inputs[0])
,full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_input = input_batching_dims[i_batching]
etd_start = scatter_indices_batching_dims[i_batching]
.- Sinon,
full_batching_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, oùwi
sont des éléments individuels dansupdate_window_index
, et0
est inséré aux indices deinserted_window_dims
etinput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Par conséquent, results = exec(schedule, inputs)
, où:
schedule
est une permutation deindex_space(updates[0])
définie par l'implémentation.exec([update_index, ...], results) = exec([...], updated_results)
où :- Si
result_index
est dans les limites deshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
est une copie deresults
avecresults...[result_index]
défini surupdated_values...
.- Sinon, procédez comme suit :
updated_results = results
.
- Si
exec([], results) = results
.
Si indices_are_sorted
est défini sur true
, l'implémentation peut supposer que les éléments scatter_indices
sont triés par rapport à scatter_dims_to_operand_dims
. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2
de indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Si unique_indices
est true
, l'implémentation peut supposer que tous les indices result_index
vers lesquels la diffusion est effectuée sont uniques. Si unique_indices
est true
, mais que les indices vers lesquels les données sont dispersées ne sont pas uniques, le comportement est indéfini.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tenseur de type entier | (C4), (C15), (C19), (C22) |
(I3) | updates |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C3-C6), (C8) |
(I4) | update_window_dims |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Constante de tenseur à dimension 1 de type si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Constante de tenseur à dimension 1 de type si64 |
(C19-C21) |
(I9) | index_vector_dim |
constante de type si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
constante de type i1 |
|
(I11) | unique_indices |
constante de type i1 |
|
(I12) | update_computation |
fonction | (C23) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C24-C25) |
Contraintes
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
où :update_scatter_dim_sizes = shape(scatter_indices)
, sauf que la taille de la dimensionscatter_indices
correspondant àindex_vector_dim
n'est pas incluse.update_window_dim_sizes <= shape(inputs[0])
, à l'exception des tailles de dimension dansinputs[0]
correspondant àinserted_window_dims
etinput_batching_dims
.combine
placeupdate_scatter_dim_sizes
sur les axes correspondant àupdate_scatter_dims
etupdate_window_dim_sizes
sur les axes correspondant àupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, oùis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
pour tous lesi
dans[0,N)
.
Exemples
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
sélectionner
Sémantique
Génère un tenseur result
dans lequel chaque élément est sélectionné à partir du tenseur on_true
ou on_false
en fonction de la valeur de l'élément correspondant de pred
.
Plus formellement, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, où pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Pour les types quantifiés, exécute dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | pred |
tenseur de type i1 |
(C1) |
(I2) | on_true |
un tenseur ou un tenseur quantifié par tenseur ; | (C1-C2) |
(I3) | on_false |
Tensor ou Tensor quantifié par Tensor | (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C2) |
Contraintes
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Exemples
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Sémantique
Disperse les valeurs du tenseur source
à l'aide de scatter
en fonction du résultat de reduce_window
du tenseur input
à l'aide de select
et génère un tenseur result
.
Le diagramme suivant montre comment les éléments de result
sont calculés à partir de operand
et source
à l'aide d'un exemple concret.
Plus formellement:
selected_values = reduce_window_without_init(...)
avec les entrées suivantes:inputs = [operand].
window_dimensions
,window_strides
etpadding
, qui sont utilisés tels quels.base_dilations = windows_dilations = 1
.body
est défini comme suit:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
où
E = element_type(operand)
etreduce_window_without_init
fonctionnent exactement commereduce_window
, sauf que leschedule
de l'reduce
sous-jacent (voir reduce) n'inclut pas les valeurs d'initialisation. Il n'est actuellement pas spécifié ce qui se passe si la fenêtre correspondante ne comporte pas de valeurs (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
où :source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
siselected_values[source_index]
contient l'élémentoperand
deoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1-C4), (C6), (C8-C11) |
(I2) | source |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C2) |
(I3) | init_value |
Tensor à 0 dimensions ou Tensor quantifié par Tensor | (C3) |
(I4) | window_dimensions |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C6), (C7) |
(I6) | padding |
Constante de tenseur bidimensionnel de type si64 |
(C2), (C8) |
(I7) | select |
fonction | (C9) |
(I8) | scatter |
fonction | (C10) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C11-C12) |
Contraintes
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
où :padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
est de type(tensor<E>, tensor<E>) -> tensor<i1>
, oùE = element_type(operand)
. - (C10)
scatter
est de type(tensor<E>, tensor<E>) -> tensor<E>
, oùis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Exemples
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
envoyer
Sémantique
Envoie inputs
à un canal channel_id
et génère un jeton result
.
Si is_host_transfer
est défini sur true
, l'opération transfère les données à l'hôte. Sinon, il transfère les données vers un autre appareil. Cela signifie que la valeur est définie par l'implémentation. Cet indicateur duplique les informations fournies dans channel_type
. Nous prévoyons donc de n'en conserver qu'un seul (#666).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
Nombre variable de tenseurs ou de tenseurs quantifiés | |
(I2) | token |
token |
|
(I3) | channel_id |
constante de type si64 |
|
(I4) | channel_type |
énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
constante de type i1 |
(C1) |
Sorties
Nom | Type |
---|---|
result |
token |
Contraintes
- (C1)
channel_type
est défini comme suit :DEVICE_TO_HOST
siis_host_transfer = true
,- Sinon,
DEVICE_TO_DEVICE
.
Exemples
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Sémantique
Effectue une opération de décalage à gauche au niveau des éléments sur le tenseur lhs
d'un nombre de bits rhs
et produit un tenseur result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
tenseur de type entier | (C1) |
(I2) | rhs |
tenseur de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Sémantique
Effectue une opération de décalage à droite arithmétique par élément sur le tenseur lhs
d'un nombre de bits rhs
et produit un tenseur result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
tenseur de type entier | (C1) |
(I2) | rhs |
tenseur de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Sémantique
Effectue une opération de décalage logique à droite par élément sur le tenseur lhs
d'un nombre de bits rhs
et produit un tenseur result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
tenseur de type entier | (C1) |
(I2) | rhs |
tenseur de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
signer
Sémantique
Renvoie le signe de l'operand
par élément et produit un tenseur result
.
Plus formellement, pour chaque élément x
, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Pour les types quantifiés, exécute dequantize_op_quantize(sign, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Sémantique
Effectue une opération sinusoïdale au niveau des éléments sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
sin
d'IEEE-754. - Pour les nombres complexes: sinus complexe.
- Pour les types quantifiés:
dequantize_op_quantize(sine, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Sémantique
Extraction d'une tranche de l'operand
à l'aide d'indices de début calculés de manière statique et production d'un tenseur result
. start_indices
contient les indices de début de la tranche pour chaque dimension, limit_indices
contient les indices de fin (exclusifs) de la tranche pour chaque dimension et strides
contient les pas pour chaque dimension.
Plus formellement, result[result_index] = operand[operand_index]
, où operand_index = start_indices + result_index * strides
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié par tenseur ; | (C1-C3), (C5) |
(I2) | start_indices |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C3), (C5) |
(I4) | strides |
Constante de tenseur à dimension 1 de type si64 |
(C2), (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié par tenseur ; | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Exemples
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
trier
Sémantique
Trie les tranches unidimensionnelles de inputs
le long de la dimension dimension
,
selon une valeur comparator
, et génère results
.
Contrairement aux entrées similaires d'autres opérations, dimension
autorise les valeurs négatives, avec la sémantique décrite ci-dessous. À l'avenir, cela pourrait être interdit pour des raisons de cohérence (numéro 1377).
Si is_stable
est défini sur "true", le tri est stable, c'est-à-dire que l'ordre relatif des éléments considérés comme égaux par le comparateur est préservé. Dans le cas où il n'y a qu'une seule entrée, deux éléments e1
et e2
sont considérés comme égaux par le comparateur si et seulement si comparator(e1, e2) = comparator(e2, e1) = false
. Consultez la formalisation ci-dessous pour voir comment cela se généralise à plusieurs entrées.
Plus formellement, pour tous les result_index
de index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, oùriN
sont des éléments individuels dansresult_index
et:
est inséré àadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- où
sort
trie une tranche unidimensionnelle dans l'ordre non descendant, en supposant quecomparator_together
renvoietrue
si l'argument de gauche est inférieur au deuxième argument de droite. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C1-C5) |
(I2) | dimension |
constante de type si64 |
(C4) |
(I3) | is_stable |
constante de type i1 |
|
(I4) | comparator |
fonction | (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variable de tenseurs ou tenseurs quantifiés par tenseur | (C2), (C3) |
Contraintes
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, oùR = rank(inputs[0])
. - (C5)
comparator
est de type(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, oùEi = element_type(inputs[i])
.
Exemples
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Sémantique
Effectue une racine carrée élément par élément sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante :
squareRoot
d'IEEE-754. - Pour les nombres complexes : racine carrée complexe.
- Pour les types quantifiés :
dequantize_op_quantize(sqrt, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Sémantique
Effectue la soustraction par élément de deux Tensors lhs
et rhs
, et produit un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les entiers: soustraction d'entiers.
- Pour les nombres à virgule flottante :
subtraction
d'IEEE-754. - Pour les nombres complexes : soustraction complexe.
- Pour les types quantifiés :
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type entier, à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Sémantique
Effectue une opération tangente au niveau des éléments sur le tenseur operand
et produit un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
tan
d'IEEE-754. - Pour les nombres complexes: tangente complexe.
- Pour les types quantifiés :
dequantize_op_quantize(tan, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Sémantique
Effectue une opération de tangente hyperbolique par élément sur le Tensor operand
et produit un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les nombres à virgule flottante :
tanh
d'IEEE-754. - Pour les nombres complexes: tangente hyperbolique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(tanh, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transposer
Sémantique
Permute les dimensions du tenseur operand
à l'aide de permutation
et génère un tenseur result
. Plus formellement, result[result_index] = operand[operand_index]
, où result_index[d] = operand_index[permutation[d]]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
un tenseur ou un tenseur quantifié ; | (C1-C4) |
(I2) | permutation |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur ou un tenseur quantifié ; | (C1), (C3-C4) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, sauf quequantization_dimension(operand)
etquantization_dimension(result)
peuvent être différents.
- (C2)
permutation
est une permutation derange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Si
is_per_axis_quantized(result)
, alorsquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Exemples
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Sémantique
Résout des lots de systèmes d'équations linéaires avec des matrices de coefficients triangulaires inférieures ou supérieures.
Plus formellement, étant donné a
et b
, result[i0, ..., iR-3, :, :]
est la solution de op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
lorsque left_side
est true
ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
lorsque left_side
est false
, en résolvant la variable x
où op(a)
est déterminée par transpose_a
, qui peut être l'une des valeurs suivantes:
NO_TRANSPOSE
: effectuez l'opération à l'aide dea
tel quel.TRANSPOSE
: effectue une opération sur la transposition dea
.ADJOINT
: effectue une opération sur la transposée conjuguée dea
.
Les données d'entrée ne sont lues que dans le triangle inférieur de a
, si lower
est true
ou dans le triangle supérieur de a
, dans le cas contraire. Les données de sortie sont renvoyées dans le même triangle. Les valeurs de l'autre triangle sont définies par l'implémentation.
Si unit_diagonal
est vrai, l'implémentation peut supposer que les éléments de la diagonale de a
sont égaux à 1, sinon le comportement est indéfini.
Pour les types quantifiés, exécute dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1-C3) |
(I2) | b |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1-C4) |
(I3) | left_side |
constante de type i1 |
(C3) |
(I4) | lower |
constante de type i1 |
|
(I5) | unit_diagonal |
constante de type i1 |
|
(I6) | transpose_a |
Énumération de NO_TRANSPOSE , TRANSPOSE et ADJOINT |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
un tenseur de type à virgule flottante ou complexe, ou un tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relation entre
shape(a)
etshape(b)
est définie comme suit :shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Exemples
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Sémantique
Génère un tuple result
à partir des valeurs val
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | val |
nombre de valeurs variable | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tuple | (C1) |
Contraintes
- (C1)
result
est de typetuple<E0, ..., EN-1>
, oùEi = type(val[i])
.
Exemples
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Sémantique
Effectue la conversion élément par élément du tenseur quantifié operand
en tenseur à virgule flottante result
, conformément aux paramètres de quantification définis par le type operand
.
Plus formellement, result = dequantize(operand)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tenseur quantifié | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Exemples
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Sémantique
Effectue une conversion par élément du tenseur à virgule flottante ou du tenseur quantifié operand
en tenseur quantifié result
, conformément aux paramètres de quantification définis par le type result
.
Plus formellement,
- Si
is_float(operand)
:result = quantize(operand, type(result))
.
- Si
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tenseur de type à virgule flottante ou quantifié | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tenseur quantifié | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Exemples
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
tandis que
Sémantique
Génère la sortie de l'exécution de la fonction body
au moins une fois lorsque la fonction cond
génère true
. Plus formellement, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Le comportement d'une boucle infinie est à déterminer (numéro 383).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Nombre variable de tenseurs, de tenseurs quantifiés ou de jetons | (C1-C3) |
(I2) | cond |
fonction | (C1) |
(I3) | body |
fonction | (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C3) |
Contraintes
- (C1)
cond
est de type(T0, ..., TN-1) -> tensor<i1>
, oùTi = type(operand[i])
. - (C2)
body
est de type(T0, ..., TN-1) -> (T0, ..., TN-1)
, oùTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Exemples
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Sémantique
Effectue une opération XOR élément par élément de deux tenseurs lhs
et rhs
, et génère un tenseur result
. En fonction du type d'élément, effectue les opérations suivantes :
- Pour les valeurs booléennes : XOR logique.
- Pour les entiers : XOR (OU exclusif) bit à bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
un tenseur de type booléen ou entier ; | (C1) |
(I2) | rhs |
un tenseur de type booléen ou entier ; | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interopérabilité des dialectes
Pour le moment, les programmes StableHLO dans la nature contiennent parfois des opérations qui ne sont pas définies par StableHLO.
Module, fonction, appel et retour
StableHLO utilise les opérations MLIR en amont pour ModuleOp, FuncOp, CallOp et ReturnOp. Cela a été fait pour améliorer l'interopérabilité avec le mécanisme MLIR existant, car de nombreuses passes utiles sont écrites en ciblant FuncOp et ModuleOp, et de nombreux pipelines de compilation s'attendent à ce que ces opérations soient présentes. Des garanties de compatibilité complète sont appliquées à ces opérations. Si ces opérations changent de manière incompatible (par exemple, suppression), des équivalents StableHLO seront ajoutés pour préserver la compatibilité.
CHLO
L'ensemble d'opérations CHLO contient des opérations de niveau supérieur qui se décomposent en StableHLO. Actuellement, aucune garantie de compatibilité n'est proposée pour les CHLO. Pour garantir la compatibilité, la passe clo-legalize-to-stablehlo doit être utilisée avant la sérialisation.
Opérations de forme
Dans la communauté, il est courant d'utiliser certaines opérations des dialectes MLIR de base dans des programmes StableHLO dynamiques pour effectuer des calculs de forme.
Le plus souvent, il s'agit d'opérations de dialecte shape
telles que shape_of
ou num_elements
, d'opérations de dialecte tensor
telles que dim
ou from_elements
, et du type index
intégré.
Le document Dynamism RFC > O2 les indique comme étant hors du champ d'application, mais certains types de prise en charge des types index
sont inclus à des fins d'interopérabilité. Aucune garantie de compatibilité n'est fournie pour ces opérations ou types. La passe shape-legalize-to-stablehlo peut être utilisée pour convertir ces opérations en opérations StableHLO entièrement compatibles.
Opérations obsolètes
Plusieurs opérations StableHLO héritées de MHLO sont obsolètes et en cours d'abandon de StableHLO. Pour en savoir plus sur ces suppressions, consultez la page Nettoyage de StableHLO v1.0 2283. Le problème de suivi de ces abandons est le n° 2340.
Ces opérations se répartissent en plusieurs catégories:
- Catégorie "Not in HLO" (Pas dans HLO) des opérations StableHLO : elles faisaient initialement partie de l'ensemble d'opérations StableHLO, mais ont ensuite été jugées inadaptées :
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(n° 3). - Opérations inutilisées : ces opérations ont peut-être été utiles à un moment donné, mais elles n'étaient pas suffisamment développées ou les pipelines qui les utilisaient ont été refactorisés pour ne plus en avoir besoin. Cela inclut
map
,tuple
(598),get_tuple_element
,rng
, les comparaisonscomplex
560 et la convolutionwindow_reversal
(1181).
Certaines de ces opérations peuvent être facilement supprimées, car elles peuvent être exprimées à l'aide d'opérations existantes (broadcast
, create_token
, cross-replica-sum
, dot
, unary_einsum
) et seront supprimées une fois la période de compatibilité existante écoulée (six mois). D'autres sont encore à l'étude pour être supprimées (einsum
, get_tuple_element
, map
, rng
, torch_index_select
, tuple
, complex
, comparaisons, window_reversal
). En attendant les commentaires de la communauté, ces opérations seront supprimées ou ajoutées à la spécification avec une prise en charge complète. Tant que ces futures opérations ne sont pas connues, leur compatibilité n'est garantie que pendant six mois.
Exécution
Exécution séquentielle
Un programme StableHLO est exécuté en fournissant des valeurs d'entrée à la fonction main
et en calculant les valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées en exécutant le graphique des opérations enracinées dans l'opération return
correspondante.
L'ordre d'exécution est défini par l'implémentation tant qu'il est aligné sur le flux de données, c'est-à-dire si les opérations sont exécutées avant leur utilisation. Dans StableHLO, toutes les opérations à effet secondaire consomment un jeton et en produisent un (plusieurs jetons peuvent être multiplexés en un seul jeton via after_all
). L'ordre d'exécution des effets secondaires est donc également aligné sur le flux de données. Par exemple, dans le programme ci-dessous, deux ordres d'exécution sont possibles: %0
→ %1
→ %2
→ return
et %1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Plus formellement, un processus StableHLO est une combinaison de : 1) un programme StableHLO, 2) des états d'exécution (non encore exécuté, déjà exécuté) et 3) des valeurs intermédiaires sur lesquelles le processus travaille.
Le processus commence par les valeurs d'entrée de la fonction main
, passe par le graphique des opérations mettant à jour les états des opérations et les valeurs intermédiaires, puis se termine par les valeurs de sortie. À déterminer pour une formalisation plus poussée (#484)
Exécution parallèle
Les programmes StableHLO peuvent être exécutés en parallèle, organisés dans une grille de processus 2D de num_replicas
par num_partitions
, qui ont tous deux le type ui32
.
Dans la grille de processus StableHLO, num_replicas * num_partitions
des processus StableHLO s'exécutent en même temps. Chaque processus possède un process_id = (replica_id, partition_id)
unique, où replica_id
dans replica_ids = range(num_replicas)
et partition_id
dans partition_ids = range(num_partitions)
, qui sont tous deux de type ui32
.
La taille de la grille de processus est connue de manière statique pour chaque programme (à l'avenir, nous prévoyons de la rendre explicite dans les programmes StableHLO 650), et la position dans la grille de processus est connue de manière statique pour chaque processus. Chaque processus a accès à sa position dans la grille de processus via les opérations replica_id
et partition_id
.
Dans la grille de processus, les programmes peuvent tous être identiques (dans le style "Programme unique, données multiples"), tous différents (dans le style "Programmes multiples, données multiples") ou quelque chose entre les deux. À l'avenir, nous prévoyons de prendre en charge d'autres idiomes permettant de définir des programmes StableHLO parallèles, y compris GSPMD (#619).
Dans la grille de processus, les processus sont pour la plupart indépendants les uns des autres. Ils ont des états d'opération distincts, des valeurs d'entrée/intermédiaire/sortie distinctes, et la plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.
Étant donné que l'exécution de la plupart des opérations n'utilise que les valeurs du même processus, il est généralement clair de faire référence à ces valeurs par leur nom.
Toutefois, lorsque vous décrivez la sémantique des opérations collectives, cela est insuffisant, et la notation name@process_id
est utilisée pour faire référence à la valeur name
dans un processus particulier. (Dans cette perspective, name
non qualifié peut être considéré comme un raccourci pour name@(replica_id(), partition_id())
.)
L'ordre d'exécution entre les processus est défini par l'implémentation, à l'exception de la synchronisation introduite par la communication point à point et les opérations collectives, comme décrit ci-dessous.
Communication point à point
Les processus StableHLO peuvent communiquer entre eux via des canaux StableHLO. Un canal est représenté par un identifiant positif de type si64
. Grâce à diverses opérations, il est possible d'envoyer des valeurs aux canaux et de les recevoir de ces derniers.
Une formalisation plus poussée, par exemple sur l'origine de ces ID de canal, la façon dont les programmes de processus en prennent connaissance et le type de synchronisation qu'ils introduisent, est à définir (numéro 484).
Communication en streaming
Chaque processus StableHLO a accès à deux interfaces de streaming :
- Infeed, qui peut être lu.
- Flux de sortie dans lequel des opérations d'écriture peuvent être effectuées.
Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et qui ont donc des processus à leurs deux extrémités, les flux entrants et sortants ont leur autre extrémité définie par l'implémentation.
Une formalisation plus poussée, par exemple sur l'impact de la communication en streaming sur l'ordre d'exécution et le type de synchronisation qu'elle introduit, est à définir (numéro 484).
Opérations collectives
StableHLO propose six opérations collectives: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
et reduce_scatter
. Toutes ces opérations divisent les processus de la grille de processus StableHLO en groupes de processus StableHLO et exécutent un calcul conjoint dans chaque groupe de processus, indépendamment des autres groupes de processus.
Dans chaque groupe de processus, les opérations collectives peuvent entraîner une barrière de synchronisation. Une formalisation plus poussée, par exemple en précisant le moment exact où cette synchronisation se produit, comment les processus arrivent exactement à cette barrière et ce qui se passe s'ils ne le font pas, est à déterminer (numéro 484).
Si le groupe de processus implique une communication interpartition, c'est-à-dire que des processus du groupe de processus ont des ID de partition différents, l'exécution de l'opération collective nécessite un canal, et l'opération collective doit fournir un channel_id
positif de type si64
. La communication entre les réplications n'a pas besoin de canaux.
Les calculs effectués par les opérations collectives sont spécifiques aux opérations individuelles et sont décrits dans les sections d'opérations individuelles ci-dessus. Toutefois, les stratégies selon lesquelles la grille de processus est divisée en groupes de processus sont partagées entre ces opérations et sont décrites dans cette section. Plus formellement, StableHLO prend en charge les quatre stratégies suivantes.
cross_replica
Seules les communications interrépliques ont lieu au sein de chaque groupe de processus. Cette stratégie prend replica_groups
(une liste de listes d'ID de réplicas) et calcule un produit cartésien de replica_groups
par partition_ids
. Les replica_groups
doivent comporter des éléments uniques et couvrir tous les replica_ids
. Plus formellement, en utilisant
la syntaxe Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]]
et num_partitions = 2
, cross_replica
génère [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Seules les communications interpartitions ont lieu au sein de chaque groupe de processus. Cette stratégie prend partition_groups
(une liste de listes d'ID de partition) et calcule un produit cartésien de partition_groups
par replica_ids
.
Les partition_groups
doivent comporter des éléments uniques et couvrir tous les partition_ids
.
Plus formellement, en utilisant la syntaxe Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour partition_groups = [[0, 1]]
et num_replicas = 4
, cross_partition
produit [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Des communications inter-réplicas et inter-partitions peuvent se produire dans chaque groupe de processus. Cette stratégie prend replica_groups
, une liste de listes d'ID de réplication, et calcule les produits cartésiens de chaque replica_group
par partition_ids
. Les replica_groups
doivent comporter des éléments uniques et couvrir tous les replica_ids
. Plus formellement, en utilisant la syntaxe Python :
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]]
et num_partitions = 2
, cross_replica_and_partition
produit [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Cette stratégie prend flattened_id_groups
(une liste de listes d'ID de processus "aplatis" sous la forme de replica_id * num_partitions + partition_id
) et les transforme en ID de processus. flattened_id_groups
doit comporter des éléments uniques et couvrir tous les process_ids
. Plus formellement, en utilisant la syntaxe Python :
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
et num_partitions = 2
, flattened_ids
produit [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Précision
Pour le moment, StableHLO ne fournit aucune garantie concernant l'exactitude numérique, mais cela peut changer à l'avenir (numéro 1156).
Sémantique d'exécution de l'opération quantifiée
L'interprétation des opérations StableHLO quantifiées peut varier en fonction des capacités et de la configuration matérielle requise. Par exemple, certains matériels peuvent choisir d'interpréter les opérations quantiques à l'aide d'une stratégie de "déquantisation, exécution d'une opération à virgule flottante et, enfin, quantification". D'autres peuvent effectuer l'intégralité du calcul avec une méthode arithmétique d'entiers. Par conséquent, l'interprétation des opérations StableHLO échantillonnées est déterminée exclusivement par l'implémentation spécifique. L'interprétation de la quantification hybride (#1575) doit être basée sur sa sémantique, telle que prescrite dans la spécification (via 1792).
Erreurs
Les programmes StableHLO sont validés via un ensemble étendu de contraintes pour les opérations individuelles, ce qui exclut de nombreuses classes d'erreurs avant l'exécution. Toutefois, les conditions d'erreur sont toujours possibles, par exemple via des débordements d'entiers, des accès hors limites, etc. Sauf mention explicite, toutes ces erreurs entraînent un comportement défini par l'implémentation, mais cela peut changer à l'avenir (numéro 1157).
Exceptions à virgule flottante
À titre d'exception à cette règle, les exceptions à virgule flottante dans les programmes StableHLO ont un comportement bien défini. Les opérations qui génèrent des exceptions définies par la norme IEEE-754 (opération non valide, division par zéro, débordement, sous-dépassement ou exceptions inexactes) génèrent des résultats par défaut (tels que définis dans la norme) et poursuivent l'exécution sans lever le drapeau d'état correspondant, semblable à la gestion des exceptions raiseNoFlag
de la norme. Les exceptions pour les opérations non standards (par exemple, les opérations arithmétiques complexes et certaines fonctions transcendantes) sont définies par l'implémentation.
Incohérences de forme
StableHLO accepte les Tensors de forme dynamique. Toutefois, les formes doivent être cohérentes au moment de l'exécution, sinon le comportement n'est pas défini. StableHLO ne fournit pas explicitement une opération pouvant affirmer qu'un tenseur a une forme donnée au moment de l'exécution. La génération du code correct relève de la responsabilité du producteur.
Par exemple, le programme ci-dessous est valide. Toutefois, au moment de l'exécution, les formes exactes de %arg0
et %arg1
doivent être identiques, sinon le comportement du programme n'est pas défini:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Pour décrire la syntaxe, ce document utilise le type ISO modifié de la syntaxe EBNF (ISO/IEC 14977:1996, Wikipédia), avec deux modifications: 1) les règles sont définies à l'aide de ::=
plutôt que de =
,
2) La concaténation est exprimée à l'aide de la juxtaposition plutôt que de ,
.
Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constants" et "Ops"), nous utilisons des formules basées sur la syntaxe Python étendue avec la prise en charge de l'expression concise des opérations de tableau, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans de rares cas où des extraits de code plus volumineux sont nécessaires, nous utilisons la syntaxe Python standard, qui est toujours introduite explicitement.
Formules
Découvrons comment fonctionnent les formules à l'aide d'un exemple tiré des spécifications dot_general
. L'une des contraintes de cette opération se présente comme suit :
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Les noms utilisés dans cette formule proviennent de deux sources: 1) les fonctions globales, c'est-à-dire dim
, 2) les définitions des membres de l'élément de programme correspondant, c'est-à-dire les entrées lhs
, lhs_batching_dimensions
, rhs
et rhs_batching_dimensions
définies dans la section "Inputs" (Entrées) de dot_general
.
Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python, avec certaines extensions axées sur la concision. Pour comprendre la formule, transformons-la en syntaxe Python standard.
A) Dans ces formules, nous utilisons =
pour représenter l'égalité. La première étape pour obtenir la syntaxe Python consiste donc à remplacer =
par ==
, comme suit : dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) De plus, ces formules acceptent les points de suspension (...
) qui transforment les expressions scalaires en expressions de Tensor. En résumé, f(xs...)
signifie approximativement "pour chaque scalaire x
dans le tenseur xs
, calculer un scalaire f(x)
, puis renvoyer tous ces résultats scalaires ensemble en tant que résultat de tenseur". Dans la syntaxe Python standard, notre exemple de formule se transforme en : [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Grâce aux ellipses, il est souvent possible d'éviter de travailler au niveau des scalaires individuels. Toutefois, dans certains cas délicats, une syntaxe semi-informelle de niveau inférieur peut être utilisée, comme dans la formule start_indices[bi0, ..., :, ..., biN]
de la spécification gather
. Par souci de concision, nous ne fournissons pas de formalisme exact pour traduire cette syntaxe en Python standard, dans l'espoir qu'elle reste intuitivement compréhensible au cas par cas.
N'hésitez pas à nous contacter si certaines formules spécifiques semblent opaques. Nous ferons de notre mieux pour les améliorer.
Vous remarquerez également que les formules utilisent des points de suspension pour développer toutes sortes de listes, y compris des tenseurs, des listes de tenseurs (qui peuvent par exemple provenir d'un nombre variadique de tenseurs), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas de formalisme exact (par exemple, les listes ne font même pas partie du système de types StableHLO) et nous nous appuyons plutôt sur une compréhension intuitive.
C) Le dernier mécanisme de notation que nous utilisons est la diffusion implicite. Bien que l'opset StableHLO n'accepte pas la diffusion implicite, les formules le font, également au service de la concision. En résumé, si un scalaire est utilisé dans un contexte où un tenseur est attendu, le scalaire est diffusé dans la forme attendue.
Pour poursuivre l'exemple dot_general
, voici une autre contrainte : 0 <= lhs_batching_dimensions < rank(lhs)
. Comme défini dans la spécification dot_general
, lhs_batching_dimensions
est un tenseur, mais 0
et rank(lhs)
sont des scalaires. Une fois la diffusion implicite appliquée, la formule devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Lorsqu'elle est appliquée à une opération dot_general
particulière, cette formule est évaluée selon un Tensor de valeurs booléennes. Lorsque des formules sont utilisées comme contraintes, la contrainte est valide si la formule s'évalue à true
ou à un tenseur qui ne comporte que des éléments true
.
Noms
Dans les formules, le champ lexical inclut : 1) les fonctions globales, 2) les définitions de membres,
3) définitions locales. La liste des fonctions globales est fournie ci-dessous. La liste des définitions d'éléments dépend de l'élément du programme auquel la notation est appliquée:
- Pour les opérations, les définitions des membres incluent les noms introduits dans les sections "Entrées" et "Sorties".
- Pour tout le reste, les définitions de membres incluent des parties structurelles de l'élément de programme, nommées d'après les non-terminaux EBNF correspondants. La plupart du temps, les noms de ces parties structurelles sont obtenus en convertissant les noms des non-terminaux en snake case (par exemple,
IntegerLiteral
=>integer_literal
), mais parfois les noms sont abrégés dans le processus (par exemple,QuantizationStorageType
=>storage_type
). Dans ce cas, les noms sont introduits explicitement de la même manière que les sections "Inputs" / "Outputs" dans les opérations et les sorties. - De plus, les définitions de membres incluent toujours
self
pour faire référence à l'élément de programme correspondant.
Valeurs
Lorsque les formules sont évaluées, elles fonctionnent avec les types de valeurs suivants :
1) Value
(valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
; leurs types sont toujours connus),
2) Placeholder
(valeurs futures, par exemple lhs
, rhs
ou result
; leurs valeurs réelles ne sont pas encore connues, seuls leurs types le sont),
3) Type
(types tels que définis dans la section "Types"),
4) Function
(fonctions globales telles que définies dans la section "Fonctions").
Selon le contexte, les noms peuvent faire référence à différentes valeurs. Plus précisément, la section "Sémantique" des opérations (et des équivalents pour les autres éléments de programme) définit la logique d'exécution. Toutes les entrées sont donc disponibles en tant que Value
.
En revanche, la section "Contraintes" des opérations (et des équivalents) définit une logique "au moment de la compilation", c'est-à-dire quelque chose qui est généralement exécuté avant l'exécution. Par conséquent, seules les entrées constantes sont disponibles en tant que Value
et les autres entrées ne sont disponibles qu'en tant que Placeholder
.
Noms | Dans "Sémantique" | Dans "Contraintes" |
---|---|---|
Fonctions globales | Function |
Function |
Entrées constantes | Value |
Value |
Entrées non constantes | Value |
Placeholder |
Sorties | Value |
Placeholder |
Définitions locales | Cela dépend de la définition | Cela dépend de la définition |
Prenons un exemple d'opération transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Pour cette opération, permutation
est une constante. Il est donc disponible en tant que Value
, tant au niveau de la sémantique que des contraintes. En revanche, operand
et result
sont disponibles en tant que Value
en sémantique, mais uniquement en tant que Placeholder
dans les contraintes.
Fonctions
Construction de types
Aucune fonction ne peut être utilisée pour créer des types. À la place, nous utilisons directement la syntaxe de type, car elle est généralement plus concise. Par exemple, (tensor<E>, tensor<E>) -> (tensor<E>)
plutôt que function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Fonctions sur les types
element_type
est défini sur les types de tenseurs et les types de tenseurs quantifiés, et renvoie respectivement la partieTensorElementType
ouQuantizedTensorElementType
de l'TensorType
ou de l'QuantizedTensorType
correspondant.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
vérifie si le typex
peut être promu au typey
. Lorsquex
ety
sont desQuantizedTensorElementType
, la promotion ne s'applique qu'austorage_type
. Cette version spécifique de la promotion est actuellement utilisée dans le contexte du calcul de la réduction (pour en savoir plus, consultez la RFC).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponible pour tous les types. Par exemple,is_float(x)
renvoietrue
six
est unFloatType
. Six
est une valeur ou un espace réservé, cette fonction est un raccourci pouris_type_name(type(x))
.max_value(x: Type) -> Value
renvoie la valeur maximale d'unTensorElementType
. Six
n'est pas unTensorElementType
, renvoieNone
.min_value(x: Type) -> Value
renvoie la valeur minimale possible d'unTensorElementType
. Six
n'est pas unTensorElementType
, renvoieNone
.member_name(x: Value | Placeholder | Type) -> Any
. Disponible pour toutes les définitions de membresmember_name
de tous types. Par exemple,tensor_element_type(x)
renvoie la partieTensorElementType
d'unTensorType
correspondant. Six
est une valeur ou un espace réservé, cette fonction est un raccourci pourmember_name(type(x))
. Six
n'est pas un type qui possède un membre approprié, ou une valeur ou un espace réservé de ce type, renvoieNone
.is_empty_algorithm(*args: Type)
vérifie si tous les champs de l'algorithme de point sont définis surNone
. Cela est nécessaire, car les algorithmes de point ont des comportements par défaut définis par l'implémentation. Il serait donc incorrect de spécifier une valeur par défaut.
Construction des valeurs
operation_name(*xs: Value | Type) -> Value
. Disponible pour toutes les opérations. Par exemple,add(lhs, rhs)
prend deux valeurs de tenseurlhs
etrhs
, puis renvoie le résultat de l'évaluation de l'opérationadd
avec ces entrées. Pour certaines opérations (par exemple,broadcast_in_dim
), leurs sorties sont de type "chargeur", c'est-à-dire nécessaires pour évaluer une opération. Dans ce cas, la fonction utilise ces types comme arguments.
Fonctions sur les valeurs
Tous les opérateurs et fonctions de Python sont disponibles. Par exemple, les notations subscription (abonnement) et slicing (tranchage) de Python sont disponibles pour indexer des tenseurs, des tenseurs quantiques et des tupels.
to_destination_type(x: Value, destination_type: Type) -> Value
est défini sur les tenseurs et renvoie la valeur convertie dex
en fonction detype(x)
etdestination_type
comme suit:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Nous discutons actuellement de la fusion des opérations convert
, uniform_quantize
et uniform_dequantize
(numéro 1576).
Après la fusion, nous n'avons plus besoin de la fonction ci-dessus et pouvons utiliser le nom de l'opération pour convert
à la place.
is_nan(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si tous les éléments dex
sontNaN
oufalse
dans les autres cas. Six
n'est pas un tenseur, renvoieNone
.is_sorted(x: Value) -> Value
est défini sur les tenseurs et renvoietrue
si les éléments dex
sont triés par ordre croissant par rapport à l'ordre lexicographique croissant de leurs indices, oufalse
dans le cas contraire. Six
n'est pas un tenseur, renvoieNone
.is_unique(x: Value) -> Value
est défini sur les tenseurs et renvoietrue
six
ne comporte pas d'éléments en double, oufalse
dans le cas contraire. Six
n'est pas un tenseur, renvoieNone
.member_name(x: Value) -> Any
est défini pour toutes les définitions de membresmember_name
de toutes les valeurs. Par exemple,real_part(x)
renvoie la partieRealPart
d'unComplexConstant
correspondant. Six
n'est pas une valeur associée à un membre approprié, renvoieNone
.same(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si les éléments dex
sont tous égaux les uns aux autres, oufalse
dans le cas contraire. Si le tenseur ne comporte pas d'éléments, il est considéré comme "tous égaux les uns aux autres", c'est-à-dire que la fonction renvoietrue
. Six
n'est pas un Tensor, la fonction renvoieNone
.split(x: Value, num_results: Value, axis: Value) -> Value
est défini sur les tenseurs et renvoie des tranchesnum_results
dex
le long de l'axeaxis
. Six
n'est pas un tenseur oudim(x, axis) % num_results != 0
, renvoieNone
.is_defined_in_parent_scope(x: Value) -> Value
est défini sur des chaînes et renvoietrue
six
est le nom d'une fonction définie dans le même champ d'application que la fonction parente de l'opération concernée.is_namespaced_op_name(x: Value) -> Value
est défini sur des chaînes et renvoietrue
six
est un nom d'opération valide, c'est-à-dire qu'il respecte l'expression régulière suivante :[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Calculs de forme
axes(x: Value | Placeholder | Type) -> Value
est un raccourci pourrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
est un raccourci pourshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
est un raccourci pourlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
est défini sur les Tensors et renvoie les indexsize(x)
pour leTensorType
correspondant, trié par ordre lexicographique croissant, c'est-à-dire[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Six
n'est pas un type de tenseur, un type de tenseur quantifié, une valeur ou un espace réservé de l'un de ces types,None
est renvoyé.rank(x: Value | Placeholder | Type) -> Value
est un raccourci poursize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
est défini dans la section "Fonctions sur les types" viamember_name
.size(x: Value | Placeholder | Type) -> Value
est un raccourci pourreduce(lambda x, y: x * y, shape(x))
.
Calculs de quantification
def baseline_element_type(x: Value | Placeholder | Type) -> Type
est un raccourci pourelement_type(baseline_type(x))
.baseline_type
est défini sur les types de tenseurs et les types de tenseurs quantifiés, et les transforme en "ligne de base", c'est-à-dire en un type de la même forme, mais avec les paramètres de quantification du type d'élément réinitialisés aux valeurs par défaut. Il s'agit d'une astuce pratique pour comparer les types de Tensors et quantifiés de manière uniforme, ce qui est assez souvent nécessaire. Pour les types quantifiés, cela permet de comparer les types en ignorant les paramètres de quantification, c'est-à-dire queshape
,storage_type
,expressed_type
,storage_min
,storage_max
etquantization_dimension
(pour le type quantifié par axe) doivent tous correspondre, maisscales
etzero points
peuvent différer.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
est défini sur les types de tenseurs quantifiés et les convertit en types de tenseurs à virgule flottante. Pour ce faire, les éléments quantifiés, qui représentent les valeurs entières du type de stockage, sont convertis en valeurs à virgule flottante correspondantes du type exprimé à l'aide du point zéro et de la mise à l'échelle associés au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
est défini sur les types de tenseurs à virgule flottante et les transforme en types de tenseurs quantifiés. Cela se produit en convertissant les valeurs à virgule flottante du type exprimé en valeurs entières correspondantes du type de stockage en utilisant le point zéro et l'échelle associées au type d'élément quantifié.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
permet de spécifier des calculs par élément sur des tenseurs quantifiés. Il déquantifie, c'est-à-dire qu'il transforme les éléments quantifiés en types exprimés, puis effectue une opération, puis quantifie, c'est-à-dire qu'il transforme les résultats en types de stockage. Pour le moment, cette fonction ne fonctionne que pour la quantification par tenseur. La quantification par axe est en cours (numéro 1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
permet de spécifier une quantification uniquement pour les poids pour une opération hybride qui accepte la valeur gauche en virgule flottante et la valeur droite en types quantifiés. Il déquantifie les entrées quantifiées en types exprimés et effectue des calculs avec float. Le type d'élément du tenseur de gauche à virgule flottante et le type exprimé du tenseur de droite quantifié doivent être identiques.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Calculs de grille
cross_partition(replica_groups: Value) -> Value
. Consultez la section "cross_replica" ci-dessus.cross_replica(replica_groups: Value) -> Value
. Consultez la section "cross_replica" ci-dessus.cross_replica_and_partition(replica_groups: Value) -> Value
. Consultez la section "cross_replica_and_partition" ci-dessus.flattened_ids(replica_groups: Value) -> Value
. Consultez la section "flattened_ids" ci-dessus.
Dynamique
Les valeurs StableHLO peuvent avoir des tailles de dimension dynamiques, par exemple tensor<?xi64>
.
Toutefois, les valeurs StableHLO ne peuvent pas avoir un nombre dynamique de dimensions (dynamisme non classé, par exemple tensor<*xi64>
). Les opérandes et les résultats sont autorisés à utiliser des tailles de dimension dynamiques, même si des contraintes s'appliquent à ces tailles. Les contraintes sont vérifiées de manière statique si possible. Sinon, elles sont différées jusqu'à l'exécution, et les incohérences entraînent un comportement non défini. Vous trouverez des exemples ci-dessous.
Incohérences de forme pour les opérations unaires par élément
Prenons l'exemple du programme de jouets suivant:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Un tel programme est inhabituel, car il n'est pas courant de connaître la forme du résultat, mais pas celle de l'entrée. Néanmoins, il s'agit d'un programme StableHLO valide. Il n'est pas possible de valider de manière statique l'opération abs
dans ce programme, car la forme exacte de l'opérande est inconnue. Toutefois, les formes sont certainement compatibles, et cela peut être vérifié de manière statique : ?
peut s'avérer être 2
au moment de l'exécution, et il n'y aura aucun problème. Toutefois, ?
peut également s'avérer être un autre entier, auquel cas le comportement est indéfini.
Notez que si la taille d'une dimension est dynamique dans le résultat, il ne peut pas y avoir de comportement indéfini. En effet, il n'existe pas de taille "attendue", il ne peut donc pas y avoir de non-concordance.
Incohérences de forme pour les opérations binaires par élément
Prenons l'exemple de programme suivant :
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
En ce qui concerne les opérations binaires par élément, les formes des entrées et du résultat doivent correspondre au moment de l'exécution. Au moment de la compilation, les dimensions statiques doivent être égales, sinon elles doivent simplement être compatibles. Si n'importe quelle dimension est dynamique dans les entrées, il peut y avoir un comportement non défini au moment de l'exécution, car la taille dynamique peut ne pas correspondre à la taille correspondante dans l'autre opérande (statique ou dynamique). Si toutes les entrées sont statiques, le fait que le résultat soit dynamique ou non n'a pas d'importance: les dimensions connues de manière statique seront vérifiées de manière statique, et les dimensions dynamiques n'imposent aucune contrainte.
Incohérences de forme pour les opérations qui utilisent leur forme de sortie comme opérande
Prenons l'exemple du programme de jouets suivant:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Les valeurs de l'opérande de forme au moment de l'exécution doivent correspondre à la forme du résultat, sinon le comportement n'est pas défini. Autrement dit, au moment de l'exécution, %arg0
doit avoir une valeur de dense<[3, 4]> : tensor<2xi32>
. Si l'opérande de forme est une constante, cela peut être vérifié de manière statique. Si la forme du résultat est entièrement dynamique, il ne peut pas y avoir de divergence.