StableHLO est un ensemble d'opérations destiné aux opérations de haut niveau (HLO) dans les modèles de machine learning (ML). StableHLO fonctionne comme une couche de portabilité entre différents frameworks de ML et compilateurs de ML: les frameworks de ML qui génèrent des programmes StableHLO sont compatibles avec les compilateurs de ML qui utilisent des programmes StableHLO.
Notre objectif est de simplifier et d'accélérer le développement du ML en favorisant l'interopérabilité entre différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et des compilateurs de ML (tels que XLA et IREE). À cette fin, ce document fournit une spécification pour le langage de programmation StableHLO.
Cette spécification contient trois sections principales. Tout d'abord, la section Programmes décrit la structure des programmes StableHLO, qui sont constitués de fonctions StableHLO, qui sont elles-mêmes constituées d'opérations StableHLO. Au sein de cette structure, la section Ops spécifie la sémantique de chaque opération. La section Exécution fournit la sémantique de toutes ces opérations exécutées ensemble au sein d'un programme. Enfin, la section Notation décrit la notation utilisée tout au long de la spécification.
Programmes
Program ::= {Func}
Les programmes StableHLO sont constitués d'un nombre arbitraire de fonctions StableHLO.
Vous trouverez ci-dessous un exemple de programme avec une fonction @main
qui comporte trois entrées (%image
, %weights
et %bias
) et une sortie. Le corps de la fonction comporte six opérations.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Fonctions
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Les fonctions StableHLO (également appelées fonctions nommées) ont un identifiant, des entrées/sorties et un corps. Nous prévoyons d'ajouter des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, #740, 744).
Identifiants
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Les identifiants StableHLO sont semblables aux identifiants dans de nombreux langages de programmation, avec deux particularités: 1) tous les identifiants ont des sigles qui distinguent différents types d'identifiants, 2) les identifiants de valeur peuvent être entièrement numériques pour simplifier la génération de programmes StableHLO.
Types
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Les types StableHLO sont classés en types de valeurs (également appelés types de première classe), qui représentent des valeurs StableHLO et des types sans valeur, qui décrivent d'autres éléments du programme. Les types StableHLO sont semblables aux types dans de nombreux langages de programmation, la principale particularité étant la nature spécifique au domaine de StableHLO, qui entraîne des résultats inhabituels (par exemple, les types scalaires ne sont pas des types de valeur).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}
Les types de Tensors représentent des Tensors, c'est-à-dire des tableaux multidimensionnels. Elles ont une forme et un type d'élément, où elles représentent des tailles de dimensions non négatives dans l'ordre croissant des dimensions correspondantes (également appelées axes) numérotées de 0
à R-1
. Le nombre de
dimensions R
est appelé rang. Par exemple, tensor<2x3xf32>
est un type de Tensor avec la forme 2x3
et le type d'élément f32
. Elle comporte deux dimensions (ou, en d'autres termes, deux axes) : la 0e dimension et la 1re dimension, dont les tailles sont 2 et 3. Son classement est de 2.
Cela permet de prendre en charge les formes statiques dans lesquelles les tailles de dimension sont connues de manière statique. À l'avenir, nous prévoyons d'ajouter également la compatibilité avec les formes dynamiques pour lesquelles les tailles de dimension sont partiellement ou entièrement inconnues (#8). De plus, nous prévoyons d'explorer l'extension des types de Tensor au-delà des tailles de dimension et des types d'éléments, par exemple pour inclure les mises en page (#629) et la parcimonie (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nom | Type | Contraintes |
---|---|---|
storage_type |
Type entier | (C1-C4), (C9) |
storage_min |
constante entière | (C2), (C4), (C8) |
storage_max |
constante entière | (C3), (C4), (C8) |
expressed_type |
type à virgule flottante | (C1), (C5) |
quantization_dimension |
constante entière facultative | (C11-C13) |
scales |
nombre varié de constantes à virgule flottante | (C5-C7), (C10), (C11), (C13) |
zero_points |
nombre varié de constantes entières | (C8-C10) |
Les types d'éléments quantifiés représentent des valeurs entières d'un type de stockage comprises entre storage_min
et storage_max
(inclus) qui correspondent à des valeurs à virgule flottante d'un type exprimé. Pour une valeur entière i
donnée, la valeur à virgule flottante correspondante f
peut être calculée comme f = (i - zero_point) * scale
, où scale
et zero_point
sont appelés paramètres de quantification. Les éléments storage_min
et storage_max
sont facultatifs dans la grammaire, mais leurs valeurs par défaut sont respectivement min_value(storage_type)
et max_value(storage_type)
. Les types d'éléments quantifiés présentent les contraintes suivantes:
- (C1)
num_bits(storage_type) < num_bits(expressed_type)
. - (C2)
type(storage_min) = storage_type
. - (C3)
type(storage_max) = storage_type
. - (C4)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C5)
type(scales...) = expressed_type
. - (C6)
0 < scales
. - (C7)
is_finite(scales...)
. - (C8)
storage_min <= zero_points <= storage_max
. - (C9)
type(zero_points...) = storage_type
. - (C10)
size(scales) = size(zero_points)
. - (C11) Si la valeur est
is_empty(quantization_dimension)
, alorssize(scales) = 1
. - (C12)
0 <= quantization_dimension
.
Pour le moment, QuantizationScale
est une constante à virgule flottante, mais il existe un intérêt important pour les échelles basées sur des entiers, représentées par des multiplicateurs et des décalages. Nous prévoyons d'explorer cette option prochainement (#1404).
Une discussion en cours sur la sémantique de QuantizationZeroPoint
, y compris le type et les valeurs, et s'il peut y avoir un seul ou potentiellement plusieurs points zéros dans un type de Tensor quantifié Sur la base des résultats de cette discussion, la spécification autour de zéro point est susceptible d'être modifiée à l'avenir (#1405).
Une autre discussion en cours concerne la sémantique de QuantizationStorageMin
et QuantizationStorageMax
pour déterminer si des contraintes doivent être imposées à ces valeurs et aux valeurs des Tensors quantifiés (#1406).
Enfin, nous prévoyons d'explorer la représentation d'échelles inconnues et de zéros points, de la même manière que nous prévoyons d'explorer la représentation de tailles de dimension inconnues (#1407).
Les types de Tensor quantifiés représentent des Tensors avec des éléments quantifiés. Ces Tensors sont exactement identiques aux Tensors standards, si ce n'est que leurs éléments ont des types d'éléments quantifiés, et non des types d'éléments standards.
Dans les Tensors quantifiés, la quantification peut être par Tensor, c'est-à-dire avoir un scale
et un zero_point
pour l'ensemble du Tensor, ou par axe, c'est-à-dire avoir plusieurs scales
et zero_points
, une paire par tranche d'un quantization_dimension
particulier de dimension. Plus formellement, dans un Tensor t
avec quantification par axe, il existe des tranches dim(t, quantization_dimension)
de quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
, etc. Tous les éléments de la i
e tranche utilisent scales[i]
et zero_points[i]
comme paramètres de quantification. Les types de Tensors quantifiés présentent les contraintes suivantes:
- Pour la quantification par Tensor :
- Aucune contrainte supplémentaire.
- Pour la quantification par axe :
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Les types de jetons représentent des jetons, c'est-à-dire des valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer l'ordre d'exécution des opérations, comme décrit dans la section Exécution.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Les types de uple représentent des tuples, c'est-à-dire des listes hétérogènes. Tuples est une ancienne fonctionnalité qui n'existe que pour assurer la compatibilité avec HLO. Dans HLO, les tuples sont utilisés pour représenter les entrées et les sorties variables. Dans StableHLO, les entrées et les sorties variables sont prises en charge de manière native. La seule utilisation de tuples dans StableHLO consiste à représenter de manière exhaustive l'ABI HLO, où T
, tuple<T>
et tuple<tuple<T>>
peuvent, par exemple, être sensiblement différents en fonction d'une implémentation particulière. À l'avenir, nous prévoyons de modifier l'ABI HLO pour pouvoir supprimer les types de tuples de StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Les types d'éléments représentent des éléments de types de Tensor. Contrairement à de nombreux langages de programmation, ces types ne constituent pas la première classe de StableHLO. Cela signifie que les programmes StableHLO ne peuvent pas représenter directement des valeurs de ces types (par conséquent, il est idiomatique de représenter des valeurs scalaires de type T
avec des valeurs de Tensor à 0 dimension de type tensor<T>
).
- Le type booléen représente les valeurs booléennes
true
etfalse
. - Les types d'entiers peuvent être signés (
si
) ou non signés (ui
) et avoir l'une des largeurs de bits compatibles (4
,8
,16
,32
ou64
). Les typessiN
signés représentent des valeurs entières comprises entre-2^(N-1)
et2^(N-1)-1
, et les typesuiN
non signés représentent des valeurs entières comprises entre0
et2^N-1
inclus. - Les types à virgule flottante peuvent être :
- Types
f8E4M3FN
etf8E5M2
correspondant respectivement aux encodagesE4M3
etE5M2
du format FP8 décrit dans la section Formats FP8 pour le deep learning. - Types
f8E4M3FNUZ
etf8E5M2FNUZ
correspondant aux encodagesE4M3
etE5M2
des formats FP8 décrits dans la section Formats numériques 8 bits pour les réseaux de neurones profonds. - Type
f8E4M3B11FNUZ
correspondant à l'encodageE4M3
des formats FP8 décrits dans la section Entraînement et inférence d'entraînement et d'inférence à virgule flottante hybride 8 bits (HFP8) pour les réseaux de neurones profonds. - Type
bf16
correspondant au formatbfloat16
décrit dans BFloat16: Le secret pour des performances élevées sur Cloud TPU. - Types
f16
,f32
etf64
correspondant respectivement aux formatsbinary16
("demi-précision"),binary32
("simple précision") etbinary64
("double précision") décrits dans la norme IEEE 754.
- Types
- Les types complexes représentent des valeurs complexes qui ont une partie réelle et une partie imaginaire du même type d'élément. Les types complexes compatibles sont
complex<f32>
(les deux parties sont de typef32
) etcomplex<f64>
(les deux parties sont de typef64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Les types de fonctions représentent à la fois des fonctions nommées et anonymes. Ils ont des types d'entrée (la liste des types se trouvant à gauche dans ->
) et des types de sortie (liste des types à droite de ->
). Dans de nombreux langages de programmation, les types de fonctions sont la première classe, mais pas dans StableHLO.
StringType ::= 'string'
Le type de chaîne représente des séquences d'octets. Contrairement à de nombreux langages de programmation, le type de chaîne n'est pas la première classe de StableHLO. Il n'est utilisé que pour spécifier des métadonnées statiques pour les éléments de programme.
Opérations
Les opérations StableHLO (également appelées opérations) représentent un ensemble fermé d'opérations de haut niveau dans les modèles de machine learning. Comme indiqué ci-dessus, la syntaxe StableHLO s'inspire fortement de MLIR, qui n'est pas nécessairement l'alternative la plus ergonomique, mais elle est sans doute la mieux adaptée à l'objectif de StableHLO, qui est de créer une plus grande interopérabilité entre les frameworks de ML et les compilateurs de ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Les opérations StableHLO (également appelées opérations) ont un nom, des entrées/sorties et une signature. Le nom se compose du préfixe stablehlo.
et d'un mnémonique qui identifie de manière unique l'une des opérations compatibles. Vous trouverez ci-dessous la liste complète de toutes les opérations compatibles.
À l'heure actuelle, les programmes StableHLO à l'état actuel contiennent parfois des opérations qui ne sont pas décrites dans ce document. À l'avenir, nous prévoyons d'absorber ces opérations dans l'opération StableHLO ou de les interdire d'apparaître dans les programmes StableHLO. En attendant, voici la liste de ces opérations:
builtin.module
,func.func
,func.call
etfunc.return
(#425).- Opérations
chlo
(#602). - Catégorie "Pas dans l'ordre HLO" des opérations StableHLO : elles faisaient initialement partie de l'opération StableHLO, mais ont par la suite été considérées comme ne convenant pas bien :
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(n° 3). - Catégorie "Dynamism" des opérations StableHLO : elles ont été lancées à partir du MHLO, mais nous ne les avons pas encore spécifiées :
compute_reshape_shape
,cstr_reshapable
,dynamic_broadcast_in_dim
,dynamic_conv
,dynamic_gather
,dynamic_iota
,dynamic_pad
,dynamic_reshape
,real_dynamic_slice
,set_dimension_size
(n° 8). - Calculs de forme, y compris les opérations
arith
,shape
ettensor
(#8).
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Les opérations consomment des entrées et produisent des sorties. Les entrées sont classées en valeurs d'entrée (calculées lors de l'exécution), fonctions d'entrée (fournies de manière statique, car les fonctions StableHLO ne sont pas des valeurs de première classe) et attributs d'entrée (également fournis de manière statique). Le type d'entrées et de sorties consommées et produites par une opération dépend de son mnémonique. Par exemple, l'opération add
utilise deux valeurs d'entrée et génère une valeur de sortie. En comparaison, l'opération select_and_scatter
utilise 3 valeurs d'entrée, 2 fonctions d'entrée et 3 attributs d'entrée.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Les fonctions d'entrée (également appelées fonctions anonymes) sont très semblables aux fonctions nommées, sauf que: 1) elles n'ont pas d'identifiant (d'où leur nom "anonyme"), 2) elles ne déclarent pas de types de sortie (les types de sortie sont déduits de l'opération return
dans la fonction).
La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir la production de Unused
ci-dessus), qui assure la compatibilité avec MLIR. Dans MLIR, il existe un concept plus général de "régions" qui peut comporter plusieurs "blocs" d'opérations connectés entre eux via des opérations de saut. Ces blocs ont des identifiants qui correspondent à la production Unused
, de sorte qu'ils peuvent être distingués les uns des autres.
StableHLO n'a pas d'opérations de saut. Par conséquent, la partie correspondante de la syntaxe MLIR n'est pas utilisée (mais est toujours présente).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Les attributs d'entrée ont un nom et une valeur qui est l'une des constantes acceptées. Il s'agit du moyen principal de spécifier des métadonnées statiques pour les éléments d'un programme. Par exemple, l'opération concatenate
utilise l'attribut dimension
pour spécifier la dimension avec laquelle ses valeurs d'entrée sont concaténées. De même, l'opération slice
utilise plusieurs attributs tels que start_indices
et limit_indices
pour spécifier les limites utilisées pour diviser la valeur d'entrée.
À l'heure actuelle, les programmes StableHLO dans la nature contiennent parfois des attributs qui ne sont pas décrits dans ce document. À l'avenir, nous prévoyons d'intégrer ces attributs dans l'opération StableHLO ou d'en interdire l'affichage dans les programmes StableHLO. En attendant, voici la liste de ces attributs:
layout
(#629)mhlo.frontend_attributes
(#628)mhlo.sharding
(#619)output_operand_aliases
(#740)- Métadonnées de localisation (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La signature Op comprend les types de toutes les valeurs d'entrée (la liste des types à gauche de ->
) et les types de toutes les valeurs de sortie (la liste des types à droite de ->
). À proprement parler, les types d'entrée sont redondants et les types de sortie sont presque toujours redondants également (car, pour la plupart des opérations StableHLO, les types de sortie peuvent être déduits à partir des entrées). Néanmoins, la signature d'opération est délibérément intégrée à la syntaxe StableHLO pour assurer la compatibilité avec MLIR.
Vous trouverez ci-dessous un exemple d'opération dont le mnémonique est select_and_scatter
. Elle utilise 3 valeurs d'entrée (%operand
, %source
et %init_value
), 2 fonctions d'entrée et 3 attributs d'entrée (window_dimensions
, window_strides
et padding
). Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée (mais pas les types de fonctions d'entrée et d'attributs fournis de façon intégrée).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Les constantes StableHLO ont un littéral et un type, qui représentent ensemble une valeur StableHLO. En règle générale, le type fait partie de la syntaxe constante, sauf lorsqu'il est sans ambiguïté (par exemple, une constante booléenne est sans ambiguïté de type i1
, alors qu'une constante entière peut avoir plusieurs types possibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Les constantes booléennes représentent les valeurs booléennes true
et false
. Les constantes booléennes sont de type i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Les constantes entières représentent des valeurs entières via des chaînes utilisant la notation décimale ou hexadécimale. Les autres bases (binaires ou octales, par exemple) ne sont pas compatibles. Les constantes entières présentent les contraintes suivantes:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Les constantes à virgule flottante représentent les valeurs à virgule flottante via des chaînes utilisant la notation décimale ou scientifique. En outre, la notation hexadécimale peut être utilisée pour spécifier directement les bits sous-jacents au format à virgule flottante du type correspondant. Les constantes à virgule flottante sont soumises aux contraintes suivantes:
- (C1) Si vous utilisez une notation non hexadécimale,
is_wellformed(float_literal, float_type)
. - (C2) Si vous utilisez la notation hexadécimale,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Les constantes complexes représentent des valeurs complexes à l'aide de listes composées d'une partie réelle (située en premier) et d'une partie imaginaire (située en deuxième). Par exemple, (1.0, 0.0) : complex<f32>
représente 1.0 + 0.0i
, et (0.0, 1.0) : complex<f32>
représente 0.0 + 1.0i
. L'ordre dans lequel ces parties sont ensuite stockées en mémoire est défini par l'implémentation. Les constantes complexes présentent les contraintes suivantes:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Les constantes Tensor représentent les valeurs de Tensor à l'aide de listes imbriquées spécifiées via la notation NumPy. Par exemple, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
représente une valeur de Tensor avec les mappages suivants des index aux éléments : {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
, {1, 2} => 6
. L'ordre dans lequel ces éléments sont ensuite stockés en mémoire est défini par l'implémentation. Les constantes de Tensor présentent les contraintes suivantes:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, où :has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, où :has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- Sinon,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Les constantes de Tensor quantifiées représentent les valeurs de Tensor quantifiées en utilisant la même notation que les constantes de Tensor, avec des éléments spécifiés en tant que constantes de leur type de stockage. Les constantes de Tensor quantifiées présentent les contraintes suivantes:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Les littéraux de chaîne sont constitués d'octets spécifiés à l'aide de caractères ASCII et de séquences d'échappement. Ils sont indépendants de l'encodage. L'interprétation de ces octets est donc définie par l'implémentation. Les littéraux de chaîne sont de type string
.
Opérations
abs
Sémantique
Effectue une opération Absolu par élément sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les entiers signés: module d'entiers.
- Pour les flottants:
abs
à partir de la norme IEEE-754. - Pour les nombres complexes: module complexe.
- Pour les types quantifiés:
dequantize_op_quantize(abs, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor d'un entier signé, d'un type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1-C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier signé, à virgule flottante ou d'un Tensor quantifié par Tensor | (C1-C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
est défini comme suit :complex_element_type(element_type(operand))
siis_complex(operand)
.baseline_element_type(operand)
dans les autres cas.
Exemples
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Sémantique
Effectue une addition par élément de deux Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: opérateur logique OR.
- Pour les nombres entiers: addition de nombres entiers.
- Pour les flottants:
addition
à partir de la norme IEEE-754. - Pour les nombres complexes: addition complexe.
- Pour les types quantifiés:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Sémantique
Permet de s'assurer que les opérations qui génèrent le inputs
sont exécutées avant toute opération qui dépend de result
. L'exécution de cette opération n'a aucun effet. Elle sert uniquement à établir des dépendances de données entre result
et inputs
.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre variable de token |
Sorties
Nom | Type |
---|---|
result |
token |
Exemples
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs du Tensor operand
de chaque processus avec all_gather_dim
et génère un Tensor result
.
L'opération divise la grille de processus StableHLO en process_groups
, défini comme suit:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
operands@receiver = [operand@sender for sender in process_group]
pour tous lesreceiver
deprocess_group
.result@process = concatenate(operands@process, all_gather_dim)
pour tous lesprocess
deprocess_group
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C6) |
(I2) | all_gather_dim |
constante de type si64 |
(C1), (C6) |
(I3). | replica_groups |
Constante de Tensor bidimensionnelle de type si64 |
(C2-C4) |
(I4). | channel_id |
constante de type si64 |
(C5) |
(I5). | use_global_device_ids |
constante de type i1 |
(C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C6) |
Contraintes
- (C1)
0 <= all_gather_dim < rank(operand)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Si la valeur est
use_global_device_ids = true
, alorschannel_id > 0
. - (C6)
type(result) = type(operand)
, sauf :dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
all_reduce
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, applique une fonction de réduction computation
aux valeurs du Tensor operand
de chaque processus et génère un Tensor result
.
L'opération divise la grille de processus StableHLO en process_groups
, défini comme suit:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
result@process[result_index] = exec(schedule)
pour une arborescence binaireschedule
où :exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
est une arborescence binaire définie par l'implémentation dont le balayage dans l'ordre estto_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C5), (C6) |
(I2) | replica_groups |
nombre varié de constantes de Tensor unidimensionnelles de type si64 |
(C1-C3) |
(I3). | channel_id |
constante de type si64 |
(C4) |
(I4). | use_global_device_ids |
constante de type i1 |
(C4) |
(I5). | computation |
function | (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C6-C7) |
Contraintes
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Si la valeur est
use_global_device_ids = true
, alorschannel_id > 0
. - (C5)
computation
est de type(tensor<E>, tensor<E>) -> (tensor<E>)
, oùis_promotable(element_type(operand), E)
. - (C6)
shape(result) = shape(operand)
. - (C7)
element_type(result) = E
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
all_to_all
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, il divise les valeurs du Tensor operand
le long de split_dimension
en plusieurs parties, disperse les parties divisées entre les processus, concatène les parties dispersées le long de concat_dimension
et génère un Tensor result
.
L'opération divise la grille de processus StableHLO en process_groups
, défini comme suit:
cross_replica(replica_groups)
sichannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Ensuite, dans chaque process_group
:
split_parts@sender = split(operand@sender, split_count, split_dimension)
pour tous lessender
dansprocess_group
.scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]
oùreceiver_index = process_group.index(receiver)
.result@process = concatenate(scattered_parts@process, concat_dimension)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C3), (C9) |
(I2) | split_dimension |
constante de type si64 |
(C1), (C2), (C9) |
(I3). | concat_dimension |
constante de type si64 |
(C3), (C9) |
(I4). | split_count |
constante de type si64 |
(C2), (C4), (C8), (C9) |
(I5). | replica_groups |
Constante de Tensor bidimensionnelle de type si64 |
(C5-C8) |
(I6). | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C9) |
Contraintes
- (C1)
0 <= split_dimension < rank(operand)
. - (C2)
dim(operand, split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operand)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(result) = type(operand)
, sauf :dim(result, split_dimension) = dim(operand, split_dimension) / split_count
.dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
// [5, 6],
// [9, 10],
// [13, 14]]
// %result@(1, 0): [[3, 4],
// [7, 8],
// [11, 12],
// [15, 16]]
et
Sémantique
Effectue l'opérateur ET par élément des deux Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: AND.
- Pour les entiers: AND au niveau du bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type booléen ou entier | (C1) |
(I2) | rhs |
Tensor de type booléen ou entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Sémantique
Effectue une opération atan2 par élément sur le Tensor lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
atan2
à partir de la norme IEEE-754. - Pour les nombres complexes: atan2 complexe.
- Pour les types quantifiés:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Sémantique
Calcule les gradients de plusieurs entrées de rétropropagation batch_norm_training
à partir de grad_output
, et génère les Tensors grad_operand
, grad_scale
et grad_offset
. Plus formellement, cette opération peut être exprimée sous forme de décomposition d'opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1-C3), (C5) |
(I2) | scale |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C4), (C5) |
(I3). | mean |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C4) |
(I4). | variance |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C4) |
(I5). | grad_output |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C2), (C3) |
(I6). | epsilon |
constante de type f32 |
|
(I7) | feature_index |
constante de type si64 |
(C1), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
grad_operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C2), (C3) |
grad_scale |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C4) |
grad_offset |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C4) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
etgrad_offset
ont le mêmebaseline_element_type
. - (C3)
operand
,grad_output
etgrad_operand
ont la même forme. - (C4)
scale
,mean
,variance
,grad_scale
etgrad_offset
ont la même forme. - (C5)
size(scale) = dim(operand, feature_index)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Sémantique
Normalise le Tensor operand
pour toutes les dimensions, à l'exception de la dimension feature_index
, et génère un Tensor result
. Plus formellement, cette opération peut être exprimée sous la forme d'une décomposition d'opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Pour les types quantifiés, exécute dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1-C7) |
(I2) | scale |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C3) |
(I3). | offset |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C4) |
(I4). | mean |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C5) |
(I5). | variance |
Tensor unidimensionnel de type quantifié à virgule flottante ou par Tensor | (C2), (C6) |
(I6). | epsilon |
constante de type f32 |
|
(I7) | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C2), (C7) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
etresult
ont le mêmebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Sémantique
Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de la dimension feature_index
, et normalise le Tensor operand
qui produit les Tensors output
, batch_mean
et batch_var
. Plus formellement, cette opération peut être exprimée sous la forme d'une décomposition d'opérations StableHLO existantes à l'aide de la syntaxe Python comme suit:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Pour les types quantifiés, exécute dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
(I2) | scale |
Tensor unidimensionnel à virgule flottante ou quantifié par Tensor | (C2), (C3) |
(I3). | offset |
Tensor unidimensionnel à virgule flottante ou quantifié par Tensor | (C2), (C4) |
(I4). | epsilon |
constante de type f32 |
(C1), (C3-C6) |
(I5). | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C7) |
batch_mean |
Tensor unidimensionnel à virgule flottante ou quantifié par Tensor | (C2), (C5) |
batch_var |
Tensor unidimensionnel à virgule flottante ou quantifié par Tensor | (C2), (C6) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
etoutput
ont le mêmebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Sémantique
Effectue une opération de bitcast sur le Tensor operand
et génère un Tensor result
, dans lequel les bits de l'ensemble du Tensor operand
sont réinterprétés à l'aide du type du Tensor result
.
Plus formellement, avec E = element_type(operand)
, E' = element_type(result)
et R = rank(operand)
:
- Si la valeur est
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Si la valeur est
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Si la valeur est
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
renvoie la représentation en mémoire d'une valeur donnée, et son comportement est défini par l'implémentation, car la représentation exacte des Tensors est définie par l'implémentation et la représentation exacte des types d'éléments est également définie par l'implémentation.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1-C2) |
Contraintes
- (C1) Avec
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
etR = rank(operand)
:- Si la valeur est
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Si la valeur est
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
pour tous les0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Si la valeur est
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
pour tous les0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Si la valeur est
- (C2) Si la valeur est
is_complex(operand) or is_complex(result)
, alorsis_complex(operand) and is_complex(result)
.
Exemples
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Sémantique
Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données dans le Tensor operand
et génère un Tensor result
. Plus formellement, result[result_index] = operand[operand_index]
où pour tous les d
dans axes(operand)
:
operand_index[d] = 0
sidim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1), (C3), (C5-C6) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, si ce n'est quequantization_dimension(operand)
,scales(operand)
etzero_points(operand)
peuvent être différents dequantization_dimension(result)
,scales(result)
etzero_points(result)
, sinon.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Pour tous les
d
deaxes(operand)
:dim(operand, d) = 1
oudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si la valeur est
dim(operand, quantization_dimension(operand)) = 1
, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Exemples
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
demande
Sémantique
Génère le résultat de l'exécution d'une seule fonction à partir de branches
en fonction de la valeur de index
. Plus formellement, result = selected_branch()
où:
selected_branch = branches[index]
si0 <= index < size(branches)
.selected_branch = branches[-1]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | index |
Tensor à 0 dimension de type si32 |
|
(I2) | branches |
nombre variable de fonctions | (C1-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors, de Tensors quantifiés ou de jetons | (C4) |
Contraintes
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Exemples
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
RTL
Sémantique
Effectue une opération de racine cubique par élément sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
rootn(x, 3)
à partir de la norme IEEE-754. - Pour les nombres complexes: racine cubique complexe.
- Pour les types quantifiés:
dequantize_op_quantize(cbrt, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Sémantique
Effectue le ceil par élément du Tensor operand
et génère un Tensor result
.
Met en œuvre l'opération roundToIntegralTowardPositive
à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(ceil, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Sémantique
Calcule la décomposition de cholésique d'un lot de matrices.
Plus formellement, pour tous les i
de index_space(result)
, result[i0, ..., iR-3, :, :]
est une décomposition cholématique de a[i0, ..., iR-3, :, :]
, sous la forme d'une matrice triangulaire inférieure (si lower
correspond à true
) ou d'une matrice triangulaire supérieure (si lower
est false
).
Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict ou le triangle inférieur strict, sont définies par l'implémentation.
S'il existe une valeur i
dans laquelle la matrice d'entrée n'est pas une matrice positive positive de l'hermite, le comportement n'est pas défini.
Pour les types quantifiés, exécute dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1-C3) |
(I2) | lower |
Constante de Tensor à 0 dimension de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Exemples
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
limiter
Sémantique
Applique chaque élément du Tensor operand
entre une valeur minimale et une valeur maximale, et génère un Tensor result
. Plus formellement, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
, où min_element = rank(min) = 0 ? min[] : min[result_index]
et max_element = rank(max) = 0 ? max[] : max[result_index]
. Pour les types quantifiés, effectue dequantize_op_quantize(clamp, min, operand, max, type(result))
.
L'application d'un ordre à des nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes pour cette opération (#560).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | min |
Tensor ou Tensor quantifié par Tensor | (C1), (C3) |
(I2) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C4) |
(I3). | max |
Tensor ou Tensor quantifié par Tensor | (C2), (C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C4) |
Contraintes
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du Tensor operand
du processus source aux processus cibles et générez un Tensor result
.
L'opération divise la grille de processus StableHLO en process_groups
, défini comme suit:
cross_replica(replica_groups)
sichannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Ensuite, result@process
est donné par:
operand@process_groups[i, 0]
s'il existe un élémenti
tel que le processus se trouve dansprocess_groups[i]
.broadcast_in_dim(constant(0, element_type(result)), [], type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor | (C3) |
(I2) | replica_groups |
nombre varié de constantes de Tensor unidimensionnelles de type si64 |
(C1), (C2) |
(I3). | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor | (C3) |
Contraintes
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, oùN
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C3)
type(result) = type(operand)
.
Exemples
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du Tensor operand
du processus source au processus cible et génère un Tensor result
.
L'opération divise la grille de processus StableHLO en process_groups
, défini comme suit:
cross_replica(source_target_pairs)
sichannel_id <= 0
.cross_partition(source_target_pairs)
sichannel_id > 0
.
Ensuite, result@process
est donné par:
operand@process_groups[i, 0]
, s'il existe uni
tel queprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C5) |
(I2) | source_target_pairs |
Constante de Tensor bidimensionnelle de type si64 |
(C1-C4) |
(I3). | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, oùN
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C5)
type(result) = type(operand)
.
Exemples
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Sémantique
Effectue une comparaison élément par élément des Tensors lhs
et rhs
selon comparison_direction
et compare_type
, et génère un Tensor result
.
Les valeurs de comparison_direction
et compare_type
ont la sémantique suivante:
Pour les types d'éléments booléens et entiers:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Pour les types d'éléments à virgule flottante avec compare_type = FLOAT
, l'opération implémente les opérations IEEE-754 suivantes:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER
, l'opération utilise la combinaison des opérations totalOrder
et compareQuietEqual
de la norme IEEE-754. Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#584).
Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag)
est effectuée à l'aide des éléments comparison_direction
et compare_type
fournis.
L'application d'un ordre aux nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes lorsque comparison_direction
correspond à GE
, GT
, LE
ou LT
(#560).
Pour les types quantifiés. Exécute dequantize_compare(lhs, rhs,
comparison_direction)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1-C3) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1-C2) |
(I3). | comparison_direction |
énumération de EQ , NE , GE , GT , LE et LT |
|
(I4). | compare_type |
énumération de FLOAT , TOTALORDER , SIGNED et UNSIGNED |
(C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen | (C2) |
Contraintes
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
est défini comme suit :SIGNED
siis_signed_integer(element_type(lhs))
.UNSIGNED
siis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
ouTOTALORDER
siis_float(element_type(lhs))
.FLOAT
siis_complex(element_type(lhs))
.
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complexe
Sémantique
Effectue une conversion par élément en valeur complexe à partir d'une paire de valeurs réelles et imaginaires, lhs
et rhs
, et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type f32 ou f64 |
(C1-C3) |
(I2) | rhs |
Tensor de type f32 ou f64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type complexe | (C2), (C3) |
Contraintes
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
est de typecomplex<E>
, oùE = element_type(lhs)
.
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
concatenate
Sémantique
Concatène inputs
selon la dimension dimension
dans le même ordre que les arguments donnés et génère un Tensor result
. Plus formellement, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, où:
id = d0 + ... + dk-1 + kd
.d
est égal àdimension
, etd0
, ... est lad
e taille de dimension deinputs
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C1-C6) |
(I2) | dimension |
constante de type si64 |
(C2), (C4), (C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C5-C6) |
Contraintes
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
, sauf pourdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, sauf :dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Exemples
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Sémantique
Génère un Tensor output
à partir d'une constante value
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | value |
constante | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor ou Tensor quantifié | (C1) |
Contraintes
- (C1)
type(value) = type(output)
.
Exemples
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
d'effectuer une conversion
Sémantique
Effectue une conversion par élément d'un type d'élément à un autre sur le Tensor operand
et génère un Tensor result
.
Pour les conversions boolean-to-any-supported-type, la valeur false
est convertie en zéro, et la valeur true
est convertie en un. Pour les conversions de type any-supported-type-to-boolean, une valeur nulle est convertie en false
, et les valeurs non nulles sont converties en true
. Reportez-vous à la section ci-dessous pour en savoir plus sur le fonctionnement des types complexes.
Pour les conversions impliquant un nombre entier à entier, un nombre entier à virgule flottante ou un point flottant à virgule flottante, si la valeur source peut être exactement représentée dans le type de destination, la valeur obtenue sera cette représentation exacte. Sinon, le comportement est à déterminer (#180).
Pour les conversions impliquant une floating-point-to-integer, la partie fractionnaire est tronquée. Si la valeur tronquée ne peut pas être représentée dans le type de destination, le comportement est à déterminer (#180).
Les conversions impliquant complexe à complexe suivent le même comportement que les conversions de point flottant à virgule flottante pour la conversion de pièces réelles et imaginaires.
Pour les conversions de type complex-to-any-other-type et complex-to-any-other-type, la valeur imaginaire source est ignorée ou la valeur imaginaire de la destination est remise à zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.
En principe, cette opération pourrait exprimer la déquantification (conversion de Tensors quantifiés en Tensors réguliers), la quantification (conversion de Tensors standards en Tensors quantifiés) et la requantification (conversion entre des Tensors quantifiés). Toutefois, pour le moment, nous disposons d'opérations dédiées (uniform_dequantize
pour le premier cas d'utilisation et uniform_quantize
pour le deuxième et le troisième). À l'avenir, ces deux opérations pourront être fusionnées dans convert
(#1576).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor | (C1) |
Contraintes
- (C1)
shape(operand) = shape(result)
.
Exemples
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolution
Sémantique
Calcule les produits scalaires entre les fenêtres de lhs
et les tranches de rhs
, et génère la valeur result
. Le schéma suivant montre comment les éléments de result
sont calculés à partir de lhs
et rhs
à l'aide d'un exemple concret.
Plus formellement, considérons le recadrage suivant des entrées en termes de lhs
afin de pouvoir exprimer des fenêtres de lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Ce recadrage utilise les fonctions d'assistance suivantes:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
, oùj[d] = i[permutation[d]]
.
Si feature_group_count = 1
et batch_group_count = 1
, alors pour tous les output_spatial_index
dans index_space(dim(result, output_spatial_dimensions...))
, result[result_shape(:, output_spatial_index, :)] = dot_product
où:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Cette fonctionnalité semble inutilisée. Nous prévoyons donc de la supprimer à l'avenir (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Si la valeur est feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Si la valeur est batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Pour les types quantifiés, exécute dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1), (C10-C11), (C14) (C25), (C27-C30) |
(I2) | rhs |
Tensor ou Tensor quantifié | (C1), (C14-C16), (C25), (C27-C32) |
(I3). | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C3), (C25) |
(I4). | padding |
Constante de Tensor bidimensionnelle de type si64 |
(C4), (C25) |
(I5). | lhs_dilation |
Constante de Tensor unidimensionnelle de type si64 |
(C5-C6), (C25) |
(I6). | rhs_dilation |
Constante de Tensor unidimensionnelle de type si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Constante de Tensor unidimensionnelle de type i1 |
(C9) |
(I8). | input_batch_dimension |
constante de type si64 |
(C10), (C13), (C25) |
(I9). | input_feature_dimension |
constante de type si64 |
(C11), (C13-C14) |
(I10). | input_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C12), (C13), (C25) |
(I11). | kernel_input_feature_dimension |
constante de type si64 |
(C14), (C18) |
(I12). | kernel_output_feature_dimension |
constante de type si64 |
(C15-C16), (C18), (C25), (C32) |
(I13). | kernel_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C17-C18), (C25) |
(I14). | output_batch_dimension |
constante de type si64 |
(C20), (C25) |
(I15). | output_feature_dimension |
constante de type si64 |
(C20), (C25), (C33) |
(I16). | output_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C19-C20), (C25) |
(I17). | feature_group_count |
constante de type si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18). | batch_group_count |
constante de type si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
nombre variable d'énumérations DEFAULT , HIGH et HIGHEST |
(C24) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C25-C28), (C30-C31), (C33) |
Contraintes
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Compte
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
donné :is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Compte
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
donné :is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Compte
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
donné :is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
est défini comme suit :dim(lhs, input_batch_dimension) / batch_group_count
siresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
siresult_dim = output_feature_dimension
.num_windows
dans les autres cas, où:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si l'opération utilise des Tensors non quantifiés :
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si l'opération utilise des Tensors quantifiés :
- (C28)
is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
. - (C29)
storage_type(lhs) = storage_type(rhs)
. - (C30)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C31) Si la valeur est
is_per_tensor_quantized(rhs)
, alorsis_per_tensor_quantized(result)
. - (C32) Si la valeur est
is_per_axis_quantized(rhs)
, alorsquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C33) Si la valeur est
is_per_axis_quantized(result)
, alorsquantization_dimension(result) = output_feature_dimension
.
- (C28)
Exemples
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosinus
Sémantique
Effectue une opération de cosinus par élément sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
cos
à partir de la norme IEEE-754. - Pour les nombres complexes: cosinus complexe.
- Pour les types quantifiés:
dequantize_op_quantize(cosine, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Sémantique
Effectue un comptage par élément du nombre de bits zéros au début dans le Tensor operand
et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Sémantique
Encapsule une opération définie par l'implémentation call_target_name
qui reçoit inputs
et called_computations
, et génère results
. Vous pouvez utiliser has_side_effect
, backend_config
et api_version
pour fournir des métadonnées supplémentaires définies par l'implémentation.
À l'heure actuelle, cette opération contient une collection de métadonnées relativement désorganisée qui reflète l'évolution organique de son équivalent dans le compilateur XLA. Nous prévoyons d'unifier ces métadonnées à l'avenir (#741).
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre varié de valeurs |
(I2) | call_target_name |
constante de type string |
(I3). | has_side_effect |
constante de type i1 |
(I4). | backend_config |
constante de type string |
(I5). | api_version |
constante de type si32 |
(I6). | called_computations |
nombre varié de constantes de type string |
Sorties
Nom | Type |
---|---|
results |
nombre varié de valeurs |
Exemples
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
diviser
Sémantique
Effectue une division par élément des Tensors de dividende lhs
et des Tensors rhs
du diviseur, et produit un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les entiers: division d'entiers qui produit le quotient algébrique avec toute partie fractionnaire rejetée.
- Pour les flottants:
division
à partir de la norme IEEE-754. - Pour les nombres complexes: division complexe.
- Pour les types quantifiés :
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Sémantique
Calcule les produits scalaires entre les tranches de lhs
et les tranches de rhs
, et génère un Tensor result
.
Plus formellement, result[result_index] = dot_product
, où:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
, oùsize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
etsize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Pour les types quantifiés, exécute dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Cette valeur spécifie uniquement la sémantique pour la quantification par Tensor. La quantification par axe est en cours (#1574). Nous pourrions également envisager d'ajouter la compatibilité avec la quantification hybride à l'avenir (#1575).
precision_config
contrôle le compromis entre vitesse et précision des calculs sur les backends d'accélérateur. Il peut s'agir de l'un des éléments suivants (à l'heure actuelle, la sémantique de ces valeurs d'énumération est sous-spécifiée, mais nous prévoyons de résoudre ce problème à la n° 755):
DEFAULT
: calcul le plus rapide, mais approximation la moins précise du nombre d'origine.HIGH
: calcul plus lent, mais approximation plus précise du nombre d'origine.HIGHEST
: calcul le plus lent, mais approximation la plus précise par rapport au nombre d'origine.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C5-C6), (C9-C10), (C12-C16) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C7-C10), (C12) |
(I3). | lhs_batching_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4). | rhs_batching_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4), (C7), (C9) |
(I5). | lhs_contracting_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3), (C6), (C10) |
(I6). | rhs_contracting_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C8), (C10) |
(I7) | precision_config |
nombre variable d'énumérations DEFAULT , HIGH et HIGHEST |
(C11) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C12), (C14), (C16) |
Contraintes
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Si l'opération utilise des Tensors non quantifiés :
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Si l'opération utilise des Tensors quantifiés :
- (C14)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C15)
storage_type(lhs) = storage_type(rhs)
. - (C16)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C17)
zero_points(rhs) = 0
.
- (C14)
Exemples
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_slice
Sémantique
Extrait une tranche de operand
à l'aide d'indices de départ calculés de manière dynamique et génère un Tensor result
. start_indices
contient les index de départ de la tranche pour chaque dimension susceptible d'être ajustée, et slice_sizes
contient les tailles de la tranche pour chaque dimension. Plus formellement, result[result_index] = operand[operand_index]
où:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C4) |
(I2) | start_indices |
nombre varié de Tensors à 0 dimension de type entier | (C2), (C3) |
(I3). | slice_sizes |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Exemples
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Sémantique
Génère un Tensor result
égal au Tensor operand
, à la différence que la tranche commençant à start_indices
est mise à jour avec les valeurs de update
.
Plus formellement, result[result_index]
est défini comme suit:
update[update_index]
si0 <= update_index < shape(update)
, où :adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C4), (C6) |
(I2) | update |
Tensor ou Tensor quantifié par Tensor | (C2), (C3), (C6) |
(I3). | start_indices |
nombre varié de Tensors à 0 dimension de type entier | (C4), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Exemples
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponentiel
Sémantique
Effectue une opération exponentielle par élément sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
exp
à partir de la norme IEEE-754. - Pour les nombres complexes: exponentielle complexe.
- Pour les types quantifiés :
dequantize_op_quantize(exponential, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Sémantique
Effectue une opération exponentielle moins 1 par élément sur le Tensor operand
et produit un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
expm1
à partir de la norme IEEE-754. - Pour les nombres complexes: puissance exponentielle complexe moins un.
- Pour les types quantifiés :
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
FFT
Sémantique
Effectue les transformations de Fourier avant et inverse pour des entrées/sorties réelles et complexes.
fft_type
est l'un des éléments suivants :
FFT
: transfère FFT complexe à complexe.IFFT
: FFT inverse complexe à complexe.RFFT
: transfère FFT réel vers complexe.IRFFT
: FFT inverse de vrai à complexe (l'opération prend des valeurs complexes et renvoie des valeurs réelles).
Plus formellement, avec la fonction fft
, qui prend en entrée des Tensors unidimensionnels de types complexes, elle produit des Tensors unidimensionnels des mêmes types qu'en sortie et calcule la transformation de Fourier discrète:
Pour fft_type = FFT
, result
est défini comme le résultat final d'une série de calculs L, où L = size(fft_length)
. Par exemple, pour L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
De plus, compte tenu de la fonction ifft
, qui a le même type de signature et calcule l'inverse de fft
:
Pour fft_type = IFFT
, result
est défini comme l'inverse des calculs de fft_type = FFT
. Par exemple, pour L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
En outre, grâce à la fonction rfft
, qui prend les Tensors unidimensionnels de types à virgule flottante, elle produit des Tensors unidimensionnels de types complexes de la même sémantique à virgule flottante, et fonctionne comme suit:
rfft(real_operand) = truncated_result
oùcomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Lorsque la transformation de Fourier discrète est calculée pour des opérandes réels, les premiers éléments N/2 + 1
du résultat définissent sans ambiguïté le reste du résultat. Le résultat de rfft
est donc tronqué pour éviter de calculer des éléments redondants.)
Pour fft_type = RFFT
, result
est défini comme le résultat final d'une série de calculs L, où L = size(fft_length)
. Par exemple, pour L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Enfin, avec la fonction irfft
, qui a le même type de signature et calcule l'inverse de rfft
:
Pour fft_type = IRFFT
, result
est défini comme l'inverse des calculs de fft_type = RFFT
. Par exemple, pour L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
énumération de FFT , IFFT , RFFT et IRFFT |
(C2), (C5) |
(I3). | fft_length |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C3), (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe | (C2), (C4), (C5) |
Contraintes
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relation entre les types d'éléments
operand
etresult
varie :- Si
fft_type = FFT
,element_type(operand)
etelement_type(result)
ont le même type complexe. - Si
fft_type = IFFT
,element_type(operand)
etelement_type(result)
ont le même type complexe. - Si la valeur est
fft_type = RFFT
,element_type(operand)
est un type à virgule flottante etelement_type(result)
est un type complexe de la même sémantique à virgule flottante. - Si la valeur est
fft_type = IRFFT
,element_type(operand)
est un type complexe etelement_type(result)
est un type à virgule flottante de la même sémantique à virgule flottante.
- Si
- (C3)
1 <= size(fft_length) <= 3
. - (C4) S'il existe entre
operand
etresult
, il existe un Tensorreal
de type à virgule flottante, alorsshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, sauf :- Si la valeur est
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Si la valeur est
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Si la valeur est
Exemples
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
floor
Sémantique
Effectue le plancher élément par élément du Tensor operand
et génère un Tensor result
.
Met en œuvre l'opération roundToIntegralTowardNegative
à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(floor, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
rassembler
Sémantique
Recueille les tranches du Tensor operand
à partir des décalages spécifiés dans start_indices
et génère un Tensor result
.
Le schéma suivant montre comment les éléments de result
sont mappés sur des éléments de operand
à l'aide d'un exemple concret. Le schéma choisit quelques exemples d'indices result
et explique en détail à quels index operand
ils correspondent.
Plus formellement, result[result_index] = operand[operand_index]
où:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
est défini comme suit :start_indices[bi0, ..., :, ..., biN]
, oùbi
sont des éléments individuels dansbatch_index
, et:
est inséré au niveau de l'indexindex_vector_dim
, siindex_vector_dim
<rank(start_indices)
.[start_indices[batch_index]]
dans les autres cas.
- Pour
d_operand
dansaxes(operand)
:full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sid_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
dans les autres cas.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, oùoi
sont des éléments individuels dansoffset_index
, et0
est inséré aux index decollapsed_slice_dims
.operand_index = full_start_index + full_offset_index
.
Si indices_are_sorted
est défini sur true
, l'implémentation peut supposer que les start_indices
sont triés par rapport à start_index_map
. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2
de indices(result)
: full_start_index(i1) <= full_start_index(i2)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
Tensor de type entier | (C2), (C3), (C13) |
(I3). | offset_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4-C5), (C13) |
(I4). | collapsed_slice_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C6-C8), (C13) |
(I5). | start_index_map |
Constante de Tensor unidimensionnelle de type si64 |
(C3), (C9), (C10) |
(I6). | index_vector_dim |
constante de type si64 |
(C2), (C3), (C13) |
(I7) | slice_sizes |
Constante de Tensor unidimensionnelle de type si64 |
(C8), (C11-C13) |
(I8). | indices_are_sorted |
constante de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C5), (C13-C14) |
Contraintes
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
où :batch_dim_sizes = shape(start_indices)
, si ce n'est que la taille de la dimension destart_indices
correspondant àindex_vector_dim
n'est pas incluse.offset_dim_sizes = shape(slice_sizes)
, si ce n'est que les dimensions des dimensions dansslice_sizes
correspondant àcollapsed_slice_dims
ne sont pas incluses.combine
placebatch_dim_sizes
sur les axes correspondant àbatch_dims
etoffset_dim_sizes
aux axes correspondant àoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Exemples
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
get_dimension_size
Sémantique
Génère la taille de l'élément dimension
donné pour operand
. Plus formellement, result = dim(operand, dimension)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
(I2) | dimension |
constante de type si64 |
(C1) |
Sorties
Nom | Type |
---|---|
result |
Tensor à 0 dimension de type si32 |
Contraintes
- (C1)
0 <= dimension < rank(operand)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Sémantique
Extrait l'élément à la position index
du tuple operand
et génère une result
. Plus formellement : result = operand[index]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
constante de type si32 |
(C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tous les types compatibles | (C2) |
Contraintes
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Exemples
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
if
Sémantique
Génère le résultat de l'exécution d'une seule fonction à partir de true_branch
ou false_branch
en fonction de la valeur de pred
. Plus formellement : result =
pred ? true_branch() : false_branch()
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | pred |
Tensor à 0 dimension de type i1 |
|
(I2) | true_branch |
function | (C1-C3) |
(I3). | false_branch |
function | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors, de Tensors quantifiés ou de jetons | (C3) |
Contraintes
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Exemples
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
Imageg
Sémantique
Extrait la partie imaginaire, par élément, de la operand
et génère un Tensor result
. Plus formellement, pour chaque élément x
: imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
est défini comme suit :complex_element_type(element_type(operand))
siis_complex(operand)
.element_type(operand)
dans les autres cas.
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
flux d'entrée
Sémantique
Lit les données du flux d'entrée et génère results
.
La sémantique de infeed_config
est définie par l'implémentation.
results
sont constitués de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (#670).
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
constante de type string |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors, de Tensors quantifiés ou de jetons | (C1-C3) |
Contraintes
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Exemples
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Ito
Sémantique
Remplit un Tensor output
avec des valeurs dans l'ordre croissant à partir de zéro le long de la dimension iota_dimension
. Plus formellement,
output[result_index] = constant(is_quantized(output) ?
quantize(result_index[iota_dimension], element_type(output)) :
result_index[iota_dimension], element_type(output))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
0 <= iota_dimension < rank(output)
.
Exemples
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Sémantique
Vérifie au niveau des éléments si la valeur de x
est finie (c'est-à-dire si elle n'est ni +Inf, -Inf, ni NaN) et génère un Tensor y
. Met en œuvre l'opération isFinite
à partir de la spécification IEEE-754. Pour les types quantifiés, le résultat est toujours true
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | x |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
y |
Tensor de type booléen | (C1) |
Contraintes
- (C1)
shape(x) = shape(y)
.
Exemples
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Sémantique
Effectue une opération de logarithme par élément sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
log
à partir de la norme IEEE-754. - Pour les nombres complexes: logarithme complexe.
- Pour les types quantifiés:
dequantize_op_quantize(log, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Sémantique
Effectue un logarithme par élément plus une opération sur le Tensor operand
et produit un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
logp1
à partir de la norme IEEE-754. - Pour les nombres complexes: logarithme complexe plus un.
- Pour les types quantifiés :
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistique
Sémantique
Effectue une opération logistique au niveau des éléments sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
division(1, addition(1, exp(-x)))
à partir de la norme IEEE-754. - Pour les nombres complexes: logistique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(logistic, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
carte
Sémantique
Applique une fonction de mappage computation
à inputs
le long de dimensions
et génère un Tensor result
.
Plus formellement : result[result_index] = computation(inputs...[result_index])
.
Notez que les dimensions
ne sont pas utilisés actuellement et seront probablement supprimés à l'avenir (#487).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C1-C4) |
(I2) | dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C3) |
(I3). | computation |
function | (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C4) |
Contraintes
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
est de type(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, oùEi = element_type(inputs[i])
etE' = element_type(result)
.
Exemples
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maximum
Sémantique
Effectue une opération maximale par élément sur les Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: opérateur logique OR.
- Pour les nombres entiers: nombre entier maximum.
- Pour les flottants:
maximum
à partir de la norme IEEE-754. - Pour les nombres complexes: valeur lexicographique maximale pour la paire
(real, imaginary)
. L'application d'un ordre à des nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes pour cette opération (#560). - Pour les types quantifiés :
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Sémantique
Effectue une opération minimale par élément sur les Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: AND.
- Pour les nombres entiers: nombre entier minimal.
- Pour les flottants:
minimum
à partir de la norme IEEE-754. - Pour les nombres complexes: valeur lexicographique minimale pour la paire
(real, imaginary)
. L'application d'un ordre à des nombres complexes implique une sémantique surprenante. Par conséquent, nous prévoyons de supprimer la prise en charge des nombres complexes pour cette opération (#560). - Pour les types quantifiés :
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplier
Sémantique
Effectue un produit par élément de deux Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: AND.
- Pour les nombres entiers: multiplication de nombres entiers.
- Pour les flottants:
multiplication
à partir de la norme IEEE-754. - Pour les nombres complexes: multiplication complexe.
- Pour les types quantifiés :
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
Sémantique
Effectue une négation élément par élément du Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les entiers signés: négation des entiers.
- Pour les entiers non signés: bitcast en entier signé, négation entière, retransmission de bit en entier non signé.
- Pour les flottants:
negate
à partir de la norme IEEE-754. - Pour les nombres complexes: négation complexe.
- Pour les types quantifiés :
dequantize_op_quantize(negate, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
(ne vivant pas
Sémantique
Effectue l'opérateur NOT au niveau des éléments du Tensor operand
et génère un Tensor result
.
Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: logique NOT.
- Pour les nombres entiers: NOT au niveau du bit.
Arguments
Nom | Type | Contraintes |
---|---|---|
operand |
Tensor de type booléen ou entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Sémantique
S'assure que les opérations qui génèrent la operand
sont exécutées avant toute opération qui dépend de la result
et empêche les transformations du compilateur de déplacer les opérations au-delà de la barrière. En dehors de cela, l'opération est une identité, par exemple result = operand
.
Arguments
Nom | Type | Contraintes |
---|---|---|
operand |
nombre varié de Tensors, Tensors quantifiés par Tensor ou jetons | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
nombre varié de Tensors, Tensors quantifiés par Tensor ou jetons | (C1) |
Contraintes
- (C1)
type(operand...) = type(result...)
.
Exemples
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
ou
Sémantique
Effectue l'opérateur OR par élément des deux Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: opérateur logique OR.
- Pour les entiers: OR au niveau du bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier ou booléen | (C1) |
(I2) | rhs |
Tensor de type entier ou booléen | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier ou booléen | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
flux de sortie
Sémantique
Écrit inputs
dans le flux de sortie et génère un jeton result
.
La sémantique de outfeed_config
est définie par l'implémentation.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre varié de Tensors ou de Tensors quantifiés |
(I2) | token |
token |
(I3). | outfeed_config |
constante de type string |
Sorties
Nom | Type |
---|---|
result |
token |
Exemples
%result = "stablehlo.outfeed"(%inputs0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
Sémantique
Développe operand
par une marge intérieure autour du Tensor, ainsi qu'entre les éléments du Tensor avec la padding_value
donnée.
edge_padding_low
et edge_padding_high
spécifient la quantité de marge intérieure ajoutée dans la limite inférieure (à côté de l'index 0) et dans la limite supérieure (à côté de l'indice le plus élevé) de chaque dimension. La valeur de la marge intérieure négative peut être négative, la valeur absolue indiquant le nombre d'éléments à supprimer de la dimension spécifiée.
interior_padding
spécifie la quantité de marge intérieure ajoutée entre deux éléments de chaque dimension, qui ne peut pas être négative. La marge intérieure intérieure se produit avant le remplissage du bord, de sorte que la marge intérieure négative supprime les éléments de l'opérande de remplissage intérieur.
Plus formellement, result[result_index]
est défini comme suit:
operand[operand_index]
siresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.padding_value
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor à 0 dimension ou Tensor quantifié par Tensor | (C1) |
(I3). | edge_padding_low |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4) |
(I4). | edge_padding_high |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4) |
(I5). | interior_padding |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C3-C6) |
Contraintes
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Exemples
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Sémantique
Génère partition_id
du processus actuel.
Sorties
Nom | Type |
---|---|
result |
Tensor à 0 dimension de type ui32 |
Exemples
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Sémantique
Effectue un comptage par élément du nombre de bits définis dans le Tensor operand
et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
puissance
Sémantique
Effectue une exponentielle par élément du Tensor lhs
par le Tensor rhs
et produit un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les entiers: exponentielle entière.
- Pour les flottants:
pow
à partir de la norme IEEE-754. - Pour les nombres complexes: exponentielle complexe.
- Pour les types quantifiés:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Sémantique
Extrait la partie réelle, élément par élément, de operand
et génère un Tensor result
. Plus formellement, pour chaque élément x
: real(x) = is_complex(x) ? real_part(x) : x
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
est défini comme suit :complex_element_type(element_type(operand))
siis_complex(operand)
.element_type(operand)
dans les autres cas.
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Sémantique
Reçoit les données d'un canal avec channel_id
et génère results
.
Si is_host_transfer
est défini sur true
, l'opération transfère les données de l'hôte. Sinon, les données seront transférées depuis un autre appareil. Cela signifie que l'implémentation
est définie. Cet indicateur duplique les informations fournies dans channel_type
. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir (#666).
results
sont constitués de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux sorties distinctes pour améliorer la clarté (#670).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
constante de type si64 |
|
(I3). | channel_type |
énumération de DEVICE_TO_DEVICE et HOST_TO_DEVICE |
(C1) |
(I4). | is_host_transfer |
constante de type i1 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors, de Tensors quantifiés ou de jetons | (C2-C4) |
Contraintes
- (C1)
channel_type
est défini comme suit :HOST_TO_DEVICE
siis_host_transfer = true
,DEVICE_TO_DEVICE
dans les autres cas.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Exemples
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Sémantique
Applique une fonction de réduction body
à inputs
et init_values
le long de dimensions
et génère des Tensors results
.
L'ordre des réductions est défini par l'implémentation, ce qui signifie que body
et init_values
doivent former un monoid pour garantir que l'opération produit les mêmes résultats pour toutes les entrées de toutes les implémentations. Cependant, cette condition ne s'applique pas à de nombreuses réductions courantes. Par exemple, l'addition à virgule flottante pour body
et le zéro pour init_values
ne forment pas réellement un monoid, car l'addition à virgule flottante n'est pas associative.
Plus formellement, results...[j0, ..., jR-1] = reduce(input_slices_converted)
où:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, où les:
sont insérés au niveau dedimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
pour une arborescence binaireschedule
où :exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
est une arborescence binaire complète définie par l'implémentation dont le balayage dans l'ordre comprend les éléments suivants :- Valeurs
input_slices_converted...[index]
, pour tous lesindex
deindex_space(input_slices_converted)
dans l'ordre lexicographique croissant deindex
. - Il est intercalé avec un nombre défini de
init_values_converted
à des positions définies par l'implémentation.
- Valeurs
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C1-C4), (C6), (C7) |
(I2) | init_values |
nombre varié de Tensors à 0 dimension ou de Tensors quantifiés par Tensor | (C2), (C3) |
(I3). | dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C4), (C5), (C7) |
(I4). | body |
function | (C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C3), (C7), (C8) |
Contraintes
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, oùis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, si ce n'est que les tailles de dimensions deinputs...
correspondant àdimensions
ne sont pas incluses. - (C8)
element_type(results[i]) = Ei
pour tous lesi
dans[0,N)
.
Exemples
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Sémantique
Effectue la conversion de operand
au niveau de l'élément en un autre type à virgule flottante qui utilise exponent_bits
et mantissa_bits
, puis revient au type à virgule flottante d'origine et génère un Tensor output
.
Plus formellement:
- Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine à la valeur représentable la plus proche avec
mantissa_bits
à l'aide de la sémantiqueroundToIntegralTiesToEven
. - Ensuite, si la valeur de
mantissa_bits
est inférieure au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués àmantissa_bits
. - Ensuite, si les bits d'exposant du résultat intermédiaire ne correspondent pas à la plage fournie par
exponent_bits
, le résultat intermédiaire déborde à l'infini avec le signe d'origine ou est insuffisant à zéro à l'aide du signe d'origine. - Pour les types quantifiés, exécute
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
(I2) | exponent_bits |
constante de type si32 |
(C2) |
(I3). | mantissa_bits |
constante de type si32 |
(C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Exemples
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, effectue une réduction sur les valeurs du Tensor operand
de chaque processus à l'aide de computations
, divise le résultat de la réduction en plusieurs parties avec scatter_dimension
, puis disperse les parties divisées entre les processus pour produire result
.
L'opération divise la grille de processus StableHLO en process_groups
, défini comme suit:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
pour tous lessender
dansprocess_group
, oùreceiver_index = process_group.index(receiver)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
constante de type si64 |
(C1), (C2), (C8) |
(I3). | replica_groups |
Constante de Tensor bidimensionnelle de type si64 |
(C3-C5) |
(I4). | channel_id |
constante de type si64 |
(C6) |
(I5). | use_global_device_ids |
constante de type i1 |
(C6) |
(I6). | computation |
function | (C7) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C8-C9) |
Contraintes
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
est défini comme suit :num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Si la valeur est
use_global_device_ids = true
, alorschannel_id > 0
. - (C7)
computation
est de type(tensor<E>, tensor<E>) -> (tensor<E>)
, oùis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
, sauf :dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Sémantique
Applique une fonction de réduction body
aux fenêtres de inputs
et init_values
, et génère results
.
Le schéma suivant montre comment les éléments de results...
sont calculés à partir de inputs...
à l'aide d'un exemple concret.
Plus formellement, results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(voir réduire) où:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
nombre varié de Tensors à 0 dimension ou de Tensors quantifiés par Tensor | (C1), (C13) |
(I3). | window_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C4), (C5), (C15) |
(I4). | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C6), (C7), (C15) |
(I5). | base_dilations |
Constante de Tensor unidimensionnelle de type si64 |
(C8), (C9), (C15) |
(I6). | window_dilations |
Constante de Tensor unidimensionnelle de type si64 |
(C10), (C11), (C15) |
(I7) | padding |
Constante de Tensor bidimensionnelle de type si64 |
(C12), (C15) |
(I8). | body |
function | (C13) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C1), (C14-C16) |
Contraintes
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, oùis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
où :dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
pour tous lesi
dans[0,N)
.
Exemples
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
reste
Sémantique
Effectue le reste par élément des Tensors de dividende lhs
et rhs
du diviseur, et produit un Tensor result
.
Plus formellement, le signe du résultat est tiré du dividende, et la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur.
Le reste est calculé comme suit : lhs - d * rhs
, où d
est calculé comme suit :
- Pour les entiers:
stablehlo.divide(lhs, rhs)
. - Pour les valeurs à virgule flottante:
division(lhs, rhs)
à partir de la norme IEEE-754 avec l'attribut d'arrondiroundTowardZero
. - Pour les nombres complexes: à déterminer (#997).
- Pour les types quantifiés :
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Pour les types d'éléments à virgule flottante, cette opération diffère de l'opération remainder
de la spécification IEEE-754, où d
est une valeur intégrale la plus proche de la valeur exacte de lhs/rhs
avec un lien égal.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor d'entiers, de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Sémantique
Génère replica_id
du processus actuel.
Sorties
Nom | Type |
---|---|
result |
Tensor à 0 dimension de type ui32 |
Exemples
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
remodeler
Sémantique
Effectue le remodelage du Tensor operand
en un Tensor result
. Conceptuellement, cela revient à conserver la même représentation canonique, mais en modifiant éventuellement la forme, par exemple de tensor<2x3xf32>
à tensor<3x2xf32>
ou tensor<6xf32>
.
Plus formellement, result[result_index] = operand[operand_index]
, où result_index
et operand_index
ont la même position dans l'ordre lexicographique de index_space(result)
et index_space(operand)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1-C3) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, si ce n'est quequantization_dimension(operand)
etquantization_dimension(result)
peuvent différer.
- (C2)
size(operand) = size(result)
. - (C3) Si
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
Sémantique
Inverse l'ordre des éléments dans operand
selon le dimensions
spécifié et génère un Tensor result
. Plus formellement, result[result_index] = operand[operand_index]
où:
operand_index[d] = dim(result, d) - result_index[d] - 1
sid
dansdimensions
.operand_index[d] = result_index[d]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C3) |
(I2) | dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C3) |
Contraintes
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Exemples
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Sémantique
Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution
et génère un Tensor result
d'une forme donnée shape
.
Si la valeur est rng_distribution = UNIFORM
, les nombres aléatoires sont générés en suivant la distribution uniforme sur l'intervalle [a, b)
. Si la valeur est a >= b
, le comportement n'est pas défini.
Si la valeur est rng_distribution = NORMAL
, les nombres aléatoires sont générés en suivant la distribution normale avec une moyenne de a
et un écart-type de b
.
Si la valeur est b < 0
, le comportement n'est pas défini.
La manière exacte dont les nombres aléatoires sont générés est définie par l'implémentation. Par exemple, elles peuvent être déterministes ou non, et utiliser ou non un état caché.
Lors de conversations avec de nombreuses personnes concernées, cette opération est devenue obsolète. Nous prévoyons donc d'envisager de la supprimer à l'avenir (#597).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor à 0 dimension de type entier, booléen ou à virgule flottante | (C1), (C2) |
(I2) | b |
Tensor à 0 dimension de type entier, booléen ou à virgule flottante | (C1), (C2) |
(I3). | shape |
Constante de Tensor unidimensionnelle de type si64 |
(C3) |
(I4). | rng_distribution |
énumération de UNIFORM et NORMAL |
(C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, booléen ou à virgule flottante | (C1-C3) |
Contraintes
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Si la valeur est
rng_distribution = NORMAL
, alorsis_float(a)
. - (C3)
shape(result) = shape
.
Exemples
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Sémantique
Renvoie une output
remplie de bits aléatoires uniformes et un état de sortie mis à jour output_state
à l'aide de l'algorithme de générateur de nombres pseudo-aléatoires rng_algorithm
en fonction d'un état initial initial_state
. Le résultat sera une fonction déterministe de initial_state
, mais il n'est pas garanti qu'il soit déterministe entre les implémentations.
rng_algorithm
est l'un des éléments suivants :
DEFAULT
: algorithme défini par l'implémentation.THREE_FRY
: variante définie par l'implémentation de l'algorithme Threefry*.PHILOX
: variante de l'algorithme Philox définie par l'implémentation*.
* Voir: Salmon et al. SC 2011. Les nombres aléatoires parallèles: 1, 2, 3...
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | rng_algorithm |
énumération de DEFAULT , THREE_FRY et PHILOX |
(C2) |
(I2) | initial_state |
Tensor unidimensionnel de type ui64 |
(C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output_state |
Tensor unidimensionnel de type ui64 |
(C1) |
output |
Tensor de type entier ou à virgule flottante |
Contraintes
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
est défini comme suit :- L'implémentation est définie si
rng_algorithm = DEFAULT
. 2
sirng_algorithm = THREE_FRY
.2
ou3
sirng_algorithm = PHILOX
.
- L'implémentation est définie si
Exemples
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Sémantique
Effectue un arrondi par élément par élément vers l'entier le plus proche, en dissociant les liaisons de zéro sur le Tensor operand
, et en générant un Tensor result
. Met en œuvre l'opération roundToIntegralTiesToAway
à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Sémantique
Effectue un arrondi par élément par élément vers l'entier le plus proche, en rompant les liens avec l'entier pair sur le Tensor operand
et génère un Tensor result
. Met en œuvre l'opération roundToIntegralTiesToEven
à partir de la spécification IEEE-754. Pour les types quantifiés, exécute dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rigueur
Sémantique
Effectue une opération de racine carrée réciproque par élément sur le Tensor operand
et produit un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
rSqrt
à partir de la norme IEEE-754. - Pour les nombres complexes: racine carrée réciproque complexe.
- Pour les types quantifiés:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
Sémantique
Génère des Tensors results
, qui sont égaux aux Tensors inputs
, à la différence que plusieurs tranches spécifiées par scatter_indices
sont mises à jour avec les valeurs updates
à l'aide de update_computation
.
Le schéma suivant montre comment les éléments de updates...
sont mappés sur des éléments de results...
à l'aide d'un exemple concret. Le schéma choisit quelques exemples d'indices updates...
et explique en détail à quels indices results...
ils correspondent.
Plus formellement, pour tous les update_index
de index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
est défini comme suit :scatter_indices[si0, ..., :, ..., siN]
, où lessi
sont des éléments individuels dansupdate_scatter_index
, et:
est inséré au niveau de l'indexindex_vector_dim
, siindex_vector_dim
<rank(scatter_indices)
.[scatter_indices[update_scatter_index]]
dans les autres cas.
- Pour
d_input
dansaxes(inputs[0])
:full_start_index[d_input] = start_index[d_start]
sid_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
dans les autres cas.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, oùwi
sont des éléments individuels dansupdate_window_index
, et0
est inséré aux index deinserted_window_dims
.result_index = full_start_index + full_window_index
.
Compte tenu de cela, results = exec(schedule, inputs)
, où:
schedule
est une permutation deindex_space(updates[0])
définie par l'implémentation.exec([update_index, ...], results) = exec([...], updated_results)
où :- Si
result_index
est compris dans les limites deshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
est une copie deresults
avecresults...[result_index]
défini surupdated_values...
.- Sinon, procédez comme suit :
updated_results = results
.
- Si
exec([], results) = results
.
Si indices_are_sorted
est défini sur true
, l'implémentation peut supposer que les scatter_indices
sont triés par rapport à scatter_dims_to_operand_dims
. Sinon, le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2
de indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Si unique_indices
est défini sur true
, l'implémentation peut supposer que tous les indices de dispersion de result_index
sont uniques. Si unique_indices
est défini sur true
, mais que les indices de dispersion ne sont pas uniques, le comportement n'est pas défini.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C1), (C2), (C4-C6), (C10), (C13), (C15-C16) |
(I2) | scatter_indices |
Tensor de type entier | (C4), (C11), (C14) |
(I3). | updates |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C3-C6), (C8) |
(I4). | update_window_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C7), (C8) |
(I5). | inserted_window_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C9), (C10) |
(I6). | scatter_dims_to_operand_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C11-C13) |
(I7) | index_vector_dim |
constante de type si64 |
(C4), (C11), (C14) |
(I8). | indices_are_sorted |
constante de type i1 |
|
(I9). | unique_indices |
constante de type i1 |
|
(I10). | update_computation |
function | (C15) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C15-C17) |
Contraintes
- (C1)
same(shape(inputs...))
. - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
. - (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
où :update_scatter_dim_sizes = shape(scatter_indices)
, si ce n'est que la taille de la dimension descatter_indices
correspondant àindex_vector_dim
n'est pas incluse.update_window_dim_sizes <= shape(inputs[0])
, si ce n'est que les dimensions des dimensions dansinputs[0]
correspondant àinserted_window_dims
ne sont pas incluses.combine
placeupdate_scatter_dim_sizes
sur les axes correspondant àupdate_scatter_dims
etupdate_window_dim_sizes
aux axes correspondant àupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(inserted_window_dims) and is_sorted(update_window_dims)
. - (C10)
0 <= inserted_window_dims < rank(inputs[0])
. - (C11)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C12)
is_unique(scatter_dims_to_operand_dims)
. - (C13)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C14)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C15)
update_computation
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, oùis_promotable(element_type(inputs[i]), Ei)
. - (C16)
shape(inputs...) = shape(results...)
. - (C17)
element_type(results[i]) = Ei
pour tous lesi
dans[0,N)
.
Exemples
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
// [[1, 2], [5, 6], [7, 8], [7, 8]],
// [[10, 11], [12, 13], [14, 15], [16, 17]],
// [[18, 19], [20, 21], [21, 22], [23, 24]]
// ]
select
Sémantique
Génère un Tensor result
, où chaque élément est sélectionné à partir du Tensor on_true
ou on_false
en fonction de la valeur de l'élément correspondant de pred
.
Plus formellement, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, où pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Pour les types quantifiés, exécute dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | pred |
Tensor de type i1 |
(C1) |
(I2) | on_true |
Tensor ou Tensor quantifié par Tensor | (C1-C2) |
(I3). | on_false |
Tensor ou Tensor quantifié par Tensor | (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C2) |
Contraintes
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Exemples
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Sémantique
Disperse les valeurs du Tensor source
à l'aide de scatter
en fonction du résultat du reduce_window
du Tensor input
à l'aide de select
et génère un Tensor result
.
Le schéma suivant montre comment les éléments de result
sont calculés à partir de operand
et source
à l'aide d'un exemple concret.
Plus formellement:
selected_values = reduce_window_without_init(...)
avec les entrées suivantes:- `inputs = [opérande].
window_dimensions
,window_strides
etpadding
, qui sont utilisés tels quels.base_dilations = windows_dilations = 1
.body
est défini comme suit:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
où
E = element_type(operand)
etreduce_window_without_init
fonctionnent exactement commereduce_window
, sauf que leschedule
dureduce
sous-jacent (voir Réduire) n'inclut pas de valeurs d'initialisation. Actuellement, ce qui se passe si la fenêtre correspondante ne possède pas de valeur n'est pas spécifiée (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
où:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
siselected_values[source_index]
contient l'élémentoperand
deoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C4), (C6), (C8-C11) |
(I2) | source |
Tensor ou Tensor quantifié par Tensor | (C1), (C2) |
(I3). | init_value |
Tensor à 0 dimension ou Tensor quantifié par Tensor | (C3) |
(I4). | window_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C5) |
(I5). | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C6), (C7) |
(I6). | padding |
Constante de Tensor bidimensionnelle de type si64 |
(C2), (C8) |
(I7) | select |
function | (C9) |
(I8). | scatter |
function | (C10) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C11-C12) |
Contraintes
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
où :padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
est de type(tensor<E>, tensor<E>) -> tensor<i1>
, oùE = element_type(operand)
. - (C10)
scatter
est de type(tensor<E>, tensor<E>) -> tensor<E>
, oùis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Exemples
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
envoyer
Sémantique
Envoie inputs
à un canal channel_id
et génère un jeton result
.
Si is_host_transfer
est défini sur true
, l'opération transfère les données à l'hôte. Sinon, les données sont transférées vers un autre appareil. Cela signifie que l'implémentation
est définie. Cet indicateur duplique les informations fournies dans channel_type
. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir (#666).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre varié de Tensors ou de Tensors quantifiés | |
(I2) | token |
token |
|
(I3). | channel_id |
constante de type si64 |
|
(I4). | channel_type |
énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST |
(C1) |
(I5). | is_host_transfer |
constante de type i1 |
(C1) |
Sorties
Nom | Type |
---|---|
result |
token |
Contraintes
- (C1)
channel_type
est défini comme suit :DEVICE_TO_HOST
siis_host_transfer = true
,DEVICE_TO_DEVICE
dans les autres cas.
Exemples
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Sémantique
Effectue un décalage à gauche par élément sur le Tensor lhs
en fonction d'un nombre de bits rhs
et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier | (C1) |
(I2) | rhs |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Sémantique
Effectue un décalage arithmétique à droite par élément sur le Tensor lhs
en fonction d'un nombre de bits de rhs
, et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier | (C1) |
(I2) | rhs |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Sémantique
Effectue un décalage logique élémentaire vers la droite sur le Tensor lhs
en fonction d'un nombre de bits rhs
et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier | (C1) |
(I2) | rhs |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
signe "=".
Sémantique
Renvoie le signe de operand
par élément et génère un Tensor result
.
Plus formellement, pour chaque élément x
, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Pour les types quantifiés, exécute dequantize_op_quantize(sign, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor d'un entier signé, d'un type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor d'un entier signé, d'un type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Sémantique
Effectue une opération sinus par élément sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
sin
à partir de la norme IEEE-754. - Pour les nombres complexes: sinus complexe.
- Pour les types quantifiés:
dequantize_op_quantize(sine, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
tranche
Sémantique
Extrait une tranche de operand
à l'aide d'indices de départ calculés de manière statique et génère un Tensor result
. start_indices
contient les index de départ de la tranche pour chaque dimension, limit_indices
contient les index de fin (exclusifs) de la tranche pour chaque dimension, et strides
contient les progrès de chaque dimension.
Plus formellement, result[result_index] = operand[operand_index]
, où operand_index = start_indices + result_index * strides
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C3), (C5) |
(I2) | start_indices |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3), (C5) |
(I3). | limit_indices |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3), (C5) |
(I4). | strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Exemples
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = dense<[1, 2]> : tensor<2xi64>,
limit_indices = dense<[3, 4]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
sort
Sémantique
Trie les tranches unidimensionnelles de inputs
le long de la dimension dimension
en fonction d'une comparator
et génère results
.
Contrairement aux entrées similaires dans d'autres opérations, dimension
autorise les valeurs négatives, avec la sémantique décrite ci-dessous. À l'avenir, cette action peut être interdite pour des raisons de cohérence (#1377).
Si is_stable
est "true", le tri est stable, c'est-à-dire que l'ordre relatif des éléments considérés comme égaux par le comparateur est conservé. Dans le cas où il y a une seule entrée, les deux éléments e1
et e2
sont considérés comme égaux par le comparateur si et seulement si comparator(e1, e2) = comparator(e2, e1) = false
. Reportez-vous à la formalisation ci-dessous pour découvrir comment cela se généralise à plusieurs entrées.
Plus formellement, pour tous les result_index
de index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, oùriN
sont des éléments individuels dansresult_index
, et:
est inséré àadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- où
sort
trie une tranche unidimensionnelle dans l'ordre non décroissant en s'attendant à ce quecomparator_together
renvoietrue
si l'argument de gauche est inférieur au deuxième argument de droite. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C1-C5) |
(I2) | dimension |
constante de type si64 |
(C4) |
(I3). | is_stable |
constante de type i1 |
|
(I4). | comparator |
function | (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors ou Tensors quantifiés par Tensor | (C2), (C3) |
Contraintes
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, oùR = rank(inputs[0])
. - (C5)
comparator
est de type(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, oùEi = element_type(inputs[i])
.
Exemples
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Sémantique
Effectue une opération de racine carrée élément par élément sur le Tensor operand
et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
squareRoot
à partir de la norme IEEE-754. - Pour les nombres complexes: racine carrée complexe.
- Pour les types quantifiés:
dequantize_op_quantize(sqrt, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Sémantique
Effectue une soustraction par élément de deux Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les nombres entiers: soustraction d'entiers.
- Pour les flottants:
subtraction
à partir de la norme IEEE-754. - Pour les nombres complexes: soustraction complexe.
- Pour les types quantifiés :
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
Tanh
Sémantique
Effectue une opération de tangente hyperbolique par élément sur le Tensor operand
et produit un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les flottants:
tanh
à partir de la norme IEEE-754. - Pour les nombres complexes: tangente hyperbolique complexe.
- Pour les types quantifiés :
dequantize_op_quantize(tanh, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transposer
Sémantique
Permute les dimensions du Tensor operand
à l'aide de permutation
et génère un Tensor result
. Plus formellement, result[result_index] = operand[operand_index]
, où result_index[d] = operand_index[permutation[d]]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C4) |
(I2) | permutation |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1), (C3-C4) |
Contraintes
- (C1)
element_type(result)
est donné par :element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, si ce n'est quequantization_dimension(operand)
etquantization_dimension(result)
peuvent différer.
- (C2)
permutation
est une permutation derange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Si la valeur est
is_per_axis_quantized(result)
, alorsquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Exemples
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Sémantique
Résoudre des lots de systèmes d'équations linéaires avec des matrices triangulaires inférieures ou supérieures.
Plus formellement, avec a
et b
, result[i0, ..., iR-3, :, :]
est la solution pour op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
lorsque left_side
est true
ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
lorsque left_side
est false
, et résout la variable x
où op(a)
est déterminé par transpose_a
, qui peut être l'un des éléments suivants:
NO_TRANSPOSE
: effectuer l'opération en utilisanta
tel quel.TRANSPOSE
: effectuer une opération sur la transposition dea
.ADJOINT
: effectuer une opération sur la transposition conjuguée dea
.
Les données d'entrée sont lues uniquement à partir du triangle inférieur de a
. Sinon, lower
correspond à true
ou au triangle supérieur de a
. Les données de sortie sont renvoyées dans le même triangle. Les valeurs de l'autre triangle sont définies par l'implémentation.
Si unit_diagonal
est "true", l'implémentation peut supposer que les éléments diagonaux de a
sont égaux à 1. Sinon, le comportement n'est pas défini.
Pour les types quantifiés, exécute dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1-C3) |
(I2) | b |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1-C4) |
(I3). | left_side |
constante de type i1 |
(C3) |
(I4). | lower |
constante de type i1 |
|
(I5). | unit_diagonal |
constante de type i1 |
|
(I6). | transpose_a |
énumération de NO_TRANSPOSE , TRANSPOSE et ADJOINT |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou d'un Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relation entre
shape(a)
etshape(b)
est définie comme suit :shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Exemples
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Sémantique
Génère un tuple result
à partir des valeurs val
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | val |
nombre varié de valeurs | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tuple | (C1) |
Contraintes
- (C1)
result
est de typetuple<E0, ..., EN-1>
, oùEi = type(val[i])
.
Exemples
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Sémantique
Effectue la conversion par élément du Tensor quantifié operand
en Tensor à virgule flottante result
en fonction des paramètres de quantification définis par le type operand
.
Plus formellement : result = dequantize(operand)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor quantifié | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Exemples
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Sémantique
Effectue la conversion par élément d'un Tensor à virgule flottante ou d'un Tensor quantifié operand
en un Tensor quantifié result
en fonction des paramètres de quantification définis par le type result
.
Plus formellement,
- Si
is_float(operand)
:result = quantize(operand, type(result))
.
- Si
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou quantifié | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor quantifié | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Exemples
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
alors que
Sémantique
Génère le résultat de l'exécution de la fonction body
zéro fois ou plus, tandis que la fonction cond
génère true
. Plus formellement, la sémantique peut être exprimée à l'aide de la syntaxe Python comme suit:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Le comportement d'une boucle infinie est à déterminer (#383).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
nombre varié de Tensors, de Tensors quantifiés ou de jetons | (C1-C3) |
(I2) | cond |
function | (C1) |
(I3). | body |
function | (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre varié de Tensors, de Tensors quantifiés ou de jetons | (C3) |
Contraintes
- (C1)
cond
est de type(T0, ..., TN-1) -> tensor<i1>
, oùTi = type(operand[i])
. - (C2)
body
est de type(T0, ..., TN-1) -> (T0, ..., TN-1)
, oùTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Exemples
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
Xor
Sémantique
Effectue l'opération XOR par élément des deux Tensors lhs
et rhs
, et génère un Tensor result
. Selon le type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: logique XOR.
- Pour les entiers: opération XOR au niveau du bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type booléen ou entier | (C1) |
(I2) | rhs |
Tensor de type booléen ou entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Exécution
Exécution séquentielle
Pour exécuter un programme StableHLO, il faut fournir les valeurs d'entrée à la fonction main
et calculer les valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées en exécutant le graphe des opérations en mode root dans l'opération return
correspondante.
L'ordre d'exécution est défini par l'implémentation, à condition qu'il soit aligné sur Dataflow, c'est-à-dire si les opérations sont exécutées avant leur utilisation. Dans StableHLO, toutes les opérations ayant des effets secondaires consomment un seul jeton (plusieurs jetons peuvent être multiplexés en un seul jeton via after_all
). L'ordre d'exécution des effets secondaires est donc également aligné sur Dataflow. Les ordres d'exécution possibles de l'exemple de programme ci-dessus sont %0
→ %1
→ %2
→ %3
→ %4
→ return
ou %3
→ %0
→
%1
→ %2
→ %4
→ return
.
Plus formellement, un processus StableHLO est une combinaison des éléments suivants : 1) un programme StableHLO, 2) des états d'opération (pas encore exécutés, déjà exécutés) et 3) des valeurs intermédiaires sur lesquelles le processus travaille.
Le processus commence par les valeurs d'entrée dans la fonction main
, passe par le graphique des opérations mettant à jour les états des opérations et les valeurs intermédiaires, et se termine par les valeurs de sortie. D'autres formalisations sont à déterminer (#484).
Exécution parallèle
Les programmes StableHLO peuvent être exécutés en parallèle et sont organisés dans une grille de processus 2D de num_replicas
par num_partitions
, de type ui32
.
Dans la grille de processus StableHLO, num_replicas * num_partitions
des processus StableHLO s'exécutent en même temps. Chaque processus possède un process_id = (replica_id, partition_id)
unique, où replica_id
dans replica_ids = range(num_replicas)
et partition_id
dans partition_ids = range(num_partitions)
, qui sont tous deux de type ui32
.
La taille de la grille de processus est connue de manière statique pour chaque programme (nous prévoyons d'en faire une partie explicite à l'avenir des programmes StableHLO n° 650), et la position dans la grille de processus est connue de manière statique pour chaque processus. Chaque processus a accès à sa position dans la grille de processus via les opérations replica_id
et partition_id
.
Dans la grille de processus, les programmes peuvent tous être identiques (dans le style "Programme unique, données multiples") ou différents (dans le style "Programme multiple, données multiples") ou intermédiaire. À l'avenir, nous prévoyons de prendre en charge d'autres idiomes permettant de définir des programmes StableHLO parallèles, y compris GSPMD (#619).
Au sein de la grille de processus, les processus sont pour la plupart indépendants les uns des autres. Ils ont des états d'opération distincts, des valeurs d'entrée/intermédiaire/sortie distinctes et la plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.
Étant donné que l'exécution de la plupart des opérations n'utilise que des valeurs du même processus, il est généralement clair de faire référence à ces valeurs par leur nom.
Toutefois, lorsque vous décrivez la sémantique des opérations collectives, cela est insuffisant, et cela génère la notation name@process_id
pour faire référence à la valeur name
dans un processus particulier. (De ce point de vue, un name
non qualifié peut être considéré comme un raccourci pour name@(replica_id(), partition_id())
.)
L'ordre d'exécution des processus est défini par l'implémentation, à l'exception de la synchronisation introduite par la communication point à point et les opérations collectives, comme décrit ci-dessous.
Communication point à point
Les processus StableHLO peuvent communiquer entre eux via des canaux StableHLO. Un canal est représenté par un ID positif de type si64
. Via différentes opérations, il est possible d'envoyer des valeurs aux canaux et de les recevoir de ces canaux.
Une formalisation supplémentaire (par exemple, la provenance de ces ID de canaux, la manière dont les processus les programmes prennent en compte et le type de synchronisation qu'ils introduit) est à déterminer (#484).
Communication en flux continu
Chaque processus StableHLO a accès à deux interfaces de streaming:
- Infeed (Flux d'entrée) pouvant être lu.
- OutFeed sur lequel une écriture est possible.
Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et qui ont donc des processus à leurs deux extrémités, l'implémentation des autres extrémités des flux d'entrée et de sortie est définie.
Une formalisation supplémentaire, par exemple la manière dont la communication par flux influence l'ordre d'exécution et le type de synchronisation qu'elle introduit, reste à déterminer (#484).
Opérations collectives
Il existe six opérations collectives dans StableHLO: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
et reduce_scatter
. Toutes ces opérations divisent les processus de la grille de processus StableHLO en groupes de processus StableHLO et exécutent un calcul conjoint dans chaque groupe de processus, indépendamment des autres groupes de processus.
Au sein de chaque groupe de processus, les opérations collectives peuvent introduire une barrière de synchronisation. D'autres formalisations, telles que le moment exact de cette synchronisation, la manière exacte dont les processus arrivent à cet obstacle et ce qui se passe s'ils ne la rencontrent pas, sont à déterminer (#484).
Si le groupe de processus implique une communication entre partitions, c'est-à-dire s'il existe des processus dans le groupe de processus dont les ID de partition sont différents, l'exécution de l'opération collective a besoin d'un canal, et l'opération collective doit fournir une valeur channel_id
positive de type si64
. La communication entre les instances répliquées n'a pas besoin de canaux.
Les calculs effectués par les opérations collectives sont spécifiques à des opérations individuelles et sont décrits dans les sections correspondantes ci-dessus. Toutefois, les stratégies par lesquelles la grille de processus est divisée en groupes de processus sont partagées entre ces opérations et sont décrites dans cette section. Plus formellement, StableHLO prend en charge les quatre stratégies suivantes.
cross_replica
Seules les communications entre instances répliquées ont lieu au sein de chaque groupe de processus. Cette stratégie utilise replica_groups
(une liste de listes d'ID d'instances répliquées) et calcule un produit cartésien de replica_groups
par partition_ids
. replica_groups
doit comporter des éléments uniques et couvrir tous les replica_ids
. Plus formellement, en utilisant
la syntaxe Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]]
et num_partitions = 2
, cross_replica
produira [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Seules les communications entre partitions ont lieu dans chaque groupe de processus. Cette stratégie utilise partition_groups
(une liste de listes d'ID de partition) et calcule un produit cartésien de partition_groups
par replica_ids
.
partition_groups
doit comporter des éléments uniques et couvrir l'intégralité des partition_ids
.
Plus formellement, en utilisant la syntaxe Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour partition_groups = [[0, 1]]
et num_replicas = 4
, cross_partition
produira [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Des communications entre instances répliquées et entre partitions peuvent avoir lieu au sein de chaque groupe de processus. Cette stratégie utilise replica_groups
(une liste d'ID d'instances répliquées) et calcule les produits cartésiens de chaque replica_group
par partition_ids
. replica_groups
doit comporter des éléments uniques et couvrir tous les replica_ids
. Plus formellement, en utilisant la syntaxe Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]]
et num_partitions = 2
, cross_replica_and_partition
produira [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Cette stratégie utilise flattened_id_groups
(une liste de listes d'ID de processus "aplatis" sous la forme replica_id * num_partitions + partition_id
) et les transforme en ID de processus. flattened_id_groups
doit comporter des éléments uniques et couvrir tous les process_ids
. Plus formellement, en utilisant la syntaxe Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
et num_partitions = 2
, flattened_ids
produira [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Justesse
Pour le moment, StableHLO n'offre aucune garantie de précision numérique, mais cela est susceptible de changer à l'avenir (#1156).
Erreurs
Les programmes StableHLO sont validés par un ensemble complet de contraintes pour les opérations individuelles, ce qui permet d'exclure de nombreuses classes d'erreurs avant l'exécution. Toutefois, des conditions d'erreur restent possibles, par exemple via des dépassements d'entiers, des accès hors limites, etc. Sauf indication contraire, toutes ces erreurs entraînent un comportement défini par l'implémentation, mais cela est susceptible de changer à l'avenir (#1157).
À l'exception de cette règle, les exceptions à virgule flottante dans les programmes StableHLO ont un comportement bien défini. Les opérations qui entraînent des exceptions définies par la norme IEEE-754 (opération non valide, division par zéro, dépassement, dépassement de capacité négatif ou exceptions inexactes) génèrent des résultats par défaut (tels que définis dans la norme) et poursuivent l'exécution sans générer l'indicateur d'état correspondant (comme pour la gestion des exceptions raiseNoFlag
de la norme). Les exceptions pour les opérations non standards (par exemple, les calculs arithmétiques complexes et certaines fonctions transcendantes) sont définies par l'implémentation.
Notation
Pour décrire la syntaxe, ce document utilise le type de syntaxe ISO modifié de la syntaxe EBNF (ISO/IEC 14977:1996, Wikipédia). Deux modifications sont apportées: 1) les règles sont définies à l'aide de ::=
au lieu de =
,
2) La concaténation est exprimée à l'aide de la juxtaposition plutôt que de ,
.
Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constantes" et "Opérations"), nous utilisons des formules basées sur la syntaxe Python étendue avec prise en charge de l'expression concise des opérations de tableau, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans les rares cas où de plus grands extraits de code sont nécessaires, nous utilisons la syntaxe Python vanille, qui est toujours introduite explicitement.
Formules
Découvrons le fonctionnement des formules en nous basant sur un exemple tiré de la spécification dot_general
. L'une des contraintes pour cette opération est la suivante : dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Les noms utilisés dans cette formule proviennent de deux sources: 1) les fonctions globales (par exemple, dim
) ; 2) les définitions des membres de l'élément de programme correspondant (c'est-à-dire les entrées lhs
, lhs_batching_dimensions
, rhs
et rhs_batching_dimensions
) définies dans la section "Entrées" de dot_general
.
Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python avec certaines extensions orientées vers la concision. Pour donner un sens à la formule, transformons-la en syntaxe Python vanille.
A) Dans ces formules, nous utilisons =
pour représenter l'égalité. La première étape pour obtenir la syntaxe Python consiste donc à remplacer =
par ==
, comme suit : dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) De plus, ces formules acceptent les points de suspension (...
), qui transforment les expressions scalaires en expressions de Tensor. En bref, f(xs...)
signifie à peu près "pour chaque x
scalaire du Tensor xs
, calculer un f(x)
scalaire, puis renvoyer tous ces résultats scalaires sous forme de résultat de Tensor". En syntaxe Python vanille, notre exemple de formule se transforme en : [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Grâce aux ellipses, il est souvent possible d'éviter de travailler au niveau des scalaires individuels. Toutefois, dans certains cas délicats, une syntaxe semi-informelle de niveau inférieur peut être utilisée comme dans la formule start_indices[bi0, ..., :, ..., biN]
de la spécification gather
. Par souci de concision, nous ne fournissons pas de formalisme exact pour traduire cette syntaxe en Python vanille, en espérant qu'elle reste intuitivement compréhensible au cas par cas.
Veuillez nous indiquer si certaines formules spécifiques semblent opaques et nous essaierons de les améliorer.
Vous remarquerez également que les formules utilisent des points de suspension pour développer toutes sortes de listes, y compris les Tensors, les listes de Tensors (par exemple, qui peuvent provenir d'un nombre variable de Tensors), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas de formalité exacte (par exemple, les listes ne font même pas partie du système de types StableHLO) et s'appuient plutôt sur la compréhension intuitive.
C) Le dernier moyen de notation notable que nous utilisons est la diffusion implicite. Bien que l'opération StableHLO ne soit pas compatible avec la diffusion implicite, les formules le sont également, à des fins de concision. En résumé, si un scalaire est utilisé dans un contexte où un Tensor est attendu, il est diffusé vers la forme attendue.
Pour continuer l'exemple dot_general
, voici une autre contrainte : 0 <= lhs_batching_dimensions < rank(lhs)
. Comme défini dans la spécification dot_general
, lhs_batching_dimensions
est un Tensor, mais 0
et rank(lhs)
sont tous deux scalaires. Une fois la diffusion implicite appliquée, la formule devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Lorsqu'elle est appliquée à une opération dot_general
particulière, cette formule sera évaluée par rapport à un Tensor de valeurs booléennes. Lorsque des formules sont utilisées comme contraintes, la contrainte est valable si la formule est évaluée sur true
ou sur un Tensor qui ne contient que des éléments true
.
Noms
Dans les formules, la portée lexicale comprend: 1) les fonctions globales, 2) les définitions de membre,
3) définitions locales. La liste des fonctions globales est fournie ci-dessous. La liste des définitions d'éléments dépend de l'élément du programme auquel la notation est appliquée:
- Pour les opérations, les définitions de membre incluent les noms introduits dans les sections "Entrées" et "Sorties".
- Pour tout le reste, les définitions de membre incluent les parties structurelles de l'élément de programme, nommées d'après les non-terminaux EBNF correspondants. La plupart du temps, les noms de ces parties structurelles sont obtenus en convertissant les noms des non-terminaux en snake case (par exemple,
IntegerLiteral
=>integer_literal
). Toutefois, les noms sont parfois abrégés dans le processus (par exemple,QuantizationStorageType
=>storage_type
). Dans ce cas, les noms sont introduits explicitement de la même manière que dans les sections "Entrées" et "Sorties" dans les opérations. - En outre, les définitions de membre incluent toujours
self
pour faire référence à l'élément de programme correspondant.
Valeurs
Lorsque les formules sont évaluées, elles fonctionnent avec les types de valeurs suivants :
1) Value
(valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
) ;
2) Placeholder
(valeurs futures, par exemple lhs
, rhs
ou result
; leurs valeurs réelles ne sont pas encore connues, seuls leurs types sont connus),
3) Type
(types tels que définis dans la section "Types") ;
4) Function
(fonctions globales telles que définies dans la section "Fonctions globales").
Selon le contexte, les noms peuvent faire référence à des valeurs différentes. Plus précisément, la section "Sémantique" des opérations (et des équivalents pour d'autres éléments du programme) définit la logique d'exécution, de sorte que toutes les entrées sont disponibles en tant que Value
.
En revanche, la section "Contraintes" des opérations (et des équivalents) définit une logique de "compilation", c'est-à-dire un élément qui est généralement exécuté avant l'exécution. Par conséquent, seules les entrées constantes sont disponibles en tant que Value
et les autres entrées ne sont disponibles qu'en tant que Placeholder
.
Noms | Dans "Sémantique" | Dans "Contraintes" |
---|---|---|
Fonctions globales | Function |
Function |
Entrées constantes | Value |
Value |
Entrées non constantes | Value |
Placeholder |
Sorties | Value |
Placeholder |
Définitions locales | Dépend de la définition | Dépend de la définition |
Prenons un exemple d'opération transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Pour cette opération, permutation
est une constante. Il est donc disponible en tant que Value
à la fois dans la sémantique et dans les contraintes. En revanche, operand
et result
sont disponibles en tant que Value
dans la sémantique, mais uniquement en tant que Placeholder
dans les contraintes.
Fonctions
Construction de types
Aucune fonction ne peut être utilisée pour construire des types. Nous utilisons directement la syntaxe de type, car elle est généralement plus concise. Par exemple, (tensor<E>, tensor<E>) -> (tensor<E>)
au lieu de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Fonctions sur les types
element_type
est défini sur les types de Tensors et quantifiés, et renvoie respectivement la partieTensorElementType
ouQuantizedTensorElementType
duTensorType
ou duQuantizedTensorType
correspondant.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
vérifie si le typex
peut être promu au typey
. Lorsquex
ety
correspondent à desQuantizedTensorElementType
, la promotion n'est appliquée qu'austorage_type
. Cette version spécifique de la promotion est actuellement utilisée dans le contexte du calcul de la réduction (consultez le document RFC pour plus d'informations).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponible pour tous les types. Par exemple,is_float(x)
renvoietrue
six
est unFloatType
. Six
est une valeur ou un espace réservé, cette fonction est un raccourci pouris_type_name(type(x))
.max_value(x: Type) -> Value
renvoie la valeur maximale d'unTensorElementType
. Six
n'est pas uneTensorElementType
, renvoieNone
.min_value(x: Type) -> Value
renvoie la valeur minimale possible d'unTensorElementType
. Six
n'est pas uneTensorElementType
, renvoieNone
.member_name(x: Value | Placeholder | Type) -> Any
. Disponible pour toutes les définitions de membremember_name
de tous types. Par exemple,tensor_element_type(x)
renvoie la partieTensorElementType
d'un objetTensorType
correspondant. Six
est une valeur ou un espace réservé, cette fonction est un raccourci pourmember_name(type(x))
. Six
n'est pas un type possédant un membre approprié, ou une valeur ou un espace réservé de ce type, renvoieNone
.
Construction de valeurs
operation_name(*xs: Value | Type) -> Value
. Disponible pour toutes les opérations. Par exemple,add(lhs, rhs)
prend deux valeurs de Tensorlhs
etrhs
, et renvoie la sortie de l'évaluation de l'opérationadd
avec ces entrées. Pour certaines opérations commebroadcast_in_dim
, les types de sorties sont "portants", c'est-à-dire nécessaires pour évaluer une opération. Dans ce cas, la fonction utilise ces types comme arguments.
Fonction sur les valeurs
Tous les opérateurs et fonctions Python sont disponibles. Par exemple, les notations d'abonnement et de tranchement de Python peuvent être indexées dans des Tensors, des Tensors quantifiés et des tuples.
to_destination_type(x: Value, destination_type: Type) -> Value
est défini sur des Tensors et renvoie la valeur convertie dex
en fonction destype(x)
etdestination_type
, comme suit:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Une discussion préliminaire a été abordée sur la fusion des opérations convert
, uniform_quantize
et uniform_dequantize
(#1576).
Après la fusion, nous n'avons plus besoin de la fonction ci-dessus. Nous pouvons utiliser le nom de l'opération pour convert
à la place.
is_nan(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si tous les éléments dex
sontNaN
oufalse
dans les autres cas. Six
n'est pas un Tensor, elle renvoieNone
.is_sorted(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si les éléments dex
sont triés par ordre croissant par rapport à l'ordre lexicographique croissant de leurs indices, oufalse
dans le cas contraire. Six
n'est pas un Tensor, la fonction renvoieNone
.is_unique(x: Value) -> Value
est défini sur des Tensors et renvoietrue
six
n'a pas d'éléments en double, oufalse
dans le cas contraire. Six
n'est pas un Tensor, elle renvoieNone
.member_name(x: Value) -> Any
est défini pour toutes les définitions de membremember_name
de l'ensemble des valeurs. Par exemple,real_part(x)
renvoie la partieRealPart
d'un objetComplexConstant
correspondant. Six
n'est pas une valeur associée à un membre approprié, elle renvoieNone
.same(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si les éléments dex
sont tous égaux les uns aux autres, oufalse
dans le cas contraire. Si le Tensor ne comporte aucun élément, il est comptabilisé comme "tous égaux les uns par rapport aux autres ", c'est-à-dire que la fonction renvoietrue
. Six
n'est pas un Tensor, la fonction renvoieNone
.split(x: Value, num_results: Value, axis: Value) -> Value
est défini sur des Tensors et renvoie des tranchesnum_results
dex
le long de l'axeaxis
. Six
n'est pas un Tensor oudim(x, axis) % num_results != 0
, renvoieNone
.
Calculs de formes
axes(x: Value | Placeholder | Type) -> Value
est un raccourci pourrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
est un raccourci pourshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
est un raccourci pourlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
est défini sur des Tensors et renvoie les indicessize(x)
duTensorType
correspondant, trié par ordre lexicographique croissant, par exemple[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Six
n'est pas un type de Tensor, un type de Tensor quantifié, une valeur ou un espace réservé de l'un de ces types, renvoieNone
.rank(x: Value | Placeholder | Type) -> Value
est un raccourci poursize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
est défini dans la section "Fonctions sur les types" viamember_name
.size(x: Value | Placeholder | Type) -> Value
est un raccourci pourreduce(lambda x, y: x * y, shape(x))
.
Calculs de quantification
def baseline_element_type(x: Value | Placeholder | Type) -> Type
est un raccourci pourelement_type(baseline_type(x))
.baseline_type
est défini sur les types de Tensor et les types de Tensor quantifiés, et les transforme en "référence", c'est-à-dire un type ayant la même forme, mais dont les paramètres de quantification du type d'élément sont réinitialisés aux valeurs par défaut. Cette astuce pratique permet de comparer de manière uniforme les types de Tensors et de Tensors quantifiés, ce qui est assez souvent nécessaire. Pour les types quantifiés, cela permet de comparer les types sans tenir compte des paramètres de quantification. Autrement dit,shape
,storage_type
,expressed_type
,storage_min
,storage_max
etquantization_dimension
(pour le type quantifié par axe) doivent tous correspondre, maisscales
etzero points
peuvent être différents.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
est défini sur des types de Tensor quantifiés et les convertit en types de Tensor à virgule flottante. Pour ce faire, les éléments quantifiés qui représentent des valeurs entières du type de stockage sont convertis en valeurs à virgule flottante correspondantes du type exprimé, à l'aide du point zéro et de l'échelle associées au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
est défini sur les types de Tensors à virgule flottante et les convertit en types de Tensor quantifiés. Pour ce faire, les valeurs à virgule flottante du type exprimé sont converties en valeurs entières correspondantes du type de stockage à l'aide du point zéro et de l'échelle associées au type d'élément quantifié.
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
dequantize_op_quantize
permet de spécifier des calculs par élément sur des Tensors quantifiés. Elle déquantifie, c'est-à-dire transforme les éléments quantifiés en types exprimés, effectue une opération, puis quantifie (en d'autres termes, reconvertit les résultats en types de stockage). Pour le moment, cette fonction ne fonctionne que pour la quantification par Tensor. La quantification par axe est en cours (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
Calculs en mode grille
cross_partition(replica_groups: Value) -> Value
. Consultez la section "cross_réplique" ci-dessus.cross_replica(replica_groups: Value) -> Value
. Consultez la section "cross_réplique" ci-dessus.cross_replica_and_partition(replica_groups: Value) -> Value
. Consultez la section "cross_réplique_and_partition" ci-dessus.flattened_ids(replica_groups: Value) -> Value
: consultez la section "flattened_ids" ci-dessus.