StableHLO est un ensemble d'opérations pour les opérations de haut niveau (HLO) dans les modèles de machine learning (ML). StableHLO sert de couche de portabilité entre différents frameworks et compilateurs de ML : les frameworks de ML qui produisent des programmes StableHLO sont compatibles avec les compilateurs de ML qui consomment des programmes StableHLO.
Notre objectif est de simplifier et d'accélérer le développement du ML en créant une meilleure interopérabilité entre les différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et les compilateurs de ML (tels que XLA et IREE). À cette fin, ce document fournit une spécification pour le langage de programmation StableHLO.
Cette spécification comporte trois sections principales. Tout d'abord, la section Programmes décrit la structure des programmes StableHLO, qui se composent de fonctions StableHLO, elles-mêmes composées d'opérations StableHLO. Dans cette structure, la section Ops spécifie la sémantique des opérations individuelles. La section Exécution fournit la sémantique pour toutes ces opérations exécutées ensemble dans un programme. Enfin, la section Notation aborde la notation utilisée tout au long de la spécification.
Pour afficher les spécifications d'une version précédente de StableHLO, ouvrez le dépôt au niveau de la version taguée qui vous intéresse. Par exemple, la spécification StableHLO v0.19.0. Pour afficher les modifications apportées à chaque mise à jour mineure de StableHLO, consultez le journal des versions dans VhloDialect.td.
Programmes
Program ::= {Func}
Les programmes StableHLO se composent d'un nombre arbitraire de fonctions StableHLO.
Vous trouverez ci-dessous un exemple de programme avec une fonction @main qui comporte trois entrées (%image, %weights et %bias) et une sortie. Le corps de la fonction comporte six opérations.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Fonctions
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Les fonctions StableHLO (également appelées fonctions nommées) comportent un identifiant, des entrées/sorties et un corps. À l'avenir, nous prévoyons d'introduire des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, #740, #744).
Identifiants
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Les identifiants StableHLO sont semblables aux identifiants de nombreux langages de programmation, avec deux particularités : 1) tous les identifiants ont des sigles qui distinguent les différents types d'identifiants, 2) les identifiants de valeur peuvent être entièrement numériques pour simplifier la génération de programmes StableHLO.
Types
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Les types StableHLO sont classés dans les types de valeurs (également appelés types de première classe) qui représentent les valeurs StableHLO et les types non-valeurs qui décrivent d'autres éléments du programme. Les types StableHLO sont semblables à ceux de nombreux langages de programmation, la principale particularité étant la nature spécifique au domaine de StableHLO, qui entraîne des résultats inhabituels (par exemple, les types scalaires ne sont pas des types de valeurs).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Les types Tensor représentent des tenseurs, c'est-à-dire des tableaux multidimensionnels. Ils ont une forme et un type d'élément, où une forme représente des tailles de dimensions non négatives ou inconnues dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0 à R-1. Le nombre de dimensions R est appelé rang. Par exemple, tensor<2x3xf32> est un type de Tensor avec une forme 2x3 et un type d'élément f32. Il comporte deux dimensions (ou, en d'autres termes, deux axes) : la dimension 0 et la dimension 1, dont les tailles sont respectivement de 2 et 3. Son rang est 2.
Les formes peuvent être partiellement ou complètement inconnues (dynamiques), par exemple tensor<?x2xf64> est partiellement inconnue et tensor<?x?xf64> est complètement inconnue. Les tailles de dimensions dynamiques sont représentées à l'aide d'un ?. Il n'est pas possible de supprimer le classement des formes.
À l'avenir, nous prévoyons d'étendre les types de Tensor au-delà des tailles de dimension et des types d'éléments, par exemple pour inclure les mises en page (#629) et la parcimonie (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| Nom | Type | Contraintes |
|---|---|---|
storage_type |
type entier | (C1-C3), (C8) |
storage_min |
constante entière | (C1), (C3), (C7) |
storage_max |
constante entière | (C2), (C3), (C7) |
expressed_type |
type à virgule flottante | (C4) |
quantization_dimension |
constante entière facultative | (C10-C12) |
scales |
Nombre variadique de constantes à virgule flottante | (C4-C6), (C9), (C10), (C13) |
zero_points |
nombre variable de constantes entières | (C7-C9) |
Les types d'éléments quantifiés représentent des valeurs entières d'un type de stockage dans la plage allant de storage_min à storage_max (inclus) qui correspondent à des valeurs à virgule flottante d'un type exprimé. Pour une valeur entière donnée i, la valeur à virgule flottante correspondante f peut être calculée sous la forme f = (i - zero_point) * scale, où scale et zero_point sont appelés paramètres de quantification. Les éléments storage_min et storage_max sont facultatifs dans la grammaire, mais ont respectivement les valeurs par défaut min_value(storage_type) et max_value(storage_type). Les types d'éléments quantifiés sont soumis aux contraintes suivantes :
- (C1)
type(storage_min) = storage_type. - (C2)
type(storage_max) = storage_type. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type). - (C4)
type(scales...) = expressed_type. - (C5)
0 < scales. - (C6)
is_finite(scales...). - (C7)
storage_min <= zero_points <= storage_max. - (C8)
type(zero_points...) = storage_type. - (C9)
size(scales) = size(zero_points). - (C10) Si
is_empty(quantization_dimension), alorssize(scales) = 1. - (C11)
0 <= quantization_dimension.
Pour le moment, QuantizationScale est une constante à virgule flottante, mais il existe un fort intérêt pour les échelles basées sur des nombres entiers, représentées par des multiplicateurs et des décalages. Nous prévoyons d'étudier cette possibilité dans un avenir proche (#1404).
Une discussion est en cours sur la sémantique de QuantizationZeroPoint, y compris le type, les valeurs et la possibilité d'avoir un ou plusieurs points zéro dans un type de Tensor quantifié. En fonction des résultats de cette discussion, la spécification concernant les zéro points pourra être modifiée à l'avenir (#1405).
Une autre discussion en cours porte sur la sémantique de QuantizationStorageMin et QuantizationStorageMax afin de déterminer si des contraintes doivent être imposées à ces valeurs et à celles des Tensors quantifiés (#1406).
Enfin, nous prévoyons d'explorer la représentation des échelles inconnues et des points zéro, de la même manière que nous prévoyons d'explorer la représentation des tailles de dimensions inconnues (#1407).
Les types de tenseurs quantifiés représentent des tenseurs avec des éléments quantifiés. Ces Tensors sont exactement les mêmes que les Tensors standards, sauf que leurs éléments ont des types d'éléments quantifiés au lieu de types d'éléments standards.
Dans les Tensors quantifiés, la quantification peut être par Tensor, c'est-à-dire avec un scale et un zero_point pour l'ensemble du Tensor, ou par axe, c'est-à-dire avec plusieurs scales et zero_points, une paire par tranche d'une dimension particulière quantization_dimension. Plus formellement, dans un Tensor t avec quantification par axe, il existe des tranches dim(t, quantization_dimension) de quantization_dimension : t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], etc. Tous les éléments de la tranche i utilisent scales[i] et zero_points[i] comme paramètres de quantification. Les types de Tensor quantifiés sont soumis aux contraintes suivantes :
- Pour la quantification par Tensor :
- Aucune autre contrainte.
- Pour la quantification par axe :
- (C12)
quantization_dimension < rank(self). - (C13)
dim(self, quantization_dimension) = size(scales).
- (C12)
TokenType ::= 'token'
Les types de jetons représentent des jetons, c'est-à-dire des valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer un ordre d'exécution aux opérations, comme décrit dans la section Exécution.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Les types de tampon représentent des tampons. Par exemple, dans XLA, les tampons sont des tableaux multidimensionnels avec un stockage cohérent. Comme les types Tensor, les types Buffer ont une forme et un type d'élément, où une forme représente des tailles de dimension non négatives ou inconnues dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0 à R-1. Le nombre de dimensions R est appelé rang. Par exemple, memref<2x3xf32> est un type de tampon avec une forme 2x3 et un type d'élément f32. Il comporte deux dimensions (ou, en d'autres termes, deux axes) : la dimension 0 et la dimension 1, dont les tailles sont respectivement de 2 et 3. Son rang est 2.
Les tampons peuvent être alloués à l'aide d'un custom_call à CreateBuffer ou Pin, et désalloués à l'aide d'un custom_call à Unpin. Seules les opérations custom_call peuvent lire et écrire le contenu des tampons. Pour en savoir plus, consultez custom_call.
Les types de tuples représentent des tuples, c'est-à-dire des listes hétérogènes. Les tuples sont une fonctionnalité ancienne qui n'existe que pour la compatibilité avec HLO. Dans HLO, les tuples sont utilisés pour représenter les entrées et sorties variadiques. Dans StableHLO, les entrées et sorties variadiques sont prises en charge de manière native. La seule utilisation des tuples dans StableHLO est de représenter de manière exhaustive l'ABI HLO où, par exemple, T, tuple<T> et tuple<tuple<T>> peuvent être très différents selon une implémentation particulière. À l'avenir, nous prévoyons de modifier l'ABI HLO, ce qui pourrait nous permettre de supprimer les types de tuples de StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Les types d'éléments représentent les éléments des types Tensor. Contrairement à de nombreux langages de programmation, ces types ne sont pas de première classe dans StableHLO. Cela signifie que les programmes StableHLO ne peuvent pas représenter directement les valeurs de ces types (par conséquent, il est idiomatique de représenter les valeurs scalaires de type T avec des valeurs de tenseur de dimension 0 de type tensor<T>).
- Le type booléen représente les valeurs booléennes
trueetfalse. - Les types entiers peuvent être signés (
si) ou non signés (ui) et avoir l'une des largeurs de bits acceptées (2,4,8,16,32ou64). Les typessiNsignés représentent des valeurs entières allant de-2^(N-1)à2^(N-1)-1inclus, et les typesuiNnon signés représentent des valeurs entières allant de0à2^N-1inclus. - Les types à virgule flottante peuvent être l'un des suivants :
f8E3M4,f8E4M3etf8E5M2sont des nombres à virgule flottante de 8 bits suivant les conventions IEEE-754.- Types
f8E4M3FNetf8E5M2correspondant respectivement aux encodagesE4M3etE5M2du format FP8 décrits dans Formats FP8 pour le deep learning. - Types
f8E4M3FNUZetf8E5M2FNUZcorrespondant aux encodagesE4M3etE5M2des formats FP8 décrits dans Formats numériques 8 bits pour les réseaux de neurones profonds. - Type
f8E4M3B11FNUZcorrespondant à l'encodageE4M3des formats FP8 décrits dans Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks. - Type
bf16correspondant au formatbfloat16décrit dans BFloat16 : le secret des hautes performances sur les Cloud TPU. - Les types
f16,f32etf64correspondent respectivement aux formatsbinary16("demi-précision"),binary32("simple précision") etbinary64("double précision") décrits dans la norme IEEE 754. - Le type
tf32correspond au format TensorFloat32 et est limité dans StableHLO. - Types MX (microsizing)
f4E2M1FN,f6E2M3FN,f6E3M2FNetf8E8M0FNUdécrits dans la spécification des formats de microsizing OCP.
- Les types complexes représentent des valeurs complexes qui ont une partie réelle et une partie imaginaire du même type d'élément. Les types complexes acceptés sont
complex<f32>(les deux parties sont de typef32) etcomplex<f64>(les deux parties sont de typef64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Les types de fonction représentent les fonctions nommées et anonymes. Ils ont des types d'entrée (la liste des types à gauche de ->) et des types de sortie (la liste des types à droite de ->). Dans de nombreux langages de programmation, les types de fonctions sont de première classe, mais pas dans StableHLO.
StringType ::= 'string'
Le type String représente des séquences d'octets. Contrairement à de nombreux langages de programmation, le type de chaîne n'est pas de première classe dans StableHLO et n'est utilisé que pour spécifier des métadonnées statiques pour les éléments de programme.
Opérations
Les opérations StableHLO (également appelées ops) représentent un ensemble fermé d'opérations de haut niveau dans les modèles de machine learning. Comme indiqué ci-dessus, la syntaxe StableHLO s'inspire fortement de MLIR, qui n'est pas forcément l'alternative la plus ergonomique, mais qui est sans doute la mieux adaptée à l'objectif de StableHLO, qui est de créer une meilleure interopérabilité entre les frameworks de ML et les compilateurs de ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Les opérations StableHLO (également appelées ops) ont un nom, des entrées/sorties et une signature. Le nom se compose du préfixe stablehlo. et d'un mnémonique qui identifie de manière unique l'une des opérations compatibles. Vous trouverez ci-dessous la liste complète des opérations compatibles.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Les opérations consomment des entrées et produisent des sorties. Les entrées sont classées dans les catégories suivantes : valeurs d'entrée (calculées lors de l'exécution), fonctions d'entrée (fournies de manière statique, car dans StableHLO, les fonctions ne sont pas des valeurs de première classe) et attributs d'entrée (également fournis de manière statique). Le type d'entrées et de sorties consommées et produites par une opération dépend de son mnémonique. Par exemple, l'opération add consomme deux valeurs d'entrée et produit une valeur de sortie. En comparaison, l'opération select_and_scatter consomme trois valeurs d'entrée, deux fonctions d'entrée et trois attributs d'entrée.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Les fonctions d'entrée (également appelées fonctions anonymes) sont très semblables aux fonctions nommées, sauf qu'elles n'ont pas d'identifiant (d'où le nom "anonyme") et qu'elles ne déclarent pas de types de sortie (les types de sortie sont déduits de l'opération return dans la fonction).
La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir la production Unused ci-dessus), qui est là pour la compatibilité avec MLIR. Dans MLIR, il existe un concept plus général de "régions" qui peuvent comporter plusieurs "blocs" d'opérations connectés entre eux par des opérations de saut. Ces blocs ont des ID qui correspondent à la production Unused, ce qui permet de les distinguer les uns des autres.
StableHLO ne comporte pas d'opérations de saut. La partie correspondante de la syntaxe MLIR n'est donc pas utilisée (mais elle est toujours présente).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Les attributs d'entrée ont un nom et une valeur qui correspond à l'une des constantes acceptées. Il s'agit du principal moyen de spécifier des métadonnées statiques pour les éléments de programme. Par exemple, l'opération concatenate utilise l'attribut dimension pour spécifier la dimension le long de laquelle ses valeurs d'entrée sont concaténées. De même, l'opération slice utilise plusieurs attributs tels que start_indices et limit_indices pour spécifier les limites utilisées pour segmenter la valeur d'entrée.
Pour le moment, les programmes StableHLO en circulation contiennent parfois des attributs qui ne sont pas décrits dans ce document. À l'avenir, nous prévoyons d'intégrer ces attributs à l'opset StableHLO ou de leur interdire d'apparaître dans les programmes StableHLO. En attendant, voici la liste de ces attributs :
layout(#629).mhlo.frontend_attributes(#628).mhlo.sharding(#619).output_operand_aliases(#740).- Métadonnées de localisation (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La signature d'opération se compose des types de toutes les valeurs d'entrée (la liste des types à gauche de ->) et des types de toutes les valeurs de sortie (la liste des types à droite de ->). À proprement parler, les types d'entrée sont redondants, et les types de sortie le sont presque toujours également (car pour la plupart des opérations StableHLO, les types de sortie peuvent être déduits des entrées). Néanmoins, la signature op fait délibérément partie de la syntaxe StableHLO pour assurer la compatibilité avec MLIR.
Vous trouverez ci-dessous un exemple d'op dont le mnémonique est select_and_scatter. Il consomme trois valeurs d'entrée (%operand, %source et %init_value), deux fonctions d'entrée et trois attributs d'entrée (window_dimensions, window_strides et padding). Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée (mais pas les types de fonctions et d'attributs d'entrée qui sont fournis en ligne).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Les constantes StableHLO ont un littéral et un type qui représentent ensemble une valeur StableHLO. En général, le type fait partie de la syntaxe constante, sauf lorsqu'il n'y a pas d'ambiguïté (par exemple, une constante booléenne a sans ambiguïté le type i1, tandis qu'une constante entière peut avoir plusieurs types possibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Les constantes booléennes représentent les valeurs booléennes true et false. Les constantes booléennes sont de type i1.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Les constantes entières représentent des valeurs entières sous forme de chaînes utilisant la notation décimale ou hexadécimale. Les autres bases, par exemple binaire ou octale, ne sont pas acceptées. Les constantes entières sont soumises aux contraintes suivantes :
- (C1)
is_wellformed(integer_literal, integer_type).
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Les constantes à virgule flottante représentent des valeurs à virgule flottante via des chaînes qui utilisent la notation décimale ou scientifique. De plus, la notation hexadécimale peut être utilisée pour spécifier directement les bits sous-jacents dans le format à virgule flottante du type correspondant. Les constantes à virgule flottante sont soumises aux contraintes suivantes :
- (C1) Si une notation non hexadécimale est utilisée,
is_wellformed(float_literal, float_type). - (C2) Si la notation hexadécimale est utilisée,
size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Les constantes complexes représentent des valeurs complexes à l'aide de listes composées d'une partie réelle (en premier) et d'une partie imaginaire (en second). Par exemple, (1.0, 0.0) : complex<f32> représente 1.0 + 0.0i et (0.0, 1.0) : complex<f32> représente 0.0 + 1.0i. L'ordre dans lequel ces parties sont ensuite stockées en mémoire est défini par l'implémentation. Les constantes complexes sont soumises aux contraintes suivantes :
- (C1)
is_wellformed(real_part, complex_element_type(complex_type)). - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Les constantes Tensor représentent les valeurs Tensor à l'aide de listes imbriquées spécifiées à l'aide de la notation NumPy. Par exemple, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> représente une valeur de Tensor avec le mappage suivant des index aux éléments : {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. L'ordre dans lequel ces éléments sont ensuite stockés en mémoire est défini par l'implémentation. Les constantes Tensor sont soumises aux contraintes suivantes :
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)), où :has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
- (C2)
has_shape(tensor_literal, shape(tensor_type)), où :has_shape(element_literal: Syntax, []) = true.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).- Sinon,
false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Les constantes de tenseur quantifié représentent des valeurs de tenseur quantifié en utilisant la même notation que les constantes de tenseur, avec des éléments spécifiés comme constantes de leur type de stockage. Les constantes de Tensor quantifié sont soumises aux contraintes suivantes :
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)). - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Les littéraux de chaîne sont constitués d'octets spécifiés à l'aide de caractères ASCII et de séquences d'échappement. Elles sont indépendantes de l'encodage. L'interprétation de ces octets est donc définie par l'implémentation. Les littéraux de chaîne sont de type string.
Opérations
abs
Sémantique
Effectue une opération abs au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les entiers signés : module entier.
- Pour les valeurs float :
absà partir de la norme IEEE-754. - Pour les nombres complexes : module complexe.
- Pour les types quantifiés :
dequantize_op_quantize(abs, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1-C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier signé ou à virgule flottante, ou tenseur quantifié par tenseur | (C1-C2) |
Contraintes
- (C1)
shape(result) = shape(operand). - (C2)
baseline_element_type(result)est défini comme suit :complex_element_type(element_type(operand))siis_complex(operand).- Sinon,
baseline_element_type(operand).
Exemples
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]
add
Sémantique
Effectue l'addition élément par élément de deux Tensors lhs et rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : OR logique.
- Pour les nombres entiers : addition d'entiers.
- Pour les valeurs float :
additionà partir de la norme IEEE-754. - Pour les nombres complexes : addition complexe.
- Pour les types quantifiés :
dequantize_op_quantize(add, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié | (C1-C6) |
| (I2) | rhs |
tenseur ou tenseur quantifié | (C1-C5), (C7) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C1-C7) |
Contraintes
- Si l'opération utilise des Tensors non quantifiés :
- (C1)
type(lhs) = type(rhs) = type(result).
- (C1)
- Si l'opération utilise des Tensors quantifiés :
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result). - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result). - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result). - (C6) Si
is_per_axis_quantized(lhs), alorsquantization_dimension(lhs) = quantization_dimension(result). - (C7) Si
is_per_axis_quantized(rhs), alorsquantization_dimension(rhs) = quantization_dimension(result).
- (C2)
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]
after_all
Sémantique
Garantit que les opérations produisant le inputs sont exécutées avant toute opération dépendant de result. L'exécution de cette opération n'a aucun effet. Elle n'existe que pour établir des dépendances de données entre result et inputs.
Entrées
| Libellé | Nom | Type |
|---|---|---|
| (I1) | inputs |
Nombre variadique de token |
Sorties
| Nom | Type |
|---|---|
result |
token |
Exemples
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token
all_gather
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs des Tensors operands de chaque processus le long de all_gather_dim et produit des Tensors results.
L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :
cross_replica(replica_groups)sichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sichannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sichannel_id > 0 and use_global_device_ids = true.
Ensuite, dans chaque process_group :
operands...@receiver = [operand@sender for sender in process_group]pour tous lesreceiverdansprocess_group.results...@process = concatenate(operands...@process, all_gather_dim)pour tous lesprocessdansprocess_group.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operands |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1), (C6) |
| (I2) | all_gather_dim |
constante de type si64 |
(C1), (C6) |
| (I3) | replica_groups |
Constante Tensor à deux dimensions de type si64 |
(C2-C4) |
| (I4) | channel_id |
constante de type si64 |
(C5) |
| (I5) | use_global_device_ids |
constante de type i1 |
(C5) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C6) |
Contraintes
- (C1)
0 <= all_gather_dim < rank(operands...). - (C2)
is_unique(replica_groups). - (C3)
size(replica_groups)est défini comme suit :num_replicassicross_replicaest utilisé.num_replicassicross_replica_and_partitionest utilisé.num_processessiflattened_idsest utilisé.
- (C4)
0 <= replica_groups < size(replica_groups). - (C5) Si
use_global_device_ids = true, alorschannel_id > 0. - (C6)
type(results...) = type(operands...)sauf :dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
// channel_id = 0
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
// use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, applique une fonction de réduction computation aux valeurs des Tensors operands de chaque processus et produit des Tensors results.
L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :
cross_replica(replica_groups)sichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sichannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sichannel_id > 0 and use_global_device_ids = true.
Ensuite, dans chaque process_group :
results...@process[result_index] = exec(schedule)pour un arbre binaire donnéscheduleoù :exec(node)=computation(exec(node.left), exec(node.right)).exec(leaf)=leaf.value.
scheduleest un arbre binaire défini par l'implémentation dont la traversée dans l'ordre estto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operands |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C5), (C6) |
| (I2) | replica_groups |
Nombre variable de constantes Tensor à une dimension de type si64 |
(C1-C3) |
| (I3) | channel_id |
constante de type si64 |
(C4) |
| (I4) | use_global_device_ids |
constante de type i1 |
(C4) |
| (I5) | computation |
fonction | (C5) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C6-C7) |
Contraintes
- (C1)
is_unique(replica_groups). - (C2)
size(replica_groups)est défini comme suit :num_replicassicross_replicaest utilisé.num_replicassicross_replica_and_partitionest utilisé.num_processessiflattened_idsest utilisé.
- (C3)
0 <= replica_groups < size(replica_groups). - (C4) Si
use_global_device_ids = true, alorschannel_id > 0. - (C5)
computationest de type(tensor<E>, tensor<E>) -> (tensor<E>)oùis_promotable(element_type(operand), E). - (C6)
shape(results...) = shape(operands...). - (C7)
element_type(results...) = E.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
// channel_id = 0
channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
// use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, divise les valeurs des tenseurs operands le long de split_dimension en parties, répartit les parties divisées entre les processus, concatène les parties réparties le long de concat_dimension et produit des tenseurs results.
L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :
cross_replica(replica_groups)sichannel_id <= 0.cross_partition(replica_groups)sichannel_id > 0.
Ensuite, dans chaque process_group :
split_parts...@sender = split(operands...@sender, split_count, split_dimension)pour tous lessenderdansprocess_group.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]oùreceiver_index = process_group.index(receiver).results...@process = concatenate(scattered_parts...@process, concat_dimension).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operands |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1-C3), (C9) |
| (I2) | split_dimension |
constante de type si64 |
(C1), (C2), (C9) |
| (I3) | concat_dimension |
constante de type si64 |
(C3), (C9) |
| (I4) | split_count |
constante de type si64 |
(C2), (C4), (C8), (C9) |
| (I5) | replica_groups |
Constante Tensor à deux dimensions de type si64 |
(C5-C8) |
| (I6) | channel_id |
constante de type si64 |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C9) |
Contraintes
- (C1)
0 <= split_dimension < rank(operands...). - (C2)
dim(operands..., split_dimension) % split_count = 0. - (C3)
0 <= concat_dimension < rank(operands...). - (C4)
0 < split_count. - (C5)
is_unique(replica_groups). - (C6)
size(replica_groups)est défini comme suit :num_replicassicross_replicaest utilisé.num_partitionssicross_partitionest utilisé.
- (C7)
0 <= replica_groups < size(replica_groups). - (C8)
dim(replica_groups, 1) = split_count. - (C9)
type(results...) = type(operands...)sauf sisplit_dimension != concat_dimension:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
// channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
et
Sémantique
Effectue un AND élément par élément de deux Tensors lhs et rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : AND logique.
- Pour les nombres entiers : AND au niveau du bit.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tensor de type booléen ou entier | (C1) |
| (I2) | rhs |
Tensor de type booléen ou entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result).
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]
atan2
Sémantique
Effectue une opération atan2 au niveau des éléments sur les Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
atan2à partir de la norme IEEE-754. - Pour les nombres complexes : atan2 complexe.
- Pour les types quantifiés :
dequantize_op_quantize(atan2, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemples
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Sémantique
Calcule les gradients de plusieurs entrées de batch_norm_training en rétropropagation à partir de grad_output et produit les Tensors grad_operand, grad_scale et grad_offset. Plus formellement, cette opération peut être exprimée comme une décomposition en opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Pour les types quantifiés, effectue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1-C3), (C5) |
| (I2) | scale |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C4), (C5) |
| (I3) | mean |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
| (I4) | variance |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
| (I5) | grad_output |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C2), (C3) |
| (I6) | epsilon |
constante de type f32 |
|
| (I7) | feature_index |
constante de type si64 |
(C1), (C5) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
grad_operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C2), (C3) |
grad_scale |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
grad_offset |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
Contraintes
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,mean,variance,grad_output,grad_operand,grad_scaleetgrad_offsetont le mêmebaseline_element_type. - (C3)
operand,grad_outputetgrad_operandont la même forme. - (C4)
scale,mean,variance,grad_scaleetgrad_offsetont la même forme. - (C5)
size(scale) = dim(operand, feature_index).
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
< tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Sémantique
Normalise le tenseur operand dans toutes les dimensions, à l'exception de la dimension feature_index, et produit un tenseur result. Plus formellement, cette opération peut être exprimée comme une décomposition en opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1-C7) |
| (I2) | scale |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C3) |
| (I3) | offset |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C4) |
| (I4) | mean |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C5) |
| (I5) | variance |
Tenseur à une dimension de type à virgule flottante ou quantifié par tenseur | (C2), (C6) |
| (I6) | epsilon |
constante de type f32 |
|
| (I7) | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C2), (C7) |
Contraintes
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,mean,varianceetresultont le mêmebaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(mean) = dim(operand, feature_index). - (C6)
size(variance) = dim(operand, feature_index). - (C7)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Sémantique
Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de la dimension feature_index, et normalise le Tensor operand pour produire les Tensors output, batch_mean et batch_var. Plus formellement, cette opération peut être exprimée comme une décomposition des opérations StableHLO existantes à l'aide de la syntaxe Python comme suit :
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Pour les types quantifiés, effectue dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
| (I2) | scale |
Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur | (C2), (C3) |
| (I3) | offset |
Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur | (C2), (C4) |
| (I4) | epsilon |
constante de type f32 |
(C1), (C3-C6) |
| (I5) | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
output |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C7) |
batch_mean |
Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur | (C2), (C5) |
batch_var |
Tenseur à une dimension de valeurs à virgule flottante ou quantifiées par tenseur | (C2), (C6) |
Contraintes
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,batch_mean,batch_varetoutputont le mêmebaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(batch_mean) = dim(operand, feature_index). - (C6)
size(batch_var) = dim(operand, feature_index). - (C7)
baseline_type(output) = baseline_type(operand).
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
< (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Sémantique
Effectue une opération bitcast sur le Tensor operand et produit un Tensor result où les bits de l'ensemble du Tensor operand sont réinterprétés à l'aide du type du Tensor result.
Plus formellement, étant donné E = element_type(operand), E' = element_type(result) et R = rank(operand) :
- Si
num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). - Si
num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). - Si
num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
bits renvoie la représentation en mémoire d'une valeur donnée. Son comportement est défini par l'implémentation, car la représentation exacte des Tensors et des types d'éléments est également définie par l'implémentation.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié | (C1-C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C1-C2) |
Contraintes
- (C1) Étant donné
E = is_quantized(operand) ? storage_type(operand) : element_type(operand),E' = is_quantized(result) ? storage_type(result) : element_type(result)etR = rank(operand):- Si
num_bits(E') = num_bits(E),shape(result) = shape(operand). - Si
num_bits(E') < num_bits(E): rank(result) = R + 1.dim(result, i) = dim(operand, i)pour tous les0 <= i < R.dim(result, R) * num_bits(E') = num_bits(E).- Si
num_bits(E') > num_bits(E): rank(result) = R - 1.dim(result, i) = dim(operand, i)pour tous les0 <= i < R.dim(operand, R - 1) * num_bits(E) = num_bits(E').
- Si
- (C2) Si
is_complex(operand) or is_complex(result), alorsis_complex(operand) and is_complex(result).
Exemples
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Sémantique
Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données dans le Tensor operand et produit un Tensor result. Plus formellement,
result[result_index] = operand[operand_index] où pour tous les d dans
axes(operand) :
operand_index[d] = 0sidim(operand, d) = 1.- Sinon,
operand_index[d] = result_index[broadcast_dimensions[d]].
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié | (C1-C2), (C5-C6) |
| (I2) | broadcast_dimensions |
Constante Tensor de type si64 à une dimension |
(C2-C6) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C1), (C3), (C5-C6) |
Contraintes
- (C1)
element_type(result)est donné par :element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), sauf siquantization_dimension(operand),scales(operand)etzero_points(operand)diffèrent dequantization_dimension(result),scales(result)etzero_points(result), respectivement.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Pour tous les
ddansaxes(operand):dim(operand, d) = 1oudim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Si
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Si
dim(operand, quantization_dimension(operand)) = 1, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
Exemples
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
coque
Sémantique
Génère le résultat de l'exécution d'une seule fonction à partir de branches en fonction de la valeur de index. Plus précisément, result = selected_branch(), où :
selected_branch = branches[index]si0 <= index < size(branches).- Sinon,
selected_branch = branches[-1].
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | index |
Tensor de dimension 0 de type si32 |
|
| (I2) | branches |
nombre variable de fonctions | (C1-C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de Tensors, de Tensors quantifiés ou de jetons. | (C4) |
Contraintes
- (C1)
0 < size(branches). - (C2)
input_types(branches...) = []. - (C3)
same(output_types(branches...)). - (C4)
type(results...) = output_types(branches[0]).
Exemples
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
"stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Sémantique
Effectue une opération de racine cubique élément par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
rootn(x, 3)à partir de la norme IEEE-754. - Pour les nombres complexes : racine cubique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(cbrt, operand, type(result))
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Sémantique
Effectue un plafond élément par élément du Tensor operand et produit un Tensor result.
Implémente l'opération roundToIntegralTowardPositive à partir de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(ceil, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Sémantique
Calcule la décomposition de Cholesky d'un lot de matrices.
Plus formellement, pour tous les i dans index_space(result), result[i0, ..., iR-3, :, :] est une décomposition de Cholesky de a[i0, ..., iR-3, :, :], sous la forme d'une matrice triangulaire inférieure (si lower est true) ou supérieure (si lower est false).
Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict ou le triangle inférieur strict, sont définies par l'implémentation.
Si i existe et que la matrice d'entrée n'est pas une matrice hermitienne définie positive, le comportement n'est pas défini.
Pour les types quantifiés, effectue dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | a |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1-C3) |
| (I2) | lower |
constante de type i1 |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(a) = baseline_type(result). - (C2)
2 <= rank(a). - (C3)
dim(a, -2) = dim(a, -1).
Exemples
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
limiter
Sémantique
Limite chaque élément du Tensor operand entre une valeur minimale et maximale, et produit un Tensor result. Plus formellement, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element), où min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Pour les types quantifiés, effectue dequantize_op_quantize(clamp, min, operand, max, type(result)).
L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (#560).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | min |
tenseur ou tenseur quantifié par tenseur | (C1), (C3) |
| (I2) | operand |
tenseur ou tenseur quantifié par tenseur | (C1-C4) |
| (I3) | max |
tenseur ou tenseur quantifié par tenseur | (C2), (C3) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C4) |
Contraintes
- (C1)
rank(min) = 0 or shape(min) = shape(operand). - (C2)
rank(max) = 0 or shape(max) = shape(operand). - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max). - (C4)
baseline_type(operand) = baseline_type(result).
Exemples
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]
collective_broadcast
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du tenseur operand du processus source aux processus cibles et générez un tenseur result.
L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :
cross_replica(replica_groups)sichannel_id <= 0.cross_partition(replica_groups)sichannel_id > 0.
result@process est ensuite donné par :
operand@process_groups[i, 0]s'il existe unitel que le processus soit dansprocess_groups[i].broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))dans le cas contraire.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C3) |
| (I2) | replica_groups |
Nombre variable de constantes Tensor à une dimension de type si64 |
(C1), (C2) |
| (I3) | channel_id |
constante de type si64 |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C3) |
Contraintes
- (C1)
is_unique(replica_groups). - (C2)
0 <= replica_groups < NoùNest défini comme suit :num_replicassicross_replicaest utilisé.num_partitionssicross_partitionest utilisé.
- (C3)
type(result) = type(operand).
Exemples
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du tenseur operand du processus source au processus cible et produit un tenseur result.
L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :
cross_replica(source_target_pairs)sichannel_id <= 0.cross_partition(source_target_pairs)sichannel_id > 0.
result@process est ensuite donné par :
operand@process_groups[i, 0], s'il existe unitel queprocess_groups[i, 1] = process.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))dans le cas contraire.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C5) |
| (I2) | source_target_pairs |
Constante Tensor à deux dimensions de type si64 |
(C1-C4) |
| (I3) | channel_id |
constante de type si64 |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, oùNest défini comme suit :num_replicassicross_replicaest utilisé.num_partitionssicross_partitionest utilisé.
- (C5)
type(result) = type(operand).
Exemples
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Sémantique
Effectue une comparaison élément par élément des Tensors lhs et rhs selon comparison_direction et compare_type, et produit un Tensor result.
Les valeurs de comparison_direction et compare_type ont la sémantique suivante :
Pour les types d'éléments booléens et entiers :
EQ:lhs = rhs.NE:lhs != rhs.GE:lhs >= rhs.GT:lhs > rhs.LE:lhs <= rhs.LT:lhs < rhs.
Pour les types d'éléments à virgule flottante avec compare_type = FLOAT, l'opération implémente les opérations IEEE-754 suivantes :
EQ:compareQuietEqual.NE:compareQuietNotEqual.GE:compareQuietGreaterEqual.GT:compareQuietGreater.LE:compareQuietLessEqual.LT:compareQuietLess.
Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER, l'opération utilise la combinaison des opérations totalOrder et compareQuietEqual de la norme IEEE-754.
Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag) est effectuée à l'aide des comparison_direction et compare_type fournis.
L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes lorsque comparison_direction est GE, GT, LE ou LT (#560).
Pour les types quantifiés, effectue dequantize_compare(lhs, rhs,
comparison_direction).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié par tenseur | (C1-C3) |
| (I2) | rhs |
tenseur ou tenseur quantifié par tenseur | (C1-C2) |
| (I3) | comparison_direction |
Énumération de EQ, NE, GE, GT, LE et LT |
|
| (I4) | compare_type |
Énumération de FLOAT, TOTALORDER, SIGNED et UNSIGNED |
(C3) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type booléen | (C2) |
Contraintes
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs). - (C2)
shape(lhs) = shape(rhs) = shape(result). - (C3)
compare_typeest défini comme suit :SIGNEDsiis_signed_integer(element_type(lhs)).UNSIGNEDsiis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).FLOATouTOTALORDERsiis_float(element_type(lhs)).FLOATsiis_complex(element_type(lhs)).
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = <#stablehlocomparison_di>rection LT,
compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]
complexe
Sémantique
Effectue une conversion élément par élément en valeur complexe à partir d'une paire de valeurs réelles et imaginaires, lhs et rhs, et produit un Tensor result.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tensor de type f32 ou f64 |
(C1-C3) |
| (I2) | rhs |
Tensor de type f32 ou f64 |
(C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type complexe | (C2), (C3) |
Contraintes
- (C1)
type(lhs) = type(rhs). - (C2)
shape(result) = shape(lhs). - (C3)
element_type(result)est de typecomplex<E>oùE = element_type(lhs).
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]
composite
Sémantique
Encapsule une opération composée d'autres opérations StableHLO, en prenant inputs et composite_attributes et en produisant results. La sémantique de l'opération est implémentée par l'attribut decomposition. L'opération composite peut être remplacée par sa décomposition sans modifier la sémantique du programme. Dans les cas où l'intégration de la décomposition ne fournit pas la même sémantique d'op, préférez utiliser custom_call.
Le champ version (par défaut 0) est utilisé pour indiquer quand la sémantique d'un composite change.
Entrées
| Libellé | Nom | Type |
|---|---|---|
| (I1) | inputs |
nombre variable de valeurs |
| (I2) | name |
constante de type string |
| (I3) | composite_attributes |
dictionnaire d'attributs |
| (I4) | decomposition |
constante de type string |
| (I5) | version |
constante de type si32 |
Sorties
| Nom | Type |
|---|---|
results |
nombre variable de valeurs |
Contraintes
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
Exemples
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
< ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32
concatenate
Sémantique
Concatène inputs le long de la dimension dimension dans le même ordre que les arguments fournis et produit un Tensor result. Plus formellement,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], où :
id = d0 + ... + dk-1 + kd.dest égal àdimension, etd0, ... sont les tailles de lade dimension deinputs.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | inputs |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1-C6) |
| (I2) | dimension |
constante de type si64 |
(C2), (C4), (C6) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C5-C6) |
Contraintes
- (C1)
same(element_type(inputs...)). - (C2)
same(shape(inputs...)), saufdim(inputs..., dimension). - (C3)
0 < size(inputs). - (C4)
0 <= dimension < rank(inputs[0]). - (C5)
element_type(result) = element_type(inputs[0]). - (C6)
shape(result) = shape(inputs[0]), sauf :dim(result, dimension) = dim(inputs[0], dimension) + ....
Exemples
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Sémantique
Génère un Tensor output à partir d'une constante value.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | value |
constante | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
output |
tenseur ou tenseur quantifié | (C1) |
Contraintes
- (C1)
type(value) = type(output).
Exemples
%output = "stablehlo.constant"() {
val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]
d'effectuer une conversion
Sémantique
Effectue une conversion élément par élément d'un type d'élément à un autre sur le Tensor operand et produit un Tensor result.
Pour les conversions boolean-to-any-supported-type, la valeur false est convertie en zéro et la valeur true est convertie en un. Pour les conversions any-supported-type-to-boolean, une valeur nulle est convertie en false et les valeurs non nulles sont converties en true. Vous trouverez ci-dessous des informations sur le fonctionnement de cette fonctionnalité pour les types complexes.
Pour les conversions d'entier à entier, d'entier à virgule flottante ou de virgule flottante à virgule flottante, si la valeur source peut être représentée exactement dans le type de destination, la valeur résultante est cette représentation exacte. Sinon, le comportement est à déterminer (#180).
Pour les conversions floating-point-to-integer, la partie fractionnaire est tronquée. Si la valeur tronquée ne peut pas être représentée dans le type de destination, le comportement est à déterminer (#180).
Les conversions complexe vers complexe suivent le même comportement que les conversions virgule flottante vers virgule flottante pour convertir les parties réelles et imaginaires.
Pour les conversions complex-to-any-other-type et any-other-type-to-complex, la valeur imaginaire source est ignorée ou la valeur imaginaire de destination est définie sur zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.
En principe, cette opération pourrait exprimer la déquantification (conversion de Tensors quantifiés en Tensors réguliers), la quantification (conversion de Tensors réguliers en Tensors quantifiés) et la requantification (conversion entre Tensors quantifiés), mais pour le moment, nous avons des opérations dédiées pour cela : uniform_dequantize pour le premier cas d'utilisation et uniform_quantize pour les deuxième et troisième cas d'utilisation. À l'avenir, ces deux opérations pourront être fusionnées dans convert (#1576).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur | (C1) |
Contraintes
- (C1)
shape(operand) = shape(result).
Exemples
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolution
Sémantique
Calcule les produits scalaires entre les fenêtres de lhs et les tranches de rhs, et génère result. Le diagramme suivant montre comment les éléments de result sont calculés à partir de lhs et rhs à l'aide d'un exemple concret.
Plus formellement, considérez la reformulation suivante des entrées en termes de lhs afin de pouvoir exprimer des fenêtres de lhs :
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).lhs_window_strides = lhs_shape(1, window_strides, 1)lhs_padding = lhs_shape([0, 0], padding, [0, 0])lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).
Cette reformulation utilise les fonctions d'assistance suivantes :
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]oùj[d] = i[permutation[d]].
Si feature_group_count = 1 et batch_group_count = 1, alors pour tous les output_spatial_index dans index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product où :
padding_value = constant(0, element_type(lhs)).padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strideslhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).
Si feature_group_count > 1 :
lhses = split(lhs, feature_group_count, input_feature_dimension).rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)result = concatenate(results, output_feature_dimension).
Si batch_group_count > 1 :
lhses = split(lhs, batch_group_count, input_batch_dimension).rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Pour les types quantifiés, effectue dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)).
Pour les types quantifiés hybrides, effectue hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié par tenseur | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
| (I2) | rhs |
tenseur ou tenseur quantifié | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
| (I3) | window_strides |
Constante Tensor de type si64 à une dimension |
(C2-C3), (C25) |
| (I4) | padding |
Constante Tensor à deux dimensions de type si64 |
(C4), (C25) |
| (I5) | lhs_dilation |
Constante Tensor de type si64 à une dimension |
(C5-C6), (C25) |
| (I6) | rhs_dilation |
Constante Tensor de type si64 à une dimension |
(C7-C8), (C25) |
| (I7) | window_reversal |
Constante Tensor de type i1 à une dimension |
(C9) |
| (I8) | input_batch_dimension |
constante de type si64 |
(C10), (C13), (C25) |
| (I9) | input_feature_dimension |
constante de type si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Constante Tensor de type si64 à une dimension |
(C12), (C13), (C25) |
| (I11) | kernel_input_feature_dimension |
constante de type si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
constante de type si64 |
(C15-C16), (C18), (C25), (C29) |
| (I13) | kernel_spatial_dimensions |
Constante Tensor de type si64 à une dimension |
(C17-C18), (C25) |
| (I14) | output_batch_dimension |
constante de type si64 |
(C20), (C25) |
| (I15) | output_feature_dimension |
constante de type si64 |
(C20), (C25), (C30) |
| (I16) | output_spatial_dimensions |
Constante Tensor de type si64 à une dimension |
(C19-C20), (C25) |
| (I17) | feature_group_count |
constante de type si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
constante de type si64 |
(C10), (C15), (C22), (C23), (C25) |
| (I19) | precision_config |
Nombre variable d'énums de DEFAULT, HIGH et HIGHEST |
(C24) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C25-C28), (C30), (C32-34) |
Contraintes
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Étant donné
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Étant donné
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Étant donné
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)est défini comme suit :dim(lhs, input_batch_dimension) / batch_group_countsiresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)siresult_dim = output_feature_dimension.num_windowssinon, où :output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim]rhs_dim = kernel_spatial_dimensions[spatial_dim]dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Si l'opération utilise des Tensors non quantifiés :
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Si l'opération utilise des Tensors quantifiés :
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Si
is_per_axis_quantized(rhs), alorsquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Si
is_per_axis_quantized(result), alorsquantization_dimension(result) = output_feature_dimension. - Si
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Si
is_per_tensor_quantized(rhs), alorsis_per_tensor_quantized(result). - Si
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Exemples
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strid<es = arra>yi64: 4, 4,
paddi<n>g = dense<0 : ten>sor2x2xi64,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
// In the StableHLO dialect, dimension numbers are encoded vi<a:
// `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" a<re spatial dimensions.
d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
batch_group_count = 1 : i64,
fea<ture_group_count >= 1 : i64,
< precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosinus
Sémantique
Effectue une opération de cosinus au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
cosà partir de la norme IEEE-754. - Pour les nombres complexes : cosinus complexe.
- Pour les types quantifiés :
dequantize_op_quantize(cosine, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Sémantique
Effectue un décompte élément par élément du nombre de bits zéro de début dans le Tensor operand et produit un Tensor result.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur de type entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result).
Exemples
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]
custom_call
Sémantique
Encapsule une opération call_target_name définie par l'implémentation qui prend inputs et called_computations et produit results. has_side_effect, backend_config et api_version peuvent être utilisés pour fournir des métadonnées supplémentaires définies par l'implémentation.
Pour le moment, cette opération contient une collection de métadonnées assez désorganisée qui reflète l'évolution organique de son opération homologue dans le compilateur XLA. À l'avenir, nous prévoyons d'unifier ces métadonnées (#741).
Entrées
| Libellé | Nom | Type |
|---|---|---|
| (I1) | inputs |
nombre variable de valeurs |
| (I2) | call_target_name |
constante de type string |
| (I3) | has_side_effect |
constante de type i1 |
| (I4) | backend_config |
constante de type string ou dictionnaire d'attributs |
| (I5) | api_version |
constante de type si32 |
| (I6) | called_computations |
nombre variadique de constantes de type string |
| (I7) | output_operand_aliases |
spécifier les parties d'alias dans les sorties et les opérandes ; |
Sorties
| Nom | Type |
|---|---|
results |
nombre variable de valeurs |
(Compatibilité XLA avec les GPU) Cibles custom_call spéciales
Il existe trois call_target_name spéciaux liés aux types buffer : CreateBuffer crée un buffer non initialisé, Pin crée un buffer initialisé et Unpin libère un buffer et renvoie le contenu du buffer.
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version> = 4 : <i32,
>} : () - memref4xf64
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin&quo<t;,
> ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64
Alias
Certaines opérations custom_call peuvent nécessiter qu'une partie des sorties et une partie des opérandes partagent la même mémoire. Cela peut être exprimé via output_operand_aliases. Une représentation de paire d'alias se compose d'une liste d'indices de tuples de sortie représentant la partie de sortie, et d'un operand_index ainsi que d'une liste d'indices de tuples d'opérandes représentant la partie d'opérande. La liste des indices de tuples de sortie ou d'opérandes est vide si le type correspondant n'est pas un type tuple et peut être arbitrairement longue pour un type de tuple arbitrairement imbriqué. C'est semblable à la représentation de l'alias XLA.
La partie de sortie et la partie d'entrée d'une paire d'alias doivent être du même type. Pour les opérations custom_call qui ne sont pas des appels à CreateBuffer, Pin et Unpin, un opérande buffer peut apparaître dans une seule paire d'alias, et une sortie buffer doit apparaître dans une paire d'alias.
Exemples
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases< = [
#stablehlo.output_operand_aliasoutput_tuple_indices = [],
operand_ind>ex = 0,
< oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64
diviser
Sémantique
Effectue une division élément par élément des Tensors de dividende lhs et de diviseur rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les nombres entiers : division entière qui produit le quotient algébrique en supprimant toute partie fractionnaire.
- Pour les valeurs float :
divisionà partir de la norme IEEE-754. - Pour les nombres complexes : division complexe.
- Pour les types quantifiés :
dequantize_op_quantize(divide, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemples
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Sémantique
Calcule les produits scalaires entre les tranches de lhs et celles de rhs, et génère un Tensor result.
Plus précisément, result[result_index] = dot_product, où :
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].result_batching_index + result_lhs_index + result_rhs_index = result_indexoùsize(result_batching_index) = size(lhs_batching_dimensions),size(result_lhs_index) = size(lhs_result_dimensions)etsize(result_rhs_index) = size(rhs_result_dimensions).transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).
Pour les types quantifiés, effectue dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).
Pour les types quantifiés hybrides, effectue hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs).
precision_config contrôle le compromis entre vitesse et précision pour les calculs sur les backends d'accélérateur. Il peut s'agir de l'une des valeurs suivantes (pour le moment, la sémantique de ces valeurs d'énumération n'est pas suffisamment spécifiée, mais nous prévoyons de résoudre ce problème dans #755) :
DEFAULT: calcul le plus rapide, mais approximation la moins précise du nombre d'origine.HIGH: calcul plus lent, mais approximation plus précise du nombre d'origine.HIGHEST: le calcul le plus lent, mais l'approximation la plus précise du nombre d'origine.
Un DotAlgorithm définit les propriétés principales de l'algorithme utilisé pour implémenter l'opération de produit scalaire, qui définit également la précision. Si les champs d'attribut d'algorithme sont définis, precision_config doit être DEFAULT. DotAlgorithms n'ont pas de valeur par défaut, car les paramètres par défaut sont définis par l'implémentation. Par conséquent, tous les champs de l'algorithme de points peuvent être définis sur None pour spécifier un algorithme de points vide, qui utilisera plutôt la valeur precision_config.
Les champs DotAlgorithm incluent les suivants :
lhs_precision_typeetrhs_precision_type, les précisions auxquelles les côtés gauche et droit de l'opération sont arrondis. Les types de précision sont indépendants des types de stockage des entrées et de la sortie.accumulation_type: précision utilisée pour l'accumulation.lhs_component_count,rhs_component_countetnum_primitive_operationss'appliquent lorsque nous effectuons un algorithme qui décompose les côtés gauche et/ou droit en plusieurs composants et effectue plusieurs opérations de produit scalaire "primitives" sur ces valeurs, généralement pour émuler une précision plus élevée (par exemple, Utiliser le type de données d'intelligence artificielle bfloat16 pour des calculs plus précis : bf16_6x tf32_3x, etc.). Pour les algorithmes sans décomposition, ces valeurs doivent être définies sur1.allow_imprecise_accumulationpour spécifier si l'accumulation dans une précision inférieure est autorisée pour certaines étapes (par exemple,CUBLASLT_MATMUL_DESC_FAST_ACCUM).
Exemples d'attributs DotAlgorithm :
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Il appartient aux implémentations de décider quelles combinaisons sont prises en charge. En général, il n'est pas garanti que chaque algorithme soit compatible avec chaque type d'accélérateur par le consommateur de StableHLO. Si un algorithme donné n'est pas pris en charge, une erreur doit être générée au lieu de revenir à une alternative. La validation StableHLO fournira la meilleure vérification possible, en empêchant les algorithmes qui ne sont pas connus pour être compatibles avec aucun matériel.
Pour obtenir des exemples de valeurs d'algorithme acceptées, consultez xla_data.proto > Algorithm. La demande 2483 décrit le plan de création d'un document centralisé sur les algorithmes compatibles par backend.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié par tenseur | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
| (I2) | rhs |
tenseur ou tenseur quantifié | (C7-C10), (C12-C20) |
| (I3) | lhs_batching_dimensions |
Constante Tensor de type si64 à une dimension |
(C1), (C3), (C5), (C9), (C12) |
| (I4) | rhs_batching_dimensions |
Constante Tensor de type si64 à une dimension |
(C1), (C4), (C7), (C9) |
| (I5) | lhs_contracting_dimensions |
Constante Tensor de type si64 à une dimension |
(C2), (C3), (C6), (C10) |
| (I6) | rhs_contracting_dimensions |
Constante Tensor de type si64 à une dimension |
(C2), (C4), (C8), (C10), (C16) |
| (I7) | precision_config |
Nombre variable d'énums de DEFAULT, HIGH et HIGHEST |
(C11), (C21) |
| (I8) | lhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType ou TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
constante de type si32 |
(C21), (C22) |
| (I12) | rhs_component_count |
constante de type si32 |
(C21), (C23) |
| (I13) | num_primitive_operations |
constante de type si32 |
(C21), (C24) |
| (I14) | allow_imprecise_accumulation |
constante de type bool |
(C21) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C12), (C14), (C18-C20) |
Contraintes
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions). - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions). - (C5)
0 <= lhs_batching_dimensions < rank(lhs). - (C6)
0 <= lhs_contracting_dimensions < rank(lhs). - (C7)
0 <= rhs_batching_dimensions < rank(rhs). - (C8)
0 <= rhs_contracting_dimensions < rank(rhs). - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...). - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...). - (C11)
size(precision_config) = 2. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions). - Si l'opération utilise des Tensors non quantifiés :
- (C13)
element_type(lhs) = element_type(rhs).
- (C13)
- Si l'opération utilise des Tensors quantifiés :
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C15)
zero_points(rhs) = 0. - (C16) Si
is_per_axis_quantized(rhs), alorsquantization_dimension(rhs)n'est pas dansrhs_contracting_dimensions. - Si
is_quantized(lhs): - (C17)
storage_type(lhs) = storage_type(rhs). - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C19) Si
is_per_tensor_quantized(rhs), alorsis_per_tensor_quantized(result). - Si
!is_quantized(lhs): - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C14)
- Si
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):- (C21)
precision_config... = DEFAULT. - (C22)
0 < lhs_component_count. - (C23)
0 < rhs_component_count. - (C24)
0 < num_primitive_operations.
- (C21)
Exemples
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #sta<blehlo.dot
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimension>s = [1]
,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
algorithm = #stablehlo.dot<_algorithm
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation >= false
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Sémantique
Cette opération est fonctionnellement identique à l'opération broadcast_in_dim, mais la forme du résultat est spécifiée de manière dynamique via output_dimensions.
L'opération accepte également les attributs facultatifs known_expanding_dimensions et known_nonexpanding_dimensions pour exprimer des connaissances statiques sur le comportement d'expansion des dimensions.
Si aucune n'est spécifiée, toutes les dimensions sont considérées comme pouvant être développées.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié | (C1-C2), (C5-C6), (C9) |
| (I2) | output_dimensions |
Tensor unidimensionnel de type entier | (C7) |
| (I3) | broadcast_dimensions |
Tensor constant unidimensionnel de type entier | (C2-C6) |
| (I4) | known_expanding_dimensions |
Tensor constant unidimensionnel de type entier | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
Tensor constant unidimensionnel de type entier | (C8-C9) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C1), (C3), (C5-C7) |
Contraintes
- (C1)
element_type(result)est donné par :element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), sauf siquantization_dimension(operand),scales(operand)etzero_points(operand)diffèrent dequantization_dimension(result),scales(result)etzero_points(result), respectivement.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Pour tous les
ddansaxes(operand):dim(operand, d) = 1oudim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Si
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Si
dim(operand, quantization_dimension(operand)) = 1, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
- (C7)
size(output_dimensions) = rank(result). - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions). - (C9)
0 <= known_expanding_dimensions < rank(operand). - (C10)
0 <= known_nonexpanding_dimensions < rank(operand).
Exemples
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensio<ns = arra>yi64: 2, 1,
known_expanding_dimensio<ns = a>rrayi64: 0,
known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Sémantique
Cette opération est fonctionnellement identique à l'opération convolution, mais le remplissage est spécifié de manière dynamique via padding.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié par tenseur | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | rhs |
tenseur ou tenseur quantifié | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | padding |
Tensor bidimensionnel de type entier | (C4) |
| (I4) | window_strides |
Constante Tensor de type si64 à une dimension |
(C2-C3) |
| (I5) | lhs_dilation |
Constante Tensor de type si64 à une dimension |
(C5-C6) |
| (I6) | rhs_dilation |
Constante Tensor de type si64 à une dimension |
(C7-C8) |
| (I7) | window_reversal |
Constante Tensor de type i1 à une dimension |
(C9) |
| (I8) | input_batch_dimension |
constante de type si64 |
(C10), (C13) |
| (I9) | input_feature_dimension |
constante de type si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Constante Tensor de type si64 à une dimension |
(C12), (C13) |
| (I11) | kernel_input_feature_dimension |
constante de type si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
constante de type si64 |
(C15-C16), (C18), (C28) |
| (I13) | kernel_spatial_dimensions |
Constante Tensor de type si64 à une dimension |
(C17-C18) |
| (I14) | output_batch_dimension |
constante de type si64 |
(C20) |
| (I15) | output_feature_dimension |
constante de type si64 |
(C20), (C29) |
| (I16) | output_spatial_dimensions |
Constante Tensor de type si64 à une dimension |
(C19-C20) |
| (I17) | feature_group_count |
constante de type si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
constante de type si64 |
(C10), (C15), (C22), (C23) |
| (I19) | precision_config |
Nombre variable d'énums de DEFAULT, HIGH et HIGHEST |
(C24) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C25-C27), (C29), (C31-C33) |
Contraintes
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Étant donné
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Étant donné
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Étant donné
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)est défini comme suit :dim(lhs, input_batch_dimension) / batch_group_countsiresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)siresult_dim = output_feature_dimension.num_windowssinon, où :output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim]rhs_dim = kernel_spatial_dimensions[spatial_dim]dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Si l'opération utilise des Tensors non quantifiés :
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Si l'opération utilise des Tensors quantifiés :
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Si
is_per_axis_quantized(rhs), alorsquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Si
is_per_axis_quantized(result), alorsquantization_dimension(result) = output_feature_dimension. - Si
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Si
is_per_tensor_quantized(rhs), alorsis_per_tensor_quantized(result). - Si
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Exemples
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strid<es = arra>yi64: 4, 4,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
dimension_numbers = #stab<lehlo.convraw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions => [1, 2]
,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Sémantique
Cette opération est fonctionnellement identique à l'opération gather, avec slice_sizes spécifié de manière dynamique en tant que valeur.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1), (C7), (C10-C12), (C14) |
| (I2) | start_indices |
Tensor de type entier | (C2), (C3), (C13) |
| (I3) | slice_sizes |
Tensor unidimensionnel de type entier | (C8), (C11-C13) |
| (I4) | offset_dims |
Constante Tensor de type si64 à une dimension |
(C1), (C4-C5), (C13) |
| (I5) | collapsed_slice_dims |
Constante Tensor de type si64 à une dimension |
(C1), (C6-C8), (C13) |
| (I6) | start_index_map |
Constante Tensor de type si64 à une dimension |
(C3), (C9), (C10) |
| (I7) | index_vector_dim |
constante de type si64 |
(C2), (C3), (C13) |
| (I8) | indices_are_sorted |
constante de type i1 |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C5), (C13-C14) |
Contraintes
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims). - (C7)
0 <= collapsed_slice_dims < rank(operand). - (C8)
slice_sizes[collapsed_slice_dims...] <= 1. - (C9)
is_unique(start_index_map). - (C10)
0 <= start_index_map < rank(operand). - (C11)
size(slice_sizes) = rank(operand). - (C12)
0 <= slice_sizes <= shape(operand). - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)où :batch_dim_sizes = shape(start_indices), sauf que la taille de la dimensionstart_indicescorrespondant àindex_vector_dimn'est pas incluse.offset_dim_sizes = shape(slice_sizes), sauf que les tailles de dimension dansslice_sizescorrespondant àcollapsed_slice_dimsne sont pas incluses.combineplacebatch_dim_sizessur les axes correspondant àbatch_dimsetoffset_dim_sizessur les axes correspondant àoffset_dims.
- (C14)
element_type(operand) = element_type(result).
Exemples
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stable<hlo.gather
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vect>or_dim = 2,
indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Sémantique
Cette opération est fonctionnellement identique à l'opération iota, mais la forme du résultat est spécifiée de manière dynamique via output_shape.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | output_shape |
Tensor unidimensionnel de type entier | (C1), (C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C2) |
Contraintes
- (C1)
0 <= iota_dimension < size(output_shape). - (C2)
rank(result) = size(output_shape).
Exemples
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Sémantique
Cette opération est fonctionnellement identique à l'opération pad, mais avec edge_padding_low, edge_padding_high et interior_padding spécifiés de manière dynamique en tant que valeurs.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1), (C2), (C4) |
| (I2) | padding_value |
Tenseur de dimension 0 ou tenseur quantifié par tenseur | (C1) |
| (I3) | edge_padding_low |
Tensor unidimensionnel de type entier | (C1), (C4) |
| (I4) | edge_padding_high |
Tensor unidimensionnel de type entier | (C1), (C4) |
| (I5) | interior_padding |
Tensor unidimensionnel de type entier | (C2-C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C3-C6) |
Contraintes
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Exemples
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Sémantique
Cette opération est fonctionnellement identique à l'opération reshape, mais la forme du résultat est spécifiée de manière dynamique via output_shape.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié | (C1-C3) |
| (I2) | output_shape |
Tensor unidimensionnel de type entier | (C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C1-C4) |
Contraintes
- (C1)
element_type(result)est donné par :element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), sauf quequantization_dimension(operand)etquantization_dimension(result)peuvent être différents.
- (C2)
size(operand) = size(result). - (C3) Si
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
- (C4)
size(output_shape) = rank(result).
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Sémantique
Extrait une tranche de operand à l'aide d'indices de début calculés de manière dynamique et produit un Tensor result. start_indices contient les indices de début de la tranche pour chaque dimension susceptible d'être ajustée, et slice_sizes contient la taille de la tranche pour chaque dimension. Plus formellement,
result[result_index] = operand[operand_index] où :
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).operand_index = adjusted_start_indices + result_index.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1), (C2), (C4) |
| (I2) | start_indices |
Nombre variable de tenseurs de dimension 0 de type entier | (C2), (C3) |
| (I3) | slice_sizes |
Constante Tensor de type si64 à une dimension |
(C2), (C4), (C5) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(slice_sizes) = rank(operand). - (C3)
same(type(start_indices...)). - (C4)
0 <= slice_sizes <= shape(operand). - (C5)
shape(result) = slice_sizes.
Exemples
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Sémantique
Produit un Tensor result qui est égal au Tensor operand, sauf que la tranche commençant à start_indices est mise à jour avec les valeurs de update.
Plus formellement, result[result_index] est défini comme suit :
update[update_index]si0 <= update_index < shape(update)où :adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).update_index = result_index - adjusted_start_indices.
- Sinon,
operand[result_index].
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1-C4), (C6) |
| (I2) | update |
tenseur ou tenseur quantifié par tenseur | (C2), (C3), (C6) |
| (I3) | start_indices |
Nombre variable de tenseurs de dimension 0 de type entier | (C4), (C5) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
type(operand) = type(result). - (C2)
element_type(update) = element_type(operand). - (C3)
rank(update) = rank(operand). - (C4)
size(start_indices) = rank(operand). - (C5)
same(type(start_indices...)). - (C6)
0 <= shape(update) <= shape(operand).
Exemples
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
< : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponentiel
Sémantique
Effectue une opération exponentielle au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
expà partir de la norme IEEE-754. - Pour les nombres complexes : exponentielle complexe.
- Pour les types quantifiés :
dequantize_op_quantize(exponential, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Sémantique
Effectue une opération exponentielle moins un au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
expm1à partir de la norme IEEE-754. - Pour les nombres complexes : exponentielle complexe moins un.
- Pour les types quantifiés :
dequantize_op_quantize(exponential_minus_one, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]
fft
Sémantique
Effectue les transformées de Fourier directes et inverses pour les entrées/sorties réelles et complexes.
fft_type est l'un des éléments suivants :
FFT: FFT complexe-complexe directe.IFFT: FFT complexe-complexe inverse.RFFT: FFT de réel à complexe.IRFFT: FFT inverse du réel au complexe (c'est-à-dire prend un complexe et renvoie un réel).
Plus formellement, étant donné la fonction fft qui prend en entrée des Tensors unidimensionnels de types complexes, produit en sortie des Tensors unidimensionnels de mêmes types et calcule la transformée de Fourier discrète :
Pour fft_type = FFT, result est défini comme le résultat final d'une série de calculs L où L = size(fft_length). Par exemple, pour L = 3 :
result1[i0, ..., :] = fft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
De plus, étant donné la fonction ifft qui a la même signature de type et calcule l'inverse de fft :
Pour fft_type = IFFT, result est défini comme l'inverse des calculs pour fft_type = FFT. Par exemple, pour L = 3 :
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])result[i0, ..., :] = ifft(result2[i0, ..., :]).
De plus, étant donné la fonction rfft qui prend des tenseurs unidimensionnels de types à virgule flottante, produit des tenseurs unidimensionnels de types complexes de la même sémantique à virgule flottante et fonctionne comme suit :
rfft(real_operand) = truncated_resultoùcomplex_operand... = (real_operand..., 0.0).complex_result = fft(complex_operand)truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].
(Lorsque la transformée de Fourier discrète est calculée pour des opérandes réels, les N/2 + 1 premiers éléments du résultat définissent sans ambiguïté le reste du résultat. Le résultat de rfft est donc tronqué pour éviter de calculer des éléments redondants.)
Pour fft_type = RFFT, result est défini comme le résultat final d'une série de calculs L où L = size(fft_length). Par exemple, pour L = 3 :
result1[i0, ..., :] = rfft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Enfin, étant donné la fonction irfft qui a la même signature de type et qui calcule l'inverse de rfft :
Pour fft_type = IRFFT, result est défini comme l'inverse des calculs pour fft_type = RFFT. Par exemple, pour L = 3 :
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])result[i0, ..., :] = irfft(result2[i0, ..., :]).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tensor de type à virgule flottante ou complexe | (C1), (C2), (C4), (C5) |
| (I2) | fft_type |
Énumération de FFT, IFFT, RFFT et IRFFT |
(C2), (C5) |
| (I3) | fft_length |
Constante Tensor de type si64 à une dimension |
(C1), (C3), (C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type à virgule flottante ou complexe | (C2), (C4), (C5) |
Contraintes
- (C1)
size(fft_length) <= rank(operand). - (C2) La relation entre les types d'éléments
operandetresultvarie :- Si
fft_type = FFT,element_type(operand)etelement_type(result)ont le même type complexe. - Si
fft_type = IFFT,element_type(operand)etelement_type(result)ont le même type complexe. - Si
fft_type = RFFT,element_type(operand)est un type à virgule flottante etelement_type(result)est un type complexe de la même sémantique à virgule flottante. - Si
fft_type = IRFFT,element_type(operand)est un type complexe etelement_type(result)est un type à virgule flottante de la même sémantique à virgule flottante.
- Si
- (C3)
1 <= size(fft_length) <= 3. - (C4) Si, parmi
operandetresult, il existe un Tensorrealde type à virgule flottante, alorsshape(real)[-size(fft_length):] = fft_length. - (C5)
shape(result) = shape(operand)sauf :- Si
fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1. - Si
fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
- Si
Exemples
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = <#stablehloff>t_type FFT,
fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
étage
Sémantique
Effectue le plancher élément par élément du Tensor operand et produit un Tensor result.
Implémente l'opération roundToIntegralTowardNegative à partir de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(floor, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
gather
Sémantique
Collecte les tranches du Tensor operand à partir des décalages spécifiés dans start_indices et produit un Tensor result.
Le schéma suivant montre comment les éléments de result sont mappés sur les éléments de operand à l'aide d'un exemple concret. Le diagramme choisit quelques exemples d'indices result et explique en détail à quels indices operand ils correspondent.
Plus précisément, result[result_index] = operand[operand_index] où :
batch_dims = [d for d in axes(result) and d not in offset_dims].batch_index = result_index[batch_dims...].start_indexest défini comme suit :start_indices[bi0, ..., :, ..., biN]oùbisont des éléments individuels dansbatch_indexet:est inséré à l'indexindex_vector_dim, siindex_vector_dim<rank(start_indices).- Sinon,
[start_indices[batch_index]].
- Pour
d_operanddansaxes(operand),full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])sid_operand = start_index_map[d_start].- Sinon,
full_start_index[d_operand] = 0.
- Pour
d_operanddansaxes(operand),full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sid_operand = operand_batching_dims[i_batching]etd_start = start_indices_batching_dims[i_batching].- Sinon,
full_batching_index[d_operand] = 0.
offset_index = result_index[offset_dims...].full_offset_index = [oi0, ..., 0, ..., oiN]oùoisont des éléments individuels dansoffset_index, et0est inséré aux indices decollapsed_slice_dimsetoperand_batching_dims.operand_index = full_start_index + full_batching_index + full_offset_index.
Si indices_are_sorted est true, l'implémentation peut supposer que start_indices est trié par rapport à start_index_map. Sinon, le comportement n'est pas défini. Plus précisément, pour tous les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
| (I2) | start_indices |
tenseur de type entier | (C2-C3), (C14), (C17), (C22) |
| (I3) | offset_dims |
Constante Tensor de type si64 à une dimension |
(C1), (C4-C5), (C22) |
| (I4) | collapsed_slice_dims |
Constante Tensor de type si64 à une dimension |
(C1), (C6-C9), (C22) |
| (I5) | operand_batching_dims |
Constante Tensor de type si64 à une dimension |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
| (I6) | start_indices_batching_dims |
Constante Tensor de type si64 à une dimension |
(C13-C17) |
| (I7) | start_index_map |
Constante Tensor de type si64 à une dimension |
(C3), (C18-C19) |
| (I8) | index_vector_dim |
constante de type si64 |
(C2-C3), (C15), (C22) |
| (I9) | slice_sizes |
Constante Tensor de type si64 à une dimension |
(C9), (C12), (C20-C22) |
| (I10) | indices_are_sorted |
constante de type i1 |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C5), (C22-C23) |
Contraintes
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims). - (C8)
0 <= collapsed_slice_dims < rank(operand). - (C9)
slice_sizes[collapsed_slice_dims...] <= 1. - (C10)
is_sorted(operand_batching_dims). - (C11)
0 <= operand_batching_dims < rank(operand). - (C12)
slice_sizes[operand_batching_dims...] <= 1. - (C13)
is_unique(start_indices_batching_dims). - (C14)
0 <= start_indices_batching_dims < rank(start_indices). - (C15)
index_vector_dim not in start_indices_batching_dims. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims). - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...). - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims)). - (C19)
0 <= start_index_map < rank(operand). - (C20)
size(slice_sizes) = rank(operand). - (C21)
0 <= slice_sizes <= shape(operand). - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)où :batch_dim_sizes = shape(start_indices), sauf que la taille de la dimensionstart_indicescorrespondant àindex_vector_dimn'est pas incluse.offset_dim_sizes = slice_sizes, sauf que les tailles de dimension dansslice_sizescorrespondant àcollapsed_slice_dimsetoperand_batching_dimsne sont pas incluses.combineplacebatch_dim_sizessur les axes correspondant àbatch_dimsetoffset_dim_sizessur les axes correspondant àoffset_dims.
- (C23)
element_type(operand) = element_type(result).
Exemples
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stable<hlo.gather
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vect>or_dim = 3,
slice_siz<es = arrayi64: >1, 1, 2, 2,
indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Sémantique
Produit la taille de l'dimension donné de l'operand. Plus précisément, result = dim(operand, dimension). La sémantique ne concerne que le composant de forme du type. Le type d'élément peut être n'importe quoi.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié | (C1) |
| (I2) | dimension |
constante de type si64 |
(C1) |
Sorties
| Nom | Type |
|---|---|
result |
Tensor de dimension 0 de type si32 |
Contraintes
- (C1)
0 <= dimension < rank(operand).
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3
get_tuple_element
Sémantique
Extrait l'élément à la position index du tuple operand et génère un result. Plus précisément, result = operand[index].
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tuple | (C1), (C2) |
| (I2) | index |
constante de type si32 |
(C1), (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
toute valeur | (C2) |
Contraintes
- (C1)
0 <= index < size(operand). - (C2)
type(result) = tuple_element_types(operand)[index].
Exemples
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]
si
Sémantique
Génère le résultat de l'exécution d'une seule fonction à partir de true_branch ou false_branch en fonction de la valeur de pred. Plus précisément, result =
pred ? true_branch() : false_branch().
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | pred |
Tensor de dimension 0 de type i1 |
|
| (I2) | true_branch |
fonction | (C1-C3) |
| (I3) | false_branch |
fonction | (C1), (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de Tensors, de Tensors quantifiés ou de jetons. | (C3) |
Contraintes
- (C1)
input_types(true_branch) = input_types(false_branch) = []. - (C2)
output_types(true_branch) = output_types(false_branch). - (C3)
type(results...) = output_types(true_branch).
Exemples
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
"stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10
imag
Sémantique
Extrait la partie imaginaire, élément par élément, de operand et produit un tenseur result. Plus précisément, pour chaque élément x :
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tensor de type à virgule flottante ou complexe | (C1), (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)est défini comme suit :complex_element_type(element_type(operand))siis_complex(operand).- Sinon,
element_type(operand).
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]
in-feed
Sémantique
Lit les données du flux In-Feed et génère results.
La sémantique de infeed_config est définie par l'implémentation.
results se compose de valeurs de charge utile qui viennent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour plus de clarté (#670).
Entrées
| Libellé | Nom | Type |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
constante de type string |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de Tensors, de Tensors quantifiés ou de jetons. | (C1-C3) |
Contraintes
- (C1)
0 < size(results). - (C2)
is_empty(result[:-1])ouis_tensor(type(results[:-1])). - (C3)
is_token(type(results[-1])).
Exemples
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Sémantique
Remplit un Tensor output avec des valeurs dans l'ordre croissant à partir de zéro le long de la dimension iota_dimension. Plus précisément,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
output |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
0 <= iota_dimension < rank(output).
Exemples
%output = "stablehlo.iota"() {
iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Sémantique
Effectue une vérification élément par élément pour déterminer si la valeur dans x est finie (c'est-à-dire ni +Inf, ni -Inf, ni NaN) et produit un Tensor y. Implémente l'opération isFinite de la spécification IEEE-754. Pour les types quantifiés, le résultat est toujours true.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | x |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
y |
Tensor de type booléen | (C1) |
Contraintes
- (C1)
shape(x) = shape(y).
Exemples
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]
log
Sémantique
Effectue une opération logarithmique au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
logà partir de la norme IEEE-754. - Pour les nombres complexes : logarithme complexe.
- Pour les types quantifiés :
dequantize_op_quantize(log, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Sémantique
Effectue une opération de logarithme plus un au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
logp1à partir de la norme IEEE-754. - Pour les nombres complexes :
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - Pour les types quantifiés :
dequantize_op_quantize(log_plus_one, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistique
Sémantique
Effectue une opération logistique au niveau des éléments sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
division(1, addition(1, exp(-x)))à partir de la norme IEEE-754. - Pour les nombres complexes : logistique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(logistic, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
carte
Sémantique
Applique une fonction de mappage computation à inputs le long de dimensions et produit un Tensor result.
Plus précisément, result[result_index] = computation(inputs...[result_index]).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | inputs |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1-C4) |
| (I2) | dimensions |
Constante Tensor de type si64 à une dimension |
(C3) |
| (I3) | computation |
fonction | (C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1), (C4) |
Contraintes
- (C1)
shape(inputs...) = shape(result). - (C2)
0 < size(inputs) = N. - (C3)
dimensions = range(rank(inputs[0])). - (C4)
computationest de type(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>oùEi = element_type(inputs[i])etE' = element_type(result).
Exemples
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
stablehlo.return %<0 :> tensori64
}) {
dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]
maximum
Sémantique
Effectue une opération max au niveau des éléments sur les Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : OR logique.
- Pour les nombres entiers : nombre entier maximal.
- Pour les valeurs float :
maximumà partir de la norme IEEE-754. - Pour les nombres complexes : maximum lexicographique pour la paire
(real, imaginary). L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (#560). - Pour les types quantifiés :
dequantize_op_quantize(maximum, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
tenseur ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]
minimum
Sémantique
Effectue une opération min élément par élément sur les Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : AND logique.
- Pour les nombres entiers : nombre entier minimal.
- Pour les valeurs float :
minimumà partir de la norme IEEE-754. - Pour les nombres complexes : minimum lexicographique pour la paire
(real, imaginary). L'imposition d'un ordre sur les nombres complexes implique une sémantique surprenante. Nous prévoyons donc de supprimer la prise en charge des nombres complexes pour cette opération à l'avenir (#560). - Pour les types quantifiés :
dequantize_op_quantize(minimum, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
tenseur ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]
multiplier
Sémantique
Effectue le produit élément par élément de deux tenseurs lhs et rhs, et produit un tenseur result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : AND logique.
- Pour les nombres entiers : multiplication d'entiers.
- Pour les valeurs float :
multiplicationà partir de la norme IEEE-754. - Pour les nombres complexes : multiplication complexe.
- Pour les types quantifiés :
dequantize_op_quantize(multiply, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
tenseur ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]
negate
Sémantique
Effectue la négation élément par élément du tenseur operand et produit un tenseur result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les entiers signés : négation d'un entier.
- Pour les entiers non signés : bitcast vers un entier signé, négation de l'entier, bitcast vers un entier non signé.
- Pour les valeurs float :
negateà partir de la norme IEEE-754. - Pour les nombres complexes : négation complexe.
- Pour les types quantifiés :
dequantize_op_quantize(negate, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]
not
Sémantique
Effectue un NOT élément par élément du Tensor operand et produit un Tensor result.
En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : NOT logique.
- Pour les entiers : NOT au niveau du bit.
Arguments
| Nom | Type | Contraintes |
|---|---|---|
operand |
Tensor de type booléen ou entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result).
Exemples
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]
optimization_barrier
Sémantique
Garantit que les opérations qui produisent le operand sont exécutées avant toute opération qui dépend du result et empêche les transformations du compilateur de déplacer des opérations au-delà de la barrière. À part cela, l'opération est une identité, c'est-à-dire result = operand.
Arguments
| Nom | Type | Contraintes |
|---|---|---|
operand |
Nombre variable de tenseurs, de tenseurs ou de jetons quantifiés par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Nombre variable de tenseurs, de tenseurs ou de jetons quantifiés par tenseur | (C1) |
Contraintes
- (C1)
type(operand...) = type(result...).
Exemples
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0
ou
Sémantique
Effectue un OR élément par élément de deux Tensors lhs et rhs et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : OR logique.
- Pour les entiers : opérateur OR (OU) au niveau du bit.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tensor de type entier ou booléen | (C1) |
| (I2) | rhs |
Tensor de type entier ou booléen | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type entier ou booléen | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result).
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]
sortie
Sémantique
Écrit inputs dans le flux de sortie et génère un jeton result.
La sémantique de outfeed_config est définie par l'implémentation.
Entrées
| Libellé | Nom | Type |
|---|---|---|
| (I1) | inputs |
un nombre variable de Tensors ou de Tensors quantifiés. |
| (I2) | token |
token |
| (I3) | outfeed_config |
constante de type string |
Sorties
| Nom | Type |
|---|---|
result |
token |
Exemples
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token
pad
Sémantique
Développe operand en ajoutant une marge intérieure autour du Tensor ainsi qu'entre les éléments du Tensor avec le padding_value donné.
edge_padding_low et edge_padding_high spécifient la quantité de marge intérieure ajoutée respectivement à l'extrémité inférieure (à côté de l'index 0) et à l'extrémité supérieure (à côté de l'index le plus élevé) de chaque dimension. La quantité de marge intérieure peut être négative, où la valeur absolue de la marge intérieure négative indique le nombre d'éléments à supprimer de la dimension spécifiée.
interior_padding spécifie la marge intérieure ajoutée entre deux éléments dans chaque dimension, qui ne peut pas être négative. La marge intérieure se produit avant la marge extérieure, de sorte qu'une marge extérieure négative supprimera les éléments de l'opérande avec marge intérieure.
Plus formellement, result[result_index] est défini comme suit :
operand[operand_index]siresult_index = edge_padding_low + operand_index * (interior_padding + 1).- Sinon,
padding_value.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1), (C2), (C4) |
| (I2) | padding_value |
Tenseur de dimension 0 ou tenseur quantifié par tenseur | (C1) |
| (I3) | edge_padding_low |
Constante Tensor de type si64 à une dimension |
(C1), (C4) |
| (I4) | edge_padding_high |
Constante Tensor de type si64 à une dimension |
(C1), (C4) |
| (I5) | interior_padding |
Constante Tensor de type si64 à une dimension |
(C2-C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C3-C6) |
Contraintes
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Exemples
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_l<ow = arra>yi64: 0, 1,
edge_padding_hi<gh = arra>yi64: 2, 1,
interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Sémantique
Produit partition_id du processus actuel.
Sorties
| Nom | Type |
|---|---|
result |
Tensor de dimension 0 de type ui32 |
Exemples
%result = "stablehlo.partition_id">;() : (<) - >tensorui32
popcnt
Sémantique
Effectue un décompte élément par élément du nombre de bits définis dans le Tensor operand et produit un Tensor result.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur de type entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result).
Exemples
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]
puissance
Sémantique
Effectue une exponentiation élément par élément du Tensor lhs par le Tensor rhs et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les nombres entiers : exponentiation d'entiers.
- Pour les valeurs float :
powà partir de la norme IEEE-754. - Pour les nombres complexes : exponentiation complexe.
- Pour les types quantifiés :
dequantize_op_quantize(power, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Sémantique
Extrait la partie réelle, élément par élément, de operand et produit un Tensor result. Plus précisément, pour chaque élément x :
real(x) = is_complex(x) ? real_part(x) : x.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tensor de type à virgule flottante ou complexe | (C1), (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)est défini comme suit :complex_element_type(element_type(operand))siis_complex(operand).- Sinon,
element_type(operand).
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]
recv
Sémantique
Reçoit des données d'un canal avec channel_id et produit results.
Si is_host_transfer est défini sur true, l'opération transfère les données depuis l'hôte. Sinon, il transfère les données depuis un autre appareil en fonction des valeurs de source_target_pairs. Cette option duplique les informations fournies dans channel_type. À l'avenir, nous prévoyons de n'en conserver qu'une seule (#666). Si is_host_transfer= false et que source_target_pairs est None ou vide, le comportement est considéré comme indéfini.
results se compose de valeurs de charge utile qui viennent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour plus de clarté (#670).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
constante de type si64 |
|
| (I3) | channel_type |
Énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST |
(C5) |
| (I4) | is_host_transfer |
constante de type i1 |
(C5-C6) |
| (I5) | source_target_pairs |
Constante Tensor à deux dimensions de type si64 |
(C1-C4), (C6) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de Tensors, de Tensors quantifiés ou de jetons. | (C2-C4) |
Contraintes
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, oùNest défini comme suit :num_replicassicross_replicaest utilisé.num_partitionssicross_partitionest utilisé.
- (C5)
channel_typeest défini comme suit :DEVICE_TO_HOSTsiis_host_transfer = true,- Sinon,
DEVICE_TO_DEVICE.
Exemples
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)
reduce
Sémantique
Applique une fonction de réduction body à inputs et init_values le long de dimensions et produit des Tensors results.
L'ordre des réductions est défini par l'implémentation, ce qui signifie que body et init_values doivent former un monoïde pour garantir que l'opération produit les mêmes résultats pour toutes les entrées sur toutes les implémentations. Toutefois, cette condition ne s'applique pas à de nombreuses réductions populaires. Par exemple, l'addition à virgule flottante pour body et zéro pour init_values ne forment pas réellement un monoïde, car l'addition à virgule flottante n'est pas associative.
Plus précisément, results...[j0, ..., jR-1] = reduce(input_slices_converted) où :
input_slices = inputs...[j0, ..., :, ..., jR-1], où:sont insérés àdimensions.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).reduce(input_slices_converted) = exec(schedule)pour un arbre binaire donnéscheduleoù :exec(node) = body(exec(node.left), exec(node.right)).exec(leaf) = leaf.value.
scheduleest un arbre binaire complet défini par l'implémentation dont la traversée dans l'ordre se compose de :- Valeurs
input_slices_converted...[index], pour tous lesindexdansindex_space(input_slices_converted)dans l'ordre lexicographique croissant deindex. - Entrecroisé avec une quantité de
init_values_converteddéfinie par l'implémentation à des positions définies par l'implémentation.
- Valeurs
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | inputs |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1-C4), (C6), (C7) |
| (I2) | init_values |
nombre variable de Tensors de dimension 0 ou de Tensors quantifiés par Tensor | (C2), (C3) |
| (I3) | dimensions |
Constante Tensor de type si64 à une dimension |
(C4), (C5), (C7) |
| (I4) | body |
fonction | (C6) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C3), (C7), (C8) |
Contraintes
- (C1)
same(shape(inputs...)). - (C2)
element_type(inputs...) = element_type(init_values...). - (C3)
0 < size(inputs) = size(init_values) = size(results) = N. - (C4)
0 <= dimensions < rank(inputs[0]). - (C5)
is_unique(dimensions). - (C6)
bodyest de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)oùis_promotable(element_type(inputs[i]), Ei). - (C7)
shape(results...) = shape(inputs...), sauf que les tailles de dimension deinputs...correspondant àdimensionsne sont pas incluses. - (C8)
element_type(results[i]) = Eipour tous lesidans[0,N).
Exemples
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]
reduce_precision
Sémantique
Effectue la conversion élément par élément de operand vers un autre type à virgule flottante qui utilise exponent_bits et mantissa_bits, puis de nouveau vers le type à virgule flottante d'origine, et produit un Tensor output.
Plus précisément :
- Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine à la valeur la plus proche pouvant être représentée avec
mantissa_bitsà l'aide de la sémantiqueroundToIntegralTiesToEven. - Ensuite, si
mantissa_bitsest inférieur au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués àmantissa_bits. - Ensuite, si les bits d'exposant du résultat intermédiaire ne tiennent pas dans la plage fournie par
exponent_bits, le résultat intermédiaire déborde à l'infini en utilisant le signe d'origine ou déborde à zéro en utilisant le signe d'origine. - Pour les types quantifiés, effectue
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
| (I2) | exponent_bits |
constante de type si32 |
(C2) |
| (I3) | mantissa_bits |
constante de type si32 |
(C3) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
output |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(output). - (C2)
1 <= exponent_bits. - (C3)
0 <= mantissa_bits.
Exemples
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, effectue une réduction, à l'aide de computations, sur les valeurs du Tensor operand de chaque processus, divise le résultat de la réduction le long de scatter_dimension en parties et distribue les parties divisées entre les processus pour produire result.
L'opération divise la grille de processus StableHLO en process_groups, qui est défini comme suit :
cross_replica(replica_groups)sichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sichannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sichannel_id > 0 and use_global_device_ids = true.
Ensuite, dans chaque process_group :
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).result@receiver = parts@sender[receiver_index]pour tous lessenderdansprocess_group, oùreceiver_index = process_group.index(receiver).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1), (C2), (C7), (C8) |
| (I2) | scatter_dimension |
constante de type si64 |
(C1), (C2), (C8) |
| (I3) | replica_groups |
Constante Tensor à deux dimensions de type si64 |
(C3-C5) |
| (I4) | channel_id |
constante de type si64 |
(C6) |
| (I5) | use_global_device_ids |
constante de type i1 |
(C6) |
| (I6) | computation |
fonction | (C7) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C8-C9) |
Contraintes
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0. - (C2)
0 <= scatter_dimension < rank(operand). - (C3)
is_unique(replica_groups). - (C4)
size(replica_groups)est défini comme suit :num_replicassicross_replicaest utilisé.num_replicassicross_replica_and_partitionest utilisé.num_processessiflattened_idsest utilisé.
- (C5)
0 <= replica_groups < size(replica_groups). - (C6) Si
use_global_device_ids = true, alorschannel_id > 0. - (C7)
computationest de type(tensor<E>, tensor<E>) -> (tensor<E>)oùis_promotable(element_type(operand), E). - (C8)
shape(result) = shape(operand)sauf :dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
- (C9)
element_type(result) = E.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimension = 1 :< i64,
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Sémantique
Applique une fonction de réduction body aux fenêtres inputs et init_values et génère results.
Le schéma suivant montre comment les éléments de results... sont calculés à partir de inputs... à l'aide d'un exemple concret.
Plus formellement, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (voir reduce) où :
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).window_start = result_index * window_strideswindow_end = window_start + (window_dimensions - 1) * window_dilations + 1windows = slice(padded_inputs..., window_start, window_end, window_dilations).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | inputs |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | init_values |
nombre variable de Tensors de dimension 0 ou de Tensors quantifiés par Tensor | (C1), (C13) |
| (I3) | window_dimensions |
Constante Tensor de type si64 à une dimension |
(C4), (C5), (C15) |
| (I4) | window_strides |
Constante Tensor de type si64 à une dimension |
(C6), (C7), (C15) |
| (I5) | base_dilations |
Constante Tensor de type si64 à une dimension |
(C8), (C9), (C15) |
| (I6) | window_dilations |
Constante Tensor de type si64 à une dimension |
(C10), (C11), (C15) |
| (I7) | padding |
Constante Tensor à deux dimensions de type si64 |
(C12), (C15) |
| (I8) | body |
fonction | (C13) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1), (C14-C16) |
Contraintes
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N. - (C2)
same(shape(inputs...)). - (C3)
element_type(inputs...) = element_type(init_values...). - (C4)
size(window_dimensions) = rank(inputs[0]). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(inputs[0]). - (C7)
0 < window_strides. - (C8)
size(base_dilations) = rank(inputs[0]). - (C9)
0 < base_dilations. - (C10)
size(window_dilations) = rank(inputs[0]). - (C11)
0 < window_dilations. - (C12)
shape(padding) = [rank(inputs[0]), 2]. - (C13)
bodyest de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)oùis_promotable(element_type(inputs[i]), Ei). - (C14)
same(shape(results...)). - (C15)
shape(results[0]) = num_windowsoù :dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]dilated_window_shape = (window_dimensions - 1) * window_dilations + 1is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shapenum_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
- (C16)
element_type(results[i]) = Eipour tous lesidans[0,N).
Exemples
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
wind>ow_dimensions = arrayi64: <2, 1,
w>indow_strides = arrayi64: <4, 1,
b>ase_dilations = arrayi64: 2,< 1,
win>dow_dilations = arr<ayi64: 3, 1,
p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]
reste
Sémantique
Effectue le reste élément par élément des Tensors de dividende lhs et de diviseur rhs, et produit un Tensor result.
Plus formellement, le signe du résultat est tiré du dividende, et la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur.
Le reste est calculé comme suit : lhs - d * rhs, où d est donné par :
- Pour les nombres entiers :
stablehlo.divide(lhs, rhs). - Pour les valeurs flottantes :
division(lhs, rhs)à partir de IEEE-754 avec l'attribut d'arrondiroundTowardZero. - Pour les nombres complexes : à déterminer (#997).
- Pour les types quantifiés :
dequantize_op_quantize(remainder, lhs, rhs, type(result)).
Pour les types d'éléments à virgule flottante, cette opération est différente de l'opération remainder de la spécification IEEE-754, où d est une valeur entière la plus proche de la valeur exacte de lhs/rhs avec des liens vers le nombre pair.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]
replica_id
Sémantique
Produit replica_id du processus actuel.
Sorties
| Nom | Type |
|---|---|
result |
Tensor de dimension 0 de type ui32 |
Exemples
%result = "stablehlo.replica_id">;() : (<) - >tensorui32
reshape
Sémantique
Remodele le Tensor operand en un Tensor result. Conceptuellement, cela revient à conserver la même représentation canonique, mais à modifier potentiellement la forme, par exemple de tensor<2x3xf32> à tensor<3x2xf32> ou tensor<6xf32>.
Plus précisément, result[result_index] = operand[operand_index] où result_index et operand_index ont la même position dans l'ordre lexicographique de index_space(result) et index_space(operand).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié | (C1-C3) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C1-C3) |
Contraintes
- (C1)
element_type(result)est donné par :element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), sauf quequantization_dimension(operand)etquantization_dimension(result)peuvent être différents.
- (C2)
size(operand) = size(result). - (C3) Si
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]
inverser
Sémantique
Inverse l'ordre des éléments dans operand le long de la dimensions spécifiée et génère un Tensor result. Plus formellement,
result[result_index] = operand[operand_index] où :
operand_index[d] = dim(result, d) - result_index[d] - 1siddansdimensions.- Sinon,
operand_index[d] = result_index[d].
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1), (C3) |
| (I2) | dimensions |
Constante Tensor de type si64 à une dimension |
(C2), (C3) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1), (C3) |
Contraintes
- (C1)
type(operand) = type(result). - (C2)
is_unique(dimensions). - (C3)
0 <= dimensions < rank(result).
Exemples
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Sémantique
Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution et produit un Tensor result d'une forme donnée shape.
Si la valeur est rng_distribution = UNIFORM, les nombres aléatoires sont générés selon la distribution uniforme sur l'intervalle [a, b). Si a >= b, le comportement n'est pas défini.
Si rng_distribution = NORMAL, les nombres aléatoires sont générés selon la distribution normale avec la moyenne = a et l'écart-type = b.
Si la valeur est b < 0, le comportement n'est pas défini.
La méthode exacte de génération des nombres aléatoires est définie par l'implémentation. Par exemple, ils peuvent être déterministes ou non, et utiliser ou non un état caché.
Lors de conversations avec de nombreuses parties prenantes, cette opération a été considérée comme obsolète. Nous prévoyons donc de l'explorer à l'avenir (#597).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | a |
Tensor de type entier, booléen ou à virgule flottante de dimension 0 | (C1), (C2) |
| (I2) | b |
Tensor de type entier, booléen ou à virgule flottante de dimension 0 | (C1), (C2) |
| (I3) | shape |
Constante Tensor de type si64 à une dimension |
(C3) |
| (I4) | rng_distribution |
Énumération de UNIFORM et NORMAL |
(C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type entier, booléen ou à virgule flottante | (C1-C3) |
Contraintes
- (C1)
element_type(a) = element_type(b) = element_type(result). - (C2) Si
rng_distribution = NORMAL, alorsis_float(a). - (C3)
shape(result) = shape.
Exemples
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Sémantique
Renvoie un output rempli de bits aléatoires uniformes et un état de sortie mis à jour output_state à l'aide de l'algorithme de générateur de nombres pseudo-aléatoires rng_algorithm, étant donné un état initial initial_state. La sortie est garantie comme étant une fonction déterministe de initial_state, mais elle n'est pas garantie comme étant déterministe entre les implémentations.
rng_algorithm est l'un des éléments suivants :
DEFAULT: algorithme défini par l'implémentation.THREE_FRY: variante de l'algorithme Threefry définie par l'implémentation*.PHILOX: variante de l'algorithme Philox définie par l'implémentation*.
* Voir Salmon et al. SC 2011. Nombres aléatoires parallèles : un jeu d'enfant.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | rng_algorithm |
Énumération de DEFAULT, THREE_FRY et PHILOX |
(C2) |
| (I2) | initial_state |
Tensor de type ui64 à une dimension |
(C1), (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
output_state |
Tensor de type ui64 à une dimension |
(C1) |
output |
Tensor de type entier ou à virgule flottante |
Contraintes
- (C1)
type(initial_state) = type(output_state). - (C2)
size(initial_state)est défini comme suit :- défini par l'implémentation si
rng_algorithm = DEFAULT. 2sirng_algorithm = THREE_FRY.2ou3sirng_algorithm = PHILOX.
- défini par l'implémentation si
Exemples
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Sémantique
Effectue un arrondi élément par élément vers l'entier le plus proche, en cas d'égalité, l'arrondi s'éloigne de zéro, sur le Tensor operand et produit un Tensor result. Implémente l'opération roundToIntegralTiesToAway de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_afz, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Sémantique
Effectue un arrondi élément par élément vers l'entier le plus proche, en cas d'égalité vers l'entier pair, sur le Tensor operand et produit un Tensor result. Implémente l'opération roundToIntegralTiesToEven à partir de la spécification IEEE-754. Pour les types quantifiés, effectue dequantize_op_quantize(round_nearest_even, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Sémantique
Effectue une opération de racine carrée réciproque au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
rSqrtà partir de la norme IEEE-754. - Pour les nombres complexes : racine carrée réciproque complexe.
- Pour les types quantifiés :
dequantize_op_quantize(rsqrt, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
disperser
Sémantique
Produit des Tensors results égaux aux Tensors inputs, sauf que plusieurs tranches spécifiées par scatter_indices sont mises à jour avec les valeurs updates à l'aide de update_computation.
Le schéma suivant montre comment les éléments de updates... sont mappés sur les éléments de results... à l'aide d'un exemple concret. Le diagramme choisit quelques exemples d'indices updates... et explique en détail à quels indices results... ils correspondent.
Plus formellement, pour tout update_index dans index_space(updates[0]) :
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].update_scatter_index = update_index[update_scatter_dims...].start_indexest défini comme suit :scatter_indices[si0, ..., :, ..., siN]oùsisont des éléments individuels dansupdate_scatter_indexet:est inséré à l'indexindex_vector_dim, siindex_vector_dim<rank(scatter_indices).- Sinon,
[scatter_indices[update_scatter_index]].
- Pour
d_inputdansaxes(inputs[0]),full_start_index[d_input] = start_index[d_start]sid_input = scatter_dims_to_operand_dims[d_start].- Sinon,
full_start_index[d_input] = 0.
- Pour
d_inputdansaxes(inputs[0]),full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sid_input = input_batching_dims[i_batching]etd_start = scatter_indices_batching_dims[i_batching].- Sinon,
full_batching_index[d_input] = 0.
update_window_index = update_index[update_window_dims...].full_window_index = [wi0, ..., 0, ..., wiN]oùwisont des éléments individuels dansupdate_window_index, et0est inséré aux indices deinserted_window_dimsetinput_batching_dims.result_index = full_start_index + full_batching_index + full_window_index.
Dans ce cas, results = exec(schedule, inputs), où :
scheduleest une permutation deindex_space(updates[0])définie par l'implémentation.exec([update_index, ...], results) = exec([...], updated_results)où :- Si
result_indexest dans les limites deshape(results...) updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_resultsest une copie deresultsavecresults...[result_index]défini surupdated_values....- Sinon, procédez comme suit :
updated_results = results.
- Si
exec([], results) = results.
Si indices_are_sorted est true, l'implémentation peut supposer que scatter_indices est trié par rapport à scatter_dims_to_operand_dims. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).
Si unique_indices est true, l'implémentation peut supposer que tous les indices result_index dispersés sont uniques. Si unique_indices est true, mais que les indices vers lesquels la dispersion est effectuée ne sont pas uniques, le comportement n'est pas défini.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | inputs |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
| (I2) | scatter_indices |
tenseur de type entier | (C4), (C15), (C19), (C22) |
| (I3) | updates |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C3-C6), (C8) |
| (I4) | update_window_dims |
Constante Tensor de type si64 à une dimension |
(C2), (C4), (C7-C8) |
| (I5) | inserted_window_dims |
Constante Tensor de type si64 à une dimension |
(C2), (C4), (C9-C11) |
| (I6) | input_batching_dims |
Constante Tensor de type si64 à une dimension |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
| (I7) | scatter_indices_batching_dims |
Constante Tensor de type si64 à une dimension |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
Constante Tensor de type si64 à une dimension |
(C19-C21) |
| (I9) | index_vector_dim |
constante de type si64 |
(C4), (C16), (C19), (C22) |
| (I10) | indices_are_sorted |
constante de type i1 |
|
| (I11) | unique_indices |
constante de type i1 |
|
| (I12) | update_computation |
fonction | (C23) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C24-C25) |
Contraintes
- (C1)
same(shape(inputs...)). - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims). - (C3)
same(shape(updates...)). - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)où :update_scatter_dim_sizes = shape(scatter_indices), sauf que la taille de la dimensionscatter_indicescorrespondant àindex_vector_dimn'est pas incluse.update_window_dim_sizes <= shape(inputs[0]), sauf que les tailles de dimension dansinputs[0]correspondant àinserted_window_dimsetinput_batching_dimsne sont pas incluses.combineplaceupdate_scatter_dim_sizessur les axes correspondant àupdate_scatter_dimsetupdate_window_dim_sizessur les axes correspondant àupdate_window_dims.
- (C5)
0 < size(inputs) = size(updates) = N. - (C6)
element_type(updates...) = element_type(inputs...). - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims). - (C8)
0 <= update_window_dims < rank(updates[0]). - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims). - (C11)
0 <= inserted_window_dims < rank(inputs[0]). - (C12)
is_sorted(input_batching_dims). - (C13)
0 <= input_batching_dims < rank(inputs[0])). - (C14)
is_unique(scatter_indices_batching_dims). - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices). - (C16)
index_vector_dim not in scatter_indices_batching_dims. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims). - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...). - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)). - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0]). - (C22)
0 <= index_vector_dim <= rank(scatter_indices). - (C23)
update_computationest de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), oùis_promotable(element_type(inputs[i]), Ei). - (C24)
shape(inputs...) = shape(results...). - (C25)
element_type(results[i]) = Eipour tous lesidans[0,N).
Exemples
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimensio<n_numbers = #stablehlo.scatter
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2>, 1],
index_vector_dim = 3,
indices_are_sorted = false,
uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
sélectionner
Sémantique
Génère un Tensor result où chaque élément est sélectionné à partir du Tensor on_true ou on_false en fonction de la valeur de l'élément pred correspondant.
Plus formellement, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index], où pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]. Pour les types quantifiés, effectue dequantize_select_quantize(pred, on_true, on_false, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | pred |
tenseur de type i1 |
(C1) |
| (I2) | on_true |
tenseur ou tenseur quantifié par tenseur | (C1-C2) |
| (I3) | on_false |
tenseur ou tenseur quantifié par tenseur | (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C2) |
Contraintes
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true). - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).
Exemples
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]
select_and_scatter
Sémantique
Disperse les valeurs du Tensor source à l'aide de scatter en fonction du résultat de reduce_window du Tensor input à l'aide de select et produit un Tensor result.
Le diagramme suivant montre comment les éléments de result sont calculés à partir de operand et source à l'aide d'un exemple concret.
Plus précisément :
selected_values = reduce_window_without_init(...)avec les entrées suivantes :inputs = [operand].window_dimensions,window_stridesetpaddingsont utilisés tels quels.base_dilations = windows_dilations = 1.bodyest défini comme suit :
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;où
E = element_type(operand)etreduce_window_without_initfonctionnent exactement commereduce_window, sauf que lescheduledureducesous-jacent (voir reduce) n'inclut pas les valeurs d'initialisation. Le comportement de la fonction lorsque la fenêtre correspondante ne contient pas de valeurs n'est pas spécifié pour le moment (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)où :source_values = [source[source_index] for source_index in source_indices].selected_index(source_index) = operand_indexsiselected_values[source_index]comporte l'élémentoperanddeoperand_index.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1-C4), (C6), (C8-C11) |
| (I2) | source |
tenseur ou tenseur quantifié par tenseur | (C1), (C2) |
| (I3) | init_value |
Tenseur de dimension 0 ou tenseur quantifié par tenseur | (C3) |
| (I4) | window_dimensions |
Constante Tensor de type si64 à une dimension |
(C2), (C4), (C5) |
| (I5) | window_strides |
Constante Tensor de type si64 à une dimension |
(C2), (C6), (C7) |
| (I6) | padding |
Constante Tensor à deux dimensions de type si64 |
(C2), (C8) |
| (I7) | select |
fonction | (C9) |
| (I8) | scatter |
fonction | (C10) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C11-C12) |
Contraintes
- (C1)
element_type(operand) = element_type(source). - (C2)
shape(source) = num_windowsoù :padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shapenum_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
- (C3)
element_type(init_value) = element_type(operand). - (C4)
size(window_dimensions) = rank(operand). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(operand). - (C7)
0 < window_strides. - (C8)
shape(padding) = [rank(operand), 2]. - (C9)
selectest de type(tensor<E>, tensor<E>) -> tensor<i1>oùE = element_type(operand). - (C10)
scatterest de type(tensor<E>, tensor<E>) -> tensor<E>oùis_promotable(element_type(operand), E). - (C11)
shape(operand) = shape(result). - (C12)
element_type(result) = E.
Exemples
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>E
} <: (>ten>sori64,< t>ensori64) - tensori1
"stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
^bb0(%<arg>0: tensori64, %arg1: tensori64):
%0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
> "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
window_dim<ensions => arrayi64: 3, 1,
<window_strides => arrayi64<: 2, 1,>
padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
envoyer
Sémantique
Envoie inputs à la chaîne channel_id. Les entrées sont ensuite envoyées aux autres appareils dans l'ordre spécifié par source_target_pairs. L'opération génère un jeton result.
Si is_host_transfer est défini sur true, l'opération transfère les données vers l'hôte. Sinon, il transfère les données vers un autre appareil en fonction des valeurs de source_target_pairs. Cette option duplique les informations fournies dans channel_type. À l'avenir, nous prévoyons de n'en conserver qu'une seule (#666). Si is_host_transfer= false et que source_target_pairs est None ou vide, le comportement est considéré comme indéfini.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | inputs |
un nombre variable de Tensors ou de Tensors quantifiés. | |
| (I2) | token |
token |
|
| (I3) | channel_id |
constante de type si64 |
|
| (I4) | channel_type |
Énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST |
(C5) |
| (I5) | is_host_transfer |
constante de type i1 |
(C5-C6) |
| (I6) | source_target_pairs |
Constante Tensor à deux dimensions de type si64 |
(C1-C4), (C6) |
Sorties
| Nom | Type |
|---|---|
result |
token |
Contraintes
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, oùNest défini comme suit :num_replicassicross_replicaest utilisé.num_partitionssicross_partitionest utilisé.
- (C5)
channel_typeest défini comme suit :DEVICE_TO_HOSTsiis_host_transfer = true,- Sinon,
DEVICE_TO_DEVICE.
Exemples
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token
shift_left
Sémantique
Effectue une opération de décalage à gauche au niveau des éléments sur le Tensor lhs par le nombre de bits rhs et génère un Tensor result.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur de type entier | (C1) |
| (I2) | rhs |
tenseur de type entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result).
Exemples
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]
shift_right_arithmetic
Sémantique
Effectue une opération de décalage arithmétique vers la droite au niveau des éléments sur le Tensor lhs par rhs bits et génère un Tensor result.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur de type entier | (C1) |
| (I2) | rhs |
tenseur de type entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result).
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]
shift_right_logical
Sémantique
Effectue une opération de décalage logique vers la droite au niveau des éléments sur le Tensor lhs par rhs bits et génère un Tensor result.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
tenseur de type entier | (C1) |
| (I2) | rhs |
tenseur de type entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result).
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]
signer
Sémantique
Renvoie le signe de l'élément operand et produit un Tensor result.
Plus formellement, pour chaque élément x, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit :
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Pour les types quantifiés, effectue dequantize_op_quantize(sign, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier signé, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Sémantique
Effectue une opération sinus au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
sinà partir de la norme IEEE-754. - Pour les nombres complexes : sinus complexe.
- Pour les types quantifiés :
dequantize_op_quantize(sine, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Sémantique
Extrait une tranche de operand à l'aide d'indices de début calculés de manière statique et produit un Tensor result. start_indices contient les index de début de la tranche pour chaque dimension, limit_indices contient les index de fin (exclusifs) de la tranche pour chaque dimension et strides contient les foulées pour chaque dimension.
Plus précisément, result[result_index] = operand[operand_index] où operand_index = start_indices + result_index * strides.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié par tenseur | (C1-C3), (C5) |
| (I2) | start_indices |
Constante Tensor de type si64 à une dimension |
(C2), (C3), (C5) |
| (I3) | limit_indices |
Constante Tensor de type si64 à une dimension |
(C2), (C3), (C5) |
| (I4) | strides |
Constante Tensor de type si64 à une dimension |
(C2), (C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié par tenseur | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand). - (C3)
0 <= start_indices <= limit_indices <= shape(operand). - (C4)
0 < strides. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides).
Exemples
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indic<es = arra>yi64: 1, 2,
limit_indic<es = arra>yi64: 3, 4,
strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
// [1, 1],
// [1, 1]
// ]
trier
Sémantique
Trie ensemble les tranches à une dimension de inputs le long de la dimension dimension, selon un comparator, et produit results.
Contrairement aux entrées similaires dans d'autres opérations, dimension autorise les valeurs négatives, avec la sémantique décrite ci-dessous. À l'avenir, cela pourra être interdit pour des raisons de cohérence (#1377).
Si is_stable est défini sur "true", le tri est stable, c'est-à-dire que l'ordre relatif des éléments considérés comme égaux par le comparateur est conservé. Dans le cas où il n'y a qu'une seule entrée, deux éléments e1 et e2 sont considérés comme égaux par le comparateur si et seulement si comparator(e1, e2) = comparator(e2, e1) = false. Consultez la formalisation ci-dessous pour savoir comment cela se généralise à plusieurs entrées.
Plus formellement, pour tout result_index dans index_space(results[0]) :
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.result_slice = [ri0, ..., :, ..., riR-1]oùriNsont des éléments individuels dansresult_index, et:est inséré àadjusted_dimension.inputs_together = (inputs[0]..., ..., inputs[N-1]...).results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).- où
sorttrie une tranche unidimensionnelle par ordre croissant, en s'attendant à ce quecomparator_togetherrenvoietruesi l'argument de gauche est inférieur au second argument de droite. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | inputs |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C1-C5) |
| (I2) | dimension |
constante de type si64 |
(C4) |
| (I3) | is_stable |
constante de type i1 |
|
| (I4) | comparator |
fonction | (C5) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
un nombre variable de tenseurs ou des tenseurs quantifiés par tenseur. | (C2), (C3) |
Contraintes
- (C1)
0 < size(inputs). - (C2)
type(inputs...) = type(results...). - (C3)
same(shape(inputs...) + shape(results...)). - (C4)
-R <= dimension < R, oùR = rank(inputs[0]). - (C5)
comparatorest de type(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, oùEi = element_type(inputs[i]).
Exemples
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>T
} <: (>ten>sori64,< t>ensori64) - tensori1
"stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
dimension = 0 : i64,
< is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Sémantique
Effectue une opération de racine carrée élément par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
squareRootà partir de la norme IEEE-754. - Pour les nombres complexes : racine carrée complexe.
- Pour les types quantifiés :
dequantize_op_quantize(sqrt, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Sémantique
Effectue une soustraction élément par élément de deux Tensors lhs et rhs, et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les nombres entiers : soustraction d'entiers.
- Pour les valeurs float :
subtractionà partir de la norme IEEE-754. - Pour les nombres complexes : soustraction complexe.
- Pour les types quantifiés :
dequantize_op_quantize(subtract, lhs, rhs, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
| (I2) | rhs |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type entier, à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Exemples
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]
tan
Sémantique
Effectue une opération tangente au niveau des éléments sur le Tensor operand et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
tanà partir de la norme IEEE-754. - Pour les nombres complexes : tangente complexe.
- Pour les types quantifiés :
dequantize_op_quantize(tan, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Sémantique
Effectue une opération de tangente hyperbolique élément par élément sur le Tensor operand et produit un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs float :
tanhà partir de la norme IEEE-754. - Pour les nombres complexes : tangente hyperbolique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(tanh, operand, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result).
Exemples
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]
transpose
Sémantique
Permute les dimensions du Tensor operand à l'aide de permutation et produit un Tensor result. Plus formellement, result[result_index] = operand[operand_index] où result_index[d] = operand_index[permutation[d]].
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur ou tenseur quantifié | (C1-C4) |
| (I2) | permutation |
Constante Tensor de type si64 à une dimension |
(C2-C4) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur ou tenseur quantifié | (C1), (C3-C4) |
Contraintes
- (C1)
element_type(result)est donné par :element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), sauf quequantization_dimension(operand)etquantization_dimension(result)peuvent être différents.
- (C2)
permutationest une permutation derange(rank(operand)). - (C3)
shape(result) = dim(operand, permutation...). - (C4) Si
is_per_axis_quantized(result), alorsquantization_dimension(operand) = permutation(quantization_dimension(result)).
Exemples
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Sémantique
Résout des lots de systèmes d'équations linéaires avec des matrices de coefficients triangulaires inférieures ou supérieures.
Plus formellement, étant donné a et b, result[i0, ..., iR-3, :, :] est la solution de op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] lorsque left_side est true ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] lorsque left_side est false, en résolvant la variable x où op(a) est déterminé par transpose_a, qui peut être l'une des valeurs suivantes :
NO_TRANSPOSE: effectuez l'opération en utilisantatel quel.TRANSPOSE: effectue l'opération sur la transposée dea.ADJOINT: effectue l'opération sur la transposée conjuguée dea.
Les données d'entrée ne sont lues qu'à partir du triangle inférieur de a, si lower est true ou du triangle supérieur de a, dans le cas contraire. Les données de sortie sont renvoyées dans le même triangle. Les valeurs de l'autre triangle sont définies par l'implémentation.
Si unit_diagonal est défini sur "true", l'implémentation peut supposer que les éléments diagonaux de a sont égaux à 1. Sinon, le comportement est indéfini.
Pour les types quantifiés, effectue dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | a |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1-C3) |
| (I2) | b |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1-C4) |
| (I3) | left_side |
constante de type i1 |
(C3) |
| (I4) | lower |
constante de type i1 |
|
| (I5) | unit_diagonal |
constante de type i1 |
|
| (I6) | transpose_a |
Énumération de NO_TRANSPOSE, TRANSPOSE et ADJOINT |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tenseur de type à virgule flottante ou complexe, ou tenseur quantifié par tenseur | (C1) |
Contraintes
- (C1)
baseline_element_type(a) = baseline_element_type(b). - (C2)
2 <= rank(a) = rank(b) = R. - (C3) La relation entre
shape(a)etshape(b)est définie comme suit :shape(a)[:-3] = shape(b)[:-3].dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
- (C4)
baseline_type(b) = baseline_type(result).
Exemples
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Sémantique
Génère un tuple result à partir des valeurs val.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | val |
nombre variable de valeurs | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tuple | (C1) |
Contraintes
- (C1)
resultest de typetuple<E0, ..., EN-1>oùEi = type(val[i]).
Exemples
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
Sémantique
Effectue une conversion élément par élément du tenseur quantifié operand en un tenseur à virgule flottante result en fonction des paramètres de quantification définis par le type operand.
Plus précisément, result = dequantize(operand).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
tenseur quantifié | (C1), (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result). - (C2)
element_type(result) = expressed_type(operand).
Exemples
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]
uniform_quantize
Sémantique
Effectue une conversion élément par élément du tenseur à virgule flottante ou du tenseur quantifié operand en un tenseur quantifié result en fonction des paramètres de quantification définis par le type result.
Plus précisément,
- Si
is_float(operand):result = quantize(operand, type(result)).
- Si
is_quantized(operand):float_result = dequantize(operand).result = quantize(float_result, type(result)).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
Tenseur de type à virgule flottante ou quantifié | (C1), (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
tenseur quantifié | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result). - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).
Exemples
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]
pendant que
Sémantique
Produit la sortie de l'exécution de la fonction body zéro fois ou plus, tandis que la fonction cond génère true. Plus formellement, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit :
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Le comportement d'une boucle infinie est à déterminer (#383).
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | operand |
nombre variable de valeurs | (C1-C3) |
| (I2) | cond |
fonction | (C1) |
| (I3) | body |
fonction | (C2) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
results |
nombre variable de valeurs | (C3) |
Contraintes
- (C1)
condest de type(T0, ..., TN-1) -> tensor<i1>, oùTi = type(operand[i]). - (C2)
bodyest de type(T0, ..., TN-1) -> (T0, ..., TN-1), oùTi = type(operand[i]). - (C3)
type(results...) = type(operand...).
Exemples
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_di<rection = #stablehlocom>parison_directio<n L>T
} <: (>ten>sori64,< t>ensori64) - tensori1
stablehlo.r<et>urn %cond : tensori1
}, {
< ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
%new_sum = stablehlo.add <%ar>g1, %one : tensori64
%new_i = stablehlo.add <%ar>g0, %one : tensori64
stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10
xor
Sémantique
Effectue une opération XOR (OU exclusif) élément par élément sur deux Tensors lhs et rhs, et génère un Tensor result. En fonction du type d'élément, effectue les actions suivantes :
- Pour les valeurs booléennes : XOR logique.
- Pour les nombres entiers : XOR (OU exclusif) bit à bit.
Entrées
| Libellé | Nom | Type | Contraintes |
|---|---|---|---|
| (I1) | lhs |
Tensor de type booléen ou entier | (C1) |
| (I2) | rhs |
Tensor de type booléen ou entier | (C1) |
Sorties
| Nom | Type | Contraintes |
|---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result).
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]
Interopérabilité des dialectes
À l'heure actuelle, les programmes StableHLO en circulation contiennent parfois des opérations qui ne sont pas définies par StableHLO.
Module, fonction, appel et retour
StableHLO utilise des opérations MLIR en amont pour ModuleOp, FuncOp, CallOp et ReturnOp. Cela a été fait pour une meilleure interopérabilité avec les mécanismes MLIR existants, car de nombreux passes utiles sont écrits en ciblant FuncOp et ModuleOp, et de nombreux pipelines de compilation s'attendent à ce que ces opérations soient présentes. Une compatibilité totale est garantie pour ces opérations. Si ces opérations changent de manière incompatible (par exemple, si elles sont supprimées), des équivalents StableHLO seront ajoutés pour préserver la compatibilité.
CHLO
L'opset CHLO contient des opérations de niveau supérieur qui se décomposent en StableHLO. Actuellement, aucune garantie de compatibilité n'est disponible pour CHLO. Pour garantir la compatibilité, le chlo-legalize-to-stablehlo pass doit être utilisé avant la sérialisation.
Opérations sur les formes
Il est courant dans la communauté d'utiliser certaines opérations des dialectes MLIR de base dans les programmes StableHLO dynamiques pour effectuer des calculs de forme.
Le plus souvent, il s'agit d'opérations shape telles que shape_of ou num_elements, d'opérations tensor telles que dim ou from_elements, et du type index intégré.
Le RFC sur le dynamisme > O2 indique que ces types sont hors champ d'application. Toutefois, une certaine prise en charge des types index est incluse à des fins d'interopérabilité. Il n'existe aucune garantie de compatibilité pour ces opérations ou types. Le pass shape-legalize-to-stablehlo peut être utilisé pour convertir ces opérations en opérations StableHLO entièrement compatibles.
Opérations obsolètes
Plusieurs opérations StableHLO héritées de MHLO sont obsolètes et seront bientôt supprimées de StableHLO. Pour en savoir plus sur ces suppressions, consultez StableHLO v1.0 Cleanup #2283. Le problème de suivi de ces abandons est le n° 2340.
Ces opérations se répartissent en plusieurs catégories :
- Catégorie "Not in HLO" des opérations StableHLO : elles faisaient initialement partie de l'opset StableHLO, mais ont ensuite été jugées inadaptées :
broadcast,create_token,cross-replica-sum,dot,einsum,torch_index_select,unary_einsum(#3). - Opérations inutilisées : ces opérations ont peut-être été utiles à un moment donné, mais elles étaient soit sous-développées, soit les pipelines qui les utilisaient ont été refactorisés pour ne plus en avoir besoin. Cela inclut les comparaisons
map,tuple(#598),get_tuple_element,rng,complex#560 et la convolutionwindow_reversal(#1181).
Certaines de ces opérations peuvent être facilement supprimées, car elles peuvent être exprimées à l'aide d'opérations existantes (broadcast, create_token, cross-replica-sum, dot, unary_einsum) et seront supprimées une fois la période de compatibilité existante (six mois) écoulée. D'autres sont encore à l'étude pour être supprimées (comparaisons einsum, get_tuple_element, map, rng, torch_index_select, tuple, complex, window_reversal). En fonction des commentaires de la communauté, ces opérations seront supprimées ou ajoutées à la spécification avec une prise en charge complète. Tant que ces futures opérations ne sont pas connues, la compatibilité n'est garantie que pendant six mois.
Exécution
Exécution séquentielle
Un programme StableHLO est exécuté en fournissant des valeurs d'entrée à la fonction main et en calculant les valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées en exécutant le graphique des opérations ancré dans l'opération return correspondante.
L'ordre d'exécution est défini par l'implémentation tant qu'il est aligné sur le flux de données, c'est-à-dire si les opérations sont exécutées avant leurs utilisations. Dans StableHLO, toutes les opérations à effet secondaire consomment un jeton et en produisent un (plusieurs jetons peuvent être multiplexés en un seul jeton via after_all). L'ordre d'exécution des effets secondaires est donc également aligné sur le flux de données. Par exemple, dans le programme ci-dessous, il existe deux ordres d'exécution possibles : %0 → %1 → %2 → return et %1 → %0 → %2 → return.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Plus précisément, un processus StableHLO est une combinaison des éléments suivants :
1) un programme StableHLO, 2) des états d'opération (pas encore exécutée, déjà exécutée) et 3) des valeurs intermédiaires sur lesquelles le processus travaille.
Le processus commence par des valeurs d'entrée pour la fonction main, progresse dans le graphique des opérations en mettant à jour les états des opérations et les valeurs intermédiaires, et se termine par des valeurs de sortie. La formalisation supplémentaire est à déterminer (#484).
Exécution parallèle
Les programmes StableHLO peuvent être exécutés en parallèle, organisés dans une grille de processus 2D de num_replicas par num_partitions, qui sont tous deux de type ui32.
Dans la grille de processus StableHLO, num_replicas * num_partitions processus StableHLO s'exécutent en même temps. Chaque processus possède un process_id = (replica_id, partition_id) unique, où replica_id dans replica_ids = range(num_replicas) et partition_id dans partition_ids = range(num_partitions) ont tous deux le type ui32.
La taille de la grille de processus est connue de manière statique pour chaque programme (à l'avenir, nous prévoyons d'en faire une partie explicite des programmes StableHLO #650), et la position dans la grille de processus est connue de manière statique pour chaque processus. Chaque processus a accès à sa position dans la grille de processus via les opérations replica_id et partition_id.
Dans la grille de processus, les programmes peuvent être identiques (style "Programme unique, données multiples"), tous différents (style "Programmes multiples, données multiples") ou quelque chose entre les deux. À l'avenir, nous prévoyons d'ajouter la compatibilité avec d'autres idiomes de définition des programmes StableHLO parallèles, y compris GSPMD (#619).
Dans la grille de processus, les processus sont généralement indépendants les uns des autres. Ils ont des états d'opération distincts, des valeurs d'entrée/intermédiaires/de sortie distinctes et la plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.
Étant donné que l'exécution de la plupart des opérations n'utilise que des valeurs provenant du même processus, il est généralement clair de faire référence à ces valeurs par leur nom.
Toutefois, cela ne suffit pas pour décrire la sémantique des opérations collectives, ce qui donne lieu à la notation name@process_id pour faire référence à la valeur name dans un processus particulier. (De ce point de vue, un name non qualifié peut être considéré comme une abréviation de name@(replica_id(), partition_id()).)
L'ordre d'exécution des processus est défini par l'implémentation, à l'exception de la synchronisation introduite par la communication point à point et les opérations collectives, comme décrit ci-dessous.
Communication point à point
Les processus StableHLO peuvent communiquer entre eux via des canaux StableHLO. Un canal est représenté par un ID positif de type si64. Grâce à différentes opérations, il est possible d'envoyer des valeurs aux canaux et de les recevoir des canaux.
Une formalisation plus poussée (par exemple, d'où proviennent ces ID de canal, comment les programmes de processus en prennent connaissance et quel type de synchronisation est introduit par eux) est à déterminer (#484).
Communication en streaming
Chaque processus StableHLO a accès à deux interfaces de streaming :
- Infeed à partir duquel la lecture peut être effectuée.
- Outfeed dans lequel écrire.
Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et ont donc des processus à leurs deux extrémités, les flux d'entrée et de sortie ont leur autre extrémité définie par l'implémentation.
La formalisation supplémentaire, par exemple la façon dont la communication en streaming influence l'ordre d'exécution et le type de synchronisation qu'elle introduit, est à déterminer (#484).
Opérations collectives
Il existe six opérations collectives dans StableHLO : all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute et reduce_scatter. Toutes ces opérations divisent les processus dans la grille de processus StableHLO en groupes de processus StableHLO et exécutent un calcul conjoint dans chaque groupe de processus, indépendamment des autres groupes de processus.
Dans chaque groupe de processus, les opérations collectives peuvent introduire une barrière de synchronisation. Une formalisation plus poussée, par exemple en précisant le moment exact où cette synchronisation se produit, la manière exacte dont les processus atteignent cette barrière et ce qui se passe s'ils ne l'atteignent pas, est à déterminer (#484).
Si le groupe de processus implique une communication entre partitions (c'est-à-dire qu'il existe des processus dans le groupe de processus dont les ID de partition sont différents), l'exécution de l'opération collective nécessite un canal, et l'opération collective doit fournir un channel_id positif de type si64. La communication entre les répliques n'a pas besoin de canaux.
Les calculs effectués par les opérations collectives sont spécifiques aux opérations individuelles et sont décrits dans les sections d'opérations individuelles ci-dessus. Toutefois, les stratégies selon lesquelles la grille de processus est divisée en groupes de processus sont partagées entre ces opérations et sont décrites dans cette section. Plus précisément, StableHLO prend en charge les quatre stratégies suivantes.
cross_replica
Seules les communications entre les répliques ont lieu au sein de chaque groupe de processus. Cette stratégie prend replica_groups (liste de listes d'ID de répliques) et calcule un produit cartésien de replica_groups par partition_ids. replica_groups doit comporter des éléments uniques et couvrir tous les replica_ids. Plus formellement, en utilisant la syntaxe Python :
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica générera [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].
cross_partition
Seules les communications entre partitions ont lieu au sein de chaque groupe de processus. Cette stratégie prend partition_groups (une liste de listes d'ID de partition) et calcule un produit cartésien de partition_groups par replica_ids.
partition_groups doit comporter des éléments uniques et couvrir tous les partition_ids.
Plus précisément, en utilisant la syntaxe Python :
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour partition_groups = [[0, 1]] et num_replicas = 4, cross_partition générera [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].
cross_replica_and_partition
Les communications entre réplicas et entre partitions peuvent avoir lieu dans chaque groupe de processus. Cette stratégie prend replica_groups (liste de listes d'ID de répliques) et calcule les produits cartésiens de chaque replica_group par partition_ids. replica_groups doit comporter des éléments uniques et couvrir tous les replica_ids. Plus précisément, en utilisant la syntaxe Python :
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]] et num_partitions = 2, cross_replica_and_partition générera [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].
flattened_ids
Cette stratégie prend flattened_id_groups (une liste de listes d'ID de processus "aplatis" sous la forme replica_id * num_partitions + partition_id) et les transforme en ID de processus. flattened_id_groups doit comporter des éléments uniques et couvrir tous les process_ids. Plus précisément, en utilisant la syntaxe Python :
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 et num_partitions = 2, flattened_ids générera [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].
Précision
Pour le moment, StableHLO ne fournit aucune garantie concernant la précision numérique, mais cela peut changer à l'avenir (#1156).
Sémantique d'exécution de l'opération quantifiée
L'interprétation des opérations StableHLO quantifiées peut varier en fonction des exigences et des capacités matérielles. Par exemple, certains matériels peuvent choisir d'interpréter les opérations quantifiées à l'aide d'une stratégie "déquantifier, effectuer une opération à virgule flottante et enfin quantifier". D'autres peuvent effectuer l'intégralité du calcul avec une arithmétique entière. Par conséquent, l'interprétation des opérations StableHLO quantifiées est exclusivement déterminée par l'implémentation spécifique. L'interprétation de la quantification hybride (#1575) doit être basée sur sa sémantique telle que prescrite dans la spécification (via 1792).
Erreurs
Les programmes StableHLO sont validés à l'aide d'un ensemble complet de contraintes pour les opérations individuelles, ce qui exclut de nombreuses classes d'erreurs avant l'exécution. Toutefois, des conditions d'erreur sont toujours possibles, par exemple en cas de dépassement de capacité d'entier, d'accès hors limites, etc. Sauf indication explicite, toutes ces erreurs entraînent un comportement défini par l'implémentation, mais cela peut changer à l'avenir (#1157).
Exceptions à virgule flottante
Par exception à cette règle, les exceptions à virgule flottante dans les programmes StableHLO ont un comportement bien défini. Les opérations qui entraînent des exceptions définies par la norme IEEE-754 (opération non valide, division par zéro, dépassement de capacité, dépassement de capacité négatif ou exceptions inexactes) produisent des résultats par défaut (tels que définis dans la norme) et poursuivent l'exécution sans déclencher l'indicateur d'état correspondant, comme la gestion des exceptions raiseNoFlag de la norme. Les exceptions pour les opérations non standards (par exemple, l'arithmétique complexe et certaines fonctions transcendantales) sont définies par l'implémentation.
Incompatibilités de forme
StableHLO est compatible avec les tenseurs de forme dynamique. Toutefois, les formes doivent être identiques au moment de l'exécution, sinon le comportement n'est pas défini. StableHLO ne fournit pas explicitement d'opération permettant d'affirmer qu'un Tensor a une forme donnée au moment de l'exécution. Il incombe au producteur de générer le code correct.
Par exemple, le programme ci-dessous est valide. Toutefois, au moment de l'exécution, les formes exactes de %arg0 et %arg1 devront être identiques, sinon le comportement du programme sera indéfini :
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Pour décrire la syntaxe, ce document utilise la variante ISO modifiée de la syntaxe EBNF (ISO/IEC 14977:1996, Wikipédia), avec deux modifications : 1) les règles sont définies à l'aide de ::= plutôt que de =,
2) La concaténation est exprimée par juxtaposition plutôt que par ,.
Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constantes" et "Ops"), nous utilisons des formules basées sur la syntaxe Python, qui sont étendues pour permettre d'exprimer de manière concise les opérations sur les tableaux, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans de rares cas où des extraits de code plus volumineux sont nécessaires, nous utilisons la syntaxe Python de base, qui est toujours introduite de manière explicite.
Formules
Voyons comment fonctionnent les formules à l'aide d'un exemple tiré de la spécification dot_general. L'une des contraintes de cette opération se présente comme suit :
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
Les noms utilisés dans cette formule proviennent de deux sources : 1) les fonctions globales, c'est-à-dire dim ; 2) les définitions de membre de l'élément de programme correspondant, c'est-à-dire les entrées lhs, lhs_batching_dimensions, rhs et rhs_batching_dimensions définies dans la section "Entrées" de dot_general.
Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python, avec quelques extensions axées sur la concision. Pour comprendre la formule, transformons-la en syntaxe Python standard.
A) Dans ces formules, nous utilisons = pour représenter l'égalité. La première étape pour obtenir la syntaxe Python consiste donc à remplacer = par ==, comme suit : dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).
B) Ces formules sont également compatibles avec les points de suspension (...), qui transforment les expressions scalaires en expressions Tensor. En résumé, f(xs...) signifie à peu près "pour chaque scalaire x dans le Tensor xs, calculez un scalaire f(x), puis renvoyez tous ces résultats scalaires ensemble sous forme de Tensor". Dans la syntaxe Python de base, notre exemple de formule devient : [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions].
Grâce aux points de suspension, il est souvent possible d'éviter de travailler au niveau des scalaires individuels. Toutefois, dans certains cas complexes, une syntaxe semi-informelle de niveau inférieur peut être utilisée, comme dans la formule start_indices[bi0, ..., :, ..., biN] de la spécification gather. Par souci de concision, nous ne fournissons pas de formalisme exact pour traduire une telle syntaxe en Python standard, dans l'espoir qu'elle reste intuitivement compréhensible au cas par cas.
N'hésitez pas à nous signaler les formules spécifiques qui vous semblent obscures. Nous essaierons de les améliorer.
Vous remarquerez également que les formules utilisent des points de suspension pour développer toutes sortes de listes, y compris les Tensors, les listes de Tensors (qui peuvent, par exemple, provenir d'un nombre variable de Tensors), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas de formalisme exact (par exemple, les listes ne font même pas partie du système de type StableHLO) et nous nous appuyons plutôt sur une compréhension intuitive.
C) Le dernier véhicule de notation notable que nous utilisons est la diffusion implicite. Bien que l'opset StableHLO ne soit pas compatible avec le broadcasting implicite, les formules le sont, également dans un souci de concision. En résumé, si un scalaire est utilisé dans un contexte où un Tensor est attendu, le scalaire est diffusé à la forme attendue.
Pour poursuivre l'exemple dot_general, voici une autre contrainte :
0 <= lhs_batching_dimensions < rank(lhs). Comme défini dans la spécification dot_general, lhs_batching_dimensions est un Tensor, mais 0 et rank(lhs) sont des scalaires. Après l'application de la diffusion implicite, la formule devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].
Lorsqu'elle est appliquée à une opération dot_general spécifique, cette formule est évaluée en tant que Tensor de booléens. Lorsque des formules sont utilisées comme contraintes, la contrainte est respectée si la formule correspond à true ou à un Tensor qui ne comporte que des éléments true.
Noms
Dans les formules, la portée lexicale inclut : 1) les fonctions globales, 2) les définitions de membres,
3) les définitions locales. La liste des fonctions globales est fournie ci-dessous. La liste des définitions d'éléments dépend de l'élément de programme auquel la notation est appliquée :
- Pour les opérations, les définitions des membres incluent les noms introduits dans les sections "Entrées" et "Sorties".
- Pour tout le reste, les définitions des membres incluent les parties structurelles de l'élément de programme, nommées d'après les non-terminaux EBNF correspondants. La plupart du temps, les noms de ces parties structurelles sont obtenus en convertissant les noms des non-terminaux en snake case (par exemple,
IntegerLiteral=>integer_literal), mais parfois les noms sont abrégés au cours du processus (par exemple,QuantizationStorageType=>storage_type). Dans ce cas, les noms sont introduits explicitement de la même manière que les sections "Entrées" / "Sorties" dans les spécifications des opérations. - De plus, les définitions de membre incluent toujours
selfpour faire référence à l'élément de programme correspondant.
Valeurs
Lorsqu'elles sont évaluées, les formules fonctionnent avec les types de valeurs suivants :
1) Value (valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> ; leurs types sont toujours connus),
2) Placeholder (valeurs futures, par exemple lhs, rhs ou result ; leurs valeurs réelles ne sont pas encore connues, seuls leurs types le sont),
3) Type (types tels que définis dans la section "Types"),
4) Function (fonctions globales telles que définies dans la section "Fonctions").
Selon le contexte, les noms peuvent faire référence à différentes valeurs. Plus précisément, la section "Sémantique" pour les opérations (et les équivalents pour les autres éléments du programme) définit la logique d'exécution. Tous les inputs sont donc disponibles en tant que Value.
En revanche, la section "Constraints" (Contraintes) pour les opérations (et les équivalents) définit la logique de "compilation", c'est-à-dire quelque chose qui est généralement exécuté avant l'exécution. Par conséquent, seules les entrées constantes sont disponibles en tant que Value, et les autres entrées ne sont disponibles qu'en tant que Placeholder.
| Noms | Dans "Sémantique" | Dans "Contraintes" |
|---|---|---|
| Fonctions globales | Function |
Function |
| Entrées constantes | Value |
Value |
| Entrées non constantes | Value |
Placeholder |
| Sorties | Value |
Placeholder |
| Définitions locales | Selon la définition | Selon la définition |
Prenons l'exemple d'une opération transpose :
%result = "stablehlo.transpose"(%operand) {
permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
Pour cette opération, permutation est une constante. Elle est donc disponible en tant que Value dans la sémantique et les contraintes. En revanche, operand et result sont disponibles en tant que Value dans la sémantique, mais uniquement en tant que Placeholder dans les contraintes.
Fonctions
Construction des types
Il n'existe aucune fonction permettant de construire des types. Nous utilisons plutôt directement la syntaxe de type, car elle est généralement plus concise. Par exemple, (tensor<E>, tensor<E>) -> (tensor<E>) plutôt que function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).
Fonctions sur les types
element_typeest défini sur les types de tenseurs et les types de tenseurs quantifiés, et renvoie respectivement la partieTensorElementTypeouQuantizedTensorElementTypeduTensorTypeou duQuantizedTensorTypecorrespondant.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueest un raccourci pouris_quantized(x) and quantization_dimension(x) is not None.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueest un raccourci pouris_quantized(x) and quantization_dimension(x) is None.is_promotable(x: Type, y: Type) -> boolvérifie si le typexpeut être promu au typey. Lorsquexetysont desQuantizedTensorElementType, la promotion ne s'applique qu'àstorage_type. Cette version spécifique de la promotion est actuellement utilisée dans le contexte du calcul des réductions (pour en savoir plus, consultez la RFC).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valueest un raccourci pouris_quantized_tensor_element_type(x).is_type_name(x: Value | Placeholder | Type) -> Value. Disponible pour tous les types. Par exemple,is_float(x)renvoietruesixest unFloatType. Sixest une valeur ou un espace réservé, cette fonction est un raccourci pouris_type_name(type(x)).max_value(x: Type) -> Valuerenvoie la valeur maximale d'unTensorElementType. Sixn'est pas unTensorElementType, renvoieNone.min_value(x: Type) -> Valuerenvoie la valeur minimale possible d'unTensorElementType. Sixn'est pas unTensorElementType, renvoieNone.member_name(x: Value | Placeholder | Type) -> Any. Disponible pour toutes les définitions de membresmember_namede tous les types. Par exemple,tensor_element_type(x)renvoie la partieTensorElementTyped'unTensorTypecorrespondant. Sixest une valeur ou un espace réservé, cette fonction est un raccourci pourmember_name(type(x)). Sixn'est pas un type qui possède un membre approprié, ou une valeur ou un espace réservé de ce type, renvoieNone.is_empty_algorithm(*args: Type)vérifie si tous les champs de l'algorithme de points sont définis surNone. Cela est nécessaire, car les algorithmes de points ont des comportements par défaut définis par l'implémentation. Il serait donc incorrect de spécifier une valeur par défaut.
Construction des valeurs
operation_name(*xs: Value | Type) -> Value. Disponible pour toutes les opérations. Par exemple,add(lhs, rhs)prend deux valeurs de Tensorlhsetrhset renvoie le résultat de l'évaluation de l'opérationaddavec ces entrées. Pour certaines opérations (par exemple,broadcast_in_dim), les types de leurs sorties sont "porteurs", c'est-à-dire nécessaires pour évaluer une opération. Dans ce cas, la fonction prend ces types comme arguments.
Fonctions sur les valeurs
Tous les opérateurs et fonctions Python sont disponibles. Par exemple, les notations subscription et slicing de Python sont disponibles pour indexer les Tensors, les Tensors quantifiés et les tuples.
to_destination_type(x: Value, destination_type: Type) -> Valueest défini sur les Tensors et renvoie la valeur convertie dexen fonction detype(x)etdestination_typecomme suit :
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Une discussion préliminaire est en cours concernant la fusion des opérations convert, uniform_quantize et uniform_dequantize (#1576).
Après la fusion, nous n'avons plus besoin de la fonction ci-dessus et nous pouvons utiliser le nom de l'opération pour convert.
is_nan(x: Value) -> Valueest défini sur les Tensors et renvoietruesi tous les éléments dexsontNaNoufalsedans le cas contraire. Sixn'est pas un Tensor, renvoieNone.is_sorted(x: Value) -> Valueest défini sur les Tensors et renvoietruesi les éléments dexsont triés par ordre croissant par rapport à l'ordre lexicographique croissant de leurs indices, oufalsedans le cas contraire. Sixn'est pas un Tensor, renvoieNone.is_unique(x: Value) -> Valueest défini sur les Tensors et renvoietruesixne comporte pas d'éléments en double, oufalsedans le cas contraire. Sixn'est pas un Tensor, renvoieNone.member_name(x: Value) -> Anyest défini pour toutes les définitions de membresmember_namede toutes les valeurs. Par exemple,real_part(x)renvoie la partieRealPartd'unComplexConstantcorrespondant. Sixn'est pas une valeur qui possède un membre approprié, renvoieNone.same(x: Value) -> Valueest défini sur les Tensors et renvoietruesi les éléments dexsont tous égaux les uns aux autres, oufalsedans le cas contraire. Si le Tensor ne comporte aucun élément, cela est considéré comme "tous égaux les uns aux autres ", c'est-à-dire que la fonction renvoietrue. Sixn'est pas un Tensor, renvoieNone.split(x: Value, num_results: Value, axis: Value) -> Valueest défini sur les Tensors et renvoie des tranchesnum_resultsdexle long de l'axeaxis. Sixn'est pas un Tensor oudim(x, axis) % num_results != 0, renvoieNone.is_defined_in_parent_scope(x: Value) -> Valueest défini sur les chaînes et renvoietruesixest le nom d'une fonction définie dans le même champ d'application que la fonction parente de l'opération concernée.is_namespaced_op_name(x: Value) -> Valueest défini sur les chaînes et renvoietruesixest un nom d'opération valide, c'est-à-dire qu'il respecte l'expression régulière suivante :[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+.
Calculs de forme
axes(x: Value | Placeholder | Type) -> Valueest un raccourci pourrange(rank(x)).dim(x: Value | Placeholder | Type, axis: Value) -> Valueest un raccourci pourshape(x)[axis].dims(x: Value | Placeholder | Type, axes: List) -> Listest un raccourci pourlist(map(lambda axis: dim(x, axis), axes)).index_space(x: Value | Placeholder | Type) -> Valueest défini sur les Tensors et renvoie les indicessize(x)pour leTensorTypecorrespondant trié par ordre lexicographique croissant, c'est-à-dire[0, ..., 0],[0, ..., 1], ...,shape(x) - 1. Sixn'est pas un type Tensor, un type Tensor quantifié, une valeur ou un espace réservé de l'un de ces types, renvoieNone.rank(x: Value | Placeholder | Type) -> Valueest un raccourci poursize(shape(x)).shape(x: Value | Placeholder | Type) -> Valueest défini dans la section "Fonctions sur les types" viamember_name.size(x: Value | Placeholder | Type) -> Valueest un raccourci pourreduce(lambda x, y: x * y, shape(x)).
Calculs de quantification
def baseline_element_type(x: Value | Placeholder | Type) -> Typeest un raccourci pourelement_type(baseline_type(x)).baseline_typeest défini sur les types de Tensor et les types de Tensor quantifiés, et les transforme en "baseline", c'est-à-dire un type avec la même forme, mais avec les paramètres de quantification du type d'élément réinitialisés sur les valeurs par défaut. Il s'agit d'une astuce pratique pour comparer uniformément les types de Tensor et de Tensor quantifiés, ce qui est souvent nécessaire. Pour les types quantifiés, cela permet de comparer les types en ignorant les paramètres de quantification, c'est-à-dire queshape,storage_type,expressed_type,storage_min,storage_maxetquantization_dimension(pour le type quantifié par axe) doivent tous correspondre, maisscalesetzero pointspeuvent être différents.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantizeest défini sur les types de tenseurs quantifiés et les transforme en types de tenseurs à virgule flottante. Cela se produit en convertissant les éléments quantifiés qui représentent les valeurs entières du type de stockage en valeurs à virgule flottante correspondantes du type exprimé à l'aide du point zéro et de l'échelle associés au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantizeest défini sur les types de tenseurs à virgule flottante et les transforme en types de tenseurs quantifiés. Pour ce faire, les valeurs à virgule flottante du type exprimé sont converties en valeurs entières correspondantes du type de stockage à l'aide du point zéro et de l'échelle associés au type d'élément quantifié.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantizepermet de spécifier des calculs élément par élément sur les Tensors quantifiés. Il déquantifie, c'est-à-dire qu'il transforme les éléments quantifiés en leurs types exprimés, puis effectue une opération, puis quantifie, c'est-à-dire qu'il transforme les résultats en leurs types de stockage. Pour le moment, cette fonction ne fonctionne que pour la quantification par Tensor. La quantification par axe est en cours de développement (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_opest utilisé pour spécifier la quantification du poids uniquement pour l'opération hybride qui accepte lhs en virgule flottante et rhs dans les types quantifiés. Il déquantifie les entrées quantifiées dans leurs types exprimés et effectue des calculs en float. Le type d'élément du Tensor lhs float et le type exprimé du Tensor rhs quantifié doivent être identiques.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Calculs de grille
cross_partition(replica_groups: Value) -> Value. Consultez la section "cross_replica" ci-dessus.cross_replica(replica_groups: Value) -> Value. Consultez la section "cross_replica" ci-dessus.cross_replica_and_partition(replica_groups: Value) -> Value. Consultez la section "cross_replica_and_partition" ci-dessus.flattened_ids(replica_groups: Value) -> Value. Consultez la section "flattened_ids" ci-dessus.
Dynamisme
Les valeurs StableHLO peuvent avoir des tailles de dimension dynamiques, par exemple tensor<?xi64>.
Toutefois, les valeurs StableHLO ne peuvent pas avoir un nombre dynamique de dimensions (dynamisme non classé, par exemple tensor<*xi64>). Les opérandes et les résultats sont autorisés à utiliser des tailles de dimension dynamiques, même s'il existe des contraintes sur les tailles. Les contraintes seront validées de manière statique si possible. Sinon, elles seront différées jusqu'à l'exécution et les incohérences entraîneront un comportement indéfini. Vous trouverez des exemples ci-dessous.
Incompatibilités de forme pour les opérations unaires élément par élément
Prenons l'exemple de programme suivant :
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Un tel programme est inhabituel, car il est rare de connaître la forme du résultat sans connaître celle de l'entrée. Néanmoins, il s'agit d'un programme StableHLO valide. Il n'est pas possible de valider statiquement l'opération abs dans ce programme, car la forme exacte de l'opérande est inconnue. Toutefois, les formes sont certainement compatibles, et cela peut être vérifié de manière statique : ? pourrait s'avérer être 2 au moment de l'exécution, et il n'y aurait aucun problème. Toutefois, ? peut également s'avérer être un autre entier, auquel cas le comportement n'est pas défini.
Notez que si la taille d'une dimension est dynamique dans le résultat, il ne peut pas y avoir de comportement indéfini. En effet, il n'y a pas de taille "attendue", il ne peut donc pas y avoir de décalage.
Incompatibilité de forme pour les opérations binaires élément par élément
Prenons l'exemple de programme suivant :
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
En ce qui concerne les opérations binaires élément par élément, les formes des entrées et du résultat doivent correspondre au moment de l'exécution. Au moment de la compilation, les dimensions statiques doivent être égales. Sinon, elles doivent simplement être compatibles. Si une dimension est dynamique dans les entrées, un comportement indéfini peut se produire au moment de l'exécution, car la taille dynamique peut ne pas correspondre à la taille correspondante dans l'autre opérande (qu'elle soit statique ou dynamique). Si toutes les entrées sont statiques, la nature dynamique ou non du résultat n'a pas d'importance : les dimensions connues de manière statique seront vérifiées de manière statique, et les dimensions dynamiques n'imposent aucune contrainte.
Incompatibilités de forme pour les opérations qui prennent leur forme de sortie comme opérande
Prenons l'exemple de programme suivant :
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Les valeurs de l'opérande de forme lors de l'exécution doivent correspondre à la forme du résultat, sinon le comportement n'est pas défini. Autrement dit, au moment de l'exécution, %arg0 doit avoir une valeur de dense<[3, 4]> : tensor<2xi32>. Si l'opérande de forme est constant, cela peut être vérifié de manière statique. Si la forme du résultat est entièrement dynamique, il ne peut pas y avoir d'incohérence.