StableHLO est un ensemble d'opérations pour les opérations de haut niveau (HLO) dans les modèles de ML (ML). StableHLO fonctionne comme une couche de portabilité entre différents Frameworks et compilateurs de ML: frameworks de ML qui produisent des programmes StableHLO sont compatibles avec les compilateurs de ML qui consomment des programmes StableHLO.
Notre objectif est de simplifier et d'accélérer le développement du ML en créant l'interopérabilité entre différents frameworks de ML (tels que TensorFlow, JAX et PyTorch) et les compilateurs de ML (tels que XLA et IREE). Dans cette optique, fournit une spécification pour le langage de programmation StableHLO.
Cette spécification contient trois sections principales. Tout d'abord, le La section Programmes décrit la structure des programmes StableHLO. qui se composent de fonctions StableHLO qui sont elles-mêmes des opérations StableHLO. Au sein de cette structure, la section Ops spécifie la sémantique opérations individuelles. La section Exécution fournit une sémantique pour toutes ces opérations s'exécutent ensemble dans un programme. Enfin, la fonction La section Notation présente la notation utilisée tout au long de la spécifique.
Pour afficher la spécification d'une version précédente de StableHLO, ouvrez le dépôt à l'adresse version taguée qui vous intéresse. (par exemple, la spécification StableHLO v0.19.0). Pour afficher les modifications survenues lors de chaque chargement de version mineure de StableHLO, consultez le journal des versions dans VhloDialect.td.
Programmes
Program ::= {Func}
Les programmes StableHLO consistent en un nombre arbitraire de fonctions StableHLO.
Vous trouverez ci-dessous un exemple de programme avec une fonction @main
comportant trois entrées
(%image
, %weights
et %bias
) et un résultat. Corps de la fonction
comporte 6 opérations.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Fonctions
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Les fonctions StableHLO (également appelées fonctions nommées) ont un identifiant, des entrées/sorties et un corps. À l'avenir, nous prévoyons de Intégrer des métadonnées supplémentaires pour les fonctions afin d'améliorer la compatibilité avec HLO (#425, #626, n° 740, n° 744).
Identifiants
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Les identifiants StableHLO sont semblables aux identifiants dans de nombreux programmes avec deux particularités: 1) tous les identifiants ont des sigils distinguer différents types d'identifiants, 2) les identifiants de valeur peuvent être entièrement numérique pour simplifier la génération de programmes StableHLO.
Types
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Les types StableHLO sont classés en types de valeurs (également appelés types de première classe), qui représentent des valeurs StableHLO et des types autres que des valeurs qui décrivent d'autres éléments du programme. Les types StableHLO sont similaires aux types dans de nombreux langages de programmation, la principale particularité étant domaine spécifique qui donne des résultats inhabituels (par exemple, des types scalaires ne sont pas des types de valeur).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Les types de Tensor représentent des Tensors, c'est-à-dire des tableaux multidimensionnels. Ils ont un
forme et un type d'élément, où une forme représente des valeurs non négatives ou
tailles de dimensions inconnues dans l'ordre croissant des valeurs
dimensions (également appelées axes) numérotées de 0
à R-1
. La
Le nombre de dimensions R
est appelé classement. Par exemple, tensor<2x3xf32>
est
un type de Tensor avec la forme 2x3
et le type d'élément f32
. Il a deux dimensions
(ou, en d'autres termes, deux axes) - 0e dimension et 1re dimension - dont les tailles
sont 2 et 3. Son classement est de 2.
Les formes peuvent être partiellement ou totalement inconnues (dynamiques). Ex. : tensor<?x2xf64>
est en partie inconnue, tandis que tensor<?x?xf64>
l'est complètement. Dynamique
les dimensions sont représentées à l'aide d'un ?
. Les formes ne peuvent pas être désclassées.
À l'avenir, nous prévoyons d'étendre les types de Tensor au-delà les tailles de dimension et les types d'éléments, par exemple, pour inclure des mises en page (#629) et la parcimonie (#1078)
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nom | Type | Contraintes |
---|---|---|
storage_type |
type d'entier | (C1-C3), (C8) |
storage_min |
constante d'entier | (C1), (C3), (C7) |
storage_max |
constante d'entier | (C2), (C3), (C7) |
expressed_type |
type à virgule flottante | (C4) |
quantization_dimension |
constante facultative entière | (C10-C12) |
scales |
nombre variadique de constantes à virgule flottante | (C4-C6), (C9), (C10), (C13) |
zero_points |
nombre variadique de constantes entières | (C7-C9) |
Les types d'éléments quantifiés représentent des valeurs entières d'un type de stockage dans
la plage comprise entre storage_min
et storage_max
(inclus) correspondant à
valeurs à virgule flottante d'un type exprimé. Pour une valeur entière donnée i
,
la valeur à virgule flottante correspondante f
peut être calculée comme suit :
f = (i - zero_point) * scale
, où scale
et zero_point
sont appelés
paramètres de quantification. Les champs storage_min
et storage_max
sont facultatifs
dans la grammaire, mais leur valeur par défaut est min_value(storage_type)
et
max_value(storage_type)
respectivement. Les types d'éléments quantifiés ont la
les contraintes suivantes:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Si
is_empty(quantization_dimension)
, alorssize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Pour le moment, QuantizationScale
est une constante à virgule flottante, mais il existe
un fort intérêt pour les échelles basées sur des entiers, représentées par des multiplicateurs et
changements de direction. Nous prévoyons de l'examiner prochainement.
(#1404)
Une discussion est en cours sur la sémantique de QuantizationZeroPoint
,
y compris le type, les valeurs et s'il ne peut y avoir qu'un seul ou
potentiellement plusieurs points zéro
dans un type de Tensor quantifié. D'après les
résultats de cette discussion, la spécification autour de zéro point peut changer
ultérieurement (#1405).
Une autre discussion en cours concerne la sémantique de QuantizationStorageMin
.
et QuantizationStorageMax
pour déterminer si des contraintes doivent être
sur ces valeurs et sur les valeurs des Tensors quantifiés
(#1406)
Enfin, nous prévoyons de représenter les échelles inconnues et les valeurs zéro de la même manière que nous prévoyons d'explorer la représentation des (#1407)
Les types de Tensors quantifiés représentent des Tensors avec des éléments quantifiés. Ces Les Tensors sont exactement les mêmes que les Tensors standards, si ce n'est que leurs éléments utilisent des types d'éléments quantifiés au lieu de types d'éléments standards.
Dans les tenseurs quantifiés, la quantification peut être par Tensor, ce qui signifie que
une valeur scale
et une zero_point
pour l'intégralité du Tensor ou par axe,
c'est-à-dire qu'avoir plusieurs scales
et zero_points
, une paire par tranche de
une dimension particulière quantization_dimension
. Plus formellement, dans un Tensor t
avec la quantification par axe, il existe dim(t, quantization_dimension)
tranches
pour quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
,
etc. Tous les éléments de la i
e tranche utilisent scales[i]
et zero_points[i]
comme
leurs paramètres de quantification. Les types de Tensor quantifiés ont les caractéristiques suivantes :
contraintes:
- Pour la quantification par Tensor:
<ph type="x-smartling-placeholder">
- </ph>
- Aucune contrainte supplémentaire.
- Pour la quantification par axe:
<ph type="x-smartling-placeholder">
- </ph>
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Les types de jetons représentent les jetons, c'est-à-dire les valeurs opaques produites et consommées par certaines opérations. Les jetons sont utilisés pour imposer un ordre d'exécution aux opérations comme décrit dans la section Exécution.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Les types de tuples représentent des tuples, c'est-à-dire des listes hétérogènes. Les Tuples sont un héritage
fonctionnalité qui n'existe que
pour la compatibilité avec HLO. Dans HLO, les tuples sont
utilisée pour représenter des entrées et sorties variables. Dans StableHLO, les entrées variables et
les sorties sont prises en charge de manière native, et la seule utilisation de tuples dans StableHLO consiste à
représenter de manière exhaustive l'ABI HLO T
, tuple<T>
et
tuple<tuple<T>>
peut être sensiblement différente en fonction d'une
la mise en œuvre. Nous prévoyons d'apporter des modifications à l'ABI HLO à l'avenir.
ce qui peut nous permettre de supprimer les types de tuples de StableHLO
(#598)
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Les types d'éléments représentent les éléments des types de Tensor. Contrairement à de nombreuses fonctions
langues, ces types ne sont pas de première classe dans StableHLO. Cela signifie que
Les programmes StableHLO ne peuvent pas représenter directement des valeurs de ces types (par conséquent,
représenter des valeurs scalaires de type T
avec un Tensor à 0 dimensions est idiomatique.
de type tensor<T>
).
- Le type booléen représente les valeurs booléennes
true
etfalse
. - Les types d'entiers peuvent être signés (
si
) ou non signés (ui
), et ont l'une des largeurs de bits acceptées (2
,4
,8
,16
,32
ou64
) ; Les typessiN
signés représentent des valeurs entières comprises entre-2^(N-1)
et2^(N-1)-1
inclus, et les typesuiN
non signés représentent des valeurs entières comprises entre0
et2^N-1
inclus. - Il existe plusieurs types à virgule flottante:
<ph type="x-smartling-placeholder">
- </ph>
- les types
f8E4M3FN
etf8E5M2
correspondant respectivement aux Les encodagesE4M3
etE5M2
du format FP8 décrits à la section Formats FP8 pour le deep learning. - Types
f8E4M3FNUZ
etf8E5M2FNUZ
correspondant àE4M3
etE5M2
les encodages des formats FP8 décrits dans Formats numériques 8 bits pour les réseaux de neurones profonds - Type de
f8E4M3B11FNUZ
correspondant à l'encodageE4M3
des formats FP8 décrits dans Entraînement et inférence hybrides à virgule flottante 8 bits (HFP8) pour les réseaux de neurones profonds - Type
bf16
correspondant au formatbfloat16
décrit dans BFloat16: Le secret des hautes performances sur les Cloud TPU - les types
f16
,f32
etf64
correspondant respectivement àbinary16
("demi-précision"),binary32
("précision simple") et les formatsbinary64
("double précision") décrits dans les la norme IEEE 754. - Le type de
tf32
correspond au format TensorFloat32. et offre une compatibilité limitée avec StableHLO.
- les types
- Les types complexes représentent des valeurs complexes ayant une partie réelle.
et une partie imaginaire du même type d'élément. Complexe compatible
les types sont
complex<f32>
(les deux parties sont de typef32
) etcomplex<f64>
(les deux parties sont de typef64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Les types de fonctions représentent les fonctions nommées et anonymes. Ils ont
les types d'entrée (liste des types à gauche de ->
) et les types de sortie
(la liste des types sur la droite de ->
). Dans de nombreux langages de programmation,
langages, les types de fonctions sont de première classe, mais pas dans StableHLO.
StringType ::= 'string'
Le type de chaîne représente des séquences d'octets. Contrairement à de nombreuses fonctions langues, le type de chaîne n'est pas la première classe dans StableHLO et n'est utilisé spécifier des métadonnées statiques pour les éléments du programme.
Opérations
Les opérations StableHLO (également appelées opérations) représentent un ensemble fermé. d'opérations de haut niveau dans des modèles de machine learning. Comme indiqué ci-dessus, La syntaxe StableHLO s'inspire fortement de MLIR, qui n'est pas nécessairement la méthode la plus alternative ergonomique, mais c'est sans doute la meilleure solution pour permettre à StableHLO de ce qui renforce l'interopérabilité entre les frameworks de ML et les compilateurs de ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Les opérations StableHLO (également appelées opérations) ont un nom,
des entrées/sorties et une signature. Ce nom se compose du préfixe stablehlo.
et
Un mnémonique qui identifie de manière unique l'une des opérations prises en charge. Voir ci-dessous pour
une liste complète de toutes les opérations prises en charge.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Les opérations utilisent les entrées et produisent des sorties. Les entrées sont classées en
les valeurs d'entrée (calculées lors de l'exécution), les fonctions d'entrée (fournies
de manière statique, car dans StableHLO, les fonctions ne sont pas des valeurs de première classe) et
d'entrée (également fournis de manière statique). Le type d'entrées et de sorties
consommée et produite par une opération dépend de son mode mnémotechnique. Par exemple, add
"op" consomme deux valeurs d'entrée et génère une valeur de sortie. En comparaison, les
L'opération select_and_scatter
consomme 3 valeurs d'entrée, 2 fonctions d'entrée et
3 attributs d'entrée.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Les fonctions d'entrée (également appelées fonctions anonymes) sont très
semblables aux fonctions nommées, sauf que: 1) elles n'ont pas d'identifiant (donc
le nom "anonyme"), 2) ils ne déclarent pas de types de sortie (les types de sortie sont
inférée à partir de l'opération return
dans la fonction).
La syntaxe des fonctions d'entrée inclut une partie actuellement inutilisée (voir les
Unused
ci-dessus), qui assure la compatibilité avec MLIR. Dans MLIR,
il existe un concept plus général
de « régions » qui peut comporter plusieurs "blocs"
reliés entre eux par des jump ops. Ces blocs ont des ID qui correspondent
à l'environnement de production Unused
, afin de pouvoir les distinguer les uns des autres.
StableHLO n'a pas de jump ops. La partie correspondante
de la syntaxe MLIR est donc
unused (mais est toujours là).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Les attributs d'entrée ont un nom et une valeur faisant partie des
constantes. Elles constituent le principal moyen de spécifier des métadonnées statiques pour un programme
éléments. Par exemple, l'opération concatenate
utilise l'attribut dimension
pour
spécifier la dimension selon laquelle ses valeurs d'entrée sont concaténées. De même,
L'opération slice
utilise plusieurs attributs tels que start_indices
et limit_indices
.
pour spécifier les limites utilisées pour segmenter la valeur d'entrée.
Actuellement, les programmes StableHLO contiennent parfois des attributs qui ne sont pas décrites dans ce document. À l'avenir, nous prévoyons de absorber ces attributs dans l'opset StableHLO ou les interdit dans les programmes StableHLO. En attendant, voici la liste de ces Attributs:
layout
(#629)mhlo.frontend_attributes
(#628)mhlo.sharding
(#619)output_operand_aliases
(#740)- Métadonnées de localisation (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La signature d'opération comprend les types de toutes les valeurs d'entrée (la liste des types sur
à gauche de ->
) et les types de toutes les valeurs de sortie (liste des
à droite de ->
). À proprement parler, les types
d'entrée sont
redondants, et les types de sortie le sont presque toujours également (parce que pour
la plupart des opérations StableHLO, les types de sortie peuvent être déduits des entrées). Néanmoins, l'opération
la signature fait délibérément partie de la syntaxe StableHLO pour assurer la compatibilité avec MLIR.
Vous trouverez ci-dessous un exemple d'opération dont l'expression mnémotechnique est select_and_scatter
. Il en consomme 3
valeurs d'entrée (%operand
, %source
et %init_value
), deux fonctions d'entrée
et trois attributs d'entrée (window_dimensions
, window_strides
et padding
).
Notez que la signature de l'opération n'inclut que les types de ses valeurs d'entrée.
(mais pas les types de fonctions d'entrée et d'attributs qui sont fournis de façon intégrée).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Les constantes StableHLO ont un littéral et un type qui représentent ensemble
une valeur StableHLO. En général, le type fait partie de la syntaxe constante, sauf
lorsqu'elle ne présente aucune ambiguïté (par exemple, une constante booléenne a sans ambiguïté le type i1
,
tandis qu'une constante entière peut avoir plusieurs types possibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Les constantes booléennes représentent les valeurs booléennes true
et false
. Booléen
sont de type i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Les constantes entières représentent des valeurs entières via des chaînes utilisant des nombres décimaux ou en notation hexadécimale. Autres bases (par exemple, binaire ou octale, ne sont pas prises en charge. Les constantes entières présentent les contraintes suivantes:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Les constantes à virgule flottante représentent des valeurs à virgule flottante via des chaînes qui utiliser la notation décimale ou scientifique. De plus, la notation hexadécimale peut être utilisé pour spécifier directement les bits sous-jacents au format à virgule flottante de le type correspondant. Les constantes à virgule flottante présentent les contraintes suivantes:
- (C1) Si vous utilisez la notation non hexadécimale,
is_wellformed(float_literal, float_type)
- (C2) Si la notation hexadécimale
est utilisée,
size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Les constantes complexes représentent des valeurs complexes à l'aide de listes d'une partie réelle
(en premier) et une partie imaginaire (en deuxième). Par exemple :
(1.0, 0.0) : complex<f32>
représente 1.0 + 0.0i
, et
(0.0, 1.0) : complex<f32>
représente 0.0 + 1.0i
. L'ordre dans lequel ces
sont ensuite stockées en mémoire et définies par l'implémentation. Constantes complexes
présentent les contraintes suivantes:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Les constantes Tensor représentent les valeurs des Tensors au moyen de listes imbriquées spécifiées via
Numpy. Exemple : dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
représente une valeur de Tensor avec le mappage suivant entre les index et les éléments:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
{1, 2} => 6
L'ordre dans lequel ces éléments sont ensuite
stockés en mémoire est
définies par l'implémentation. Les constantes Tensor présentent les contraintes suivantes:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, où: <ph type="x-smartling-placeholder">- </ph>
has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, où: <ph type="x-smartling-placeholder">- </ph>
has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- sinon
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Les constantes de Tensors quantifiées représentent les valeurs de Tensor quantifiées à l'aide des mêmes en tant que constantes de Tensor, les éléments étant spécifiés comme des constantes de leur type de stockage. Les constantes de Tensor quantifiées présentent les contraintes suivantes:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Les littéraux de chaîne sont composés d'octets spécifiés à l'aide de caractères ASCII et
et d'échappement. Comme ils sont indépendants de l'encodage, leur interprétation
octets est défini par l'implémentation. Les littéraux de chaîne sont de type string
.
Opérations
abs
Sémantique
Effectue une opération des abscisses par élément sur le Tensor operand
et génère une result
.
Tensor. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les entiers signés: module entier.
- Pour les nombres à virgule flottante:
abs
d'IEEE-754. - Pour les nombres complexes: module complexe.
- Pour les types quantifiés:
dequantize_op_quantize(abs, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1-C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier signé ou à virgule flottante, ou Tensor quantifié par Tensor | (C1-C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
complex_element_type(element_type(operand))
siis_complex(operand)
.baseline_element_type(operand)
dans les autres cas.
Exemples
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Sémantique
Effectue l'addition par élément de deux Tensors lhs
et rhs
, et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: opérateur logique OU.
- Pour les entiers: addition d'entiers.
- Pour les nombres à virgule flottante:
addition
d'IEEE-754. - Pour les nombres complexes: addition complexe.
- Pour les types quantifiés:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié | (C1-C6) |
(I2) | rhs |
Tensor ou Tensor quantifié | (C1-C5), (C7) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1-C7) |
Contraintes
- Si l'opération utilise des Tensors non quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Si l'opération utilise des Tensors quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Si
is_per_axis_quantized(lhs)
, alorsquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Si
is_per_axis_quantized(rhs)
, alorsquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Sémantique
Assurez-vous que les opérations à l'origine de l'inputs
sont exécutées avant toute
opérations qui dépendent de result
. L'exécution de cette opération n'a aucun effet,
il n'existe que pour établir les dépendances de données de result
à inputs
.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre variadique de token |
Sorties
Nom | Type |
---|---|
result |
token |
Exemples
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, concatène les valeurs
des Tensors operands
de chaque processus le long de all_gather_dim
et produit
results
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est
définis comme suit:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
operands...@receiver = [operand@sender for sender in process_group]
pour tousreceiver
dansprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
pour tousprocess
dansprocess_group
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operands |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1), (C6) |
(I2) | all_gather_dim |
constante de type si64 |
(C1), (C6) |
(I3) | replica_groups |
Constante de Tensor bidimensionnelle de type si64 |
(C2-C4) |
(I4) | channel_id |
constante de type si64 |
(C5) |
(I5) | use_global_device_ids |
constante de type i1 |
(C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C6) |
Contraintes
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Si
use_global_device_ids = true
, alorschannel_id > 0
. - (C6)
type(results...) = type(operands...)
, sauf: <ph type="x-smartling-placeholder">- </ph>
dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, applique une réduction
fonction computation
sur les valeurs des Tensors operands
de chaque processus
et produit des Tensors results
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est
définis comme suit:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
results...@process[result_index] = exec(schedule)
pour une arborescence binaireschedule
où: <ph type="x-smartling-placeholder">- </ph>
exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
est une arborescence binaire définie par l'implémentation, dont l'ordre le balayage estto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operands |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C5), (C6) |
(I2) | replica_groups |
nombre variadique de constantes de Tensor unidimensionnelles de type si64 |
(C1-C3) |
(I3) | channel_id |
constante de type si64 |
(C4) |
(I4) | use_global_device_ids |
constante de type i1 |
(C4) |
(I5) | computation |
fonction | (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C6-C7) |
Contraintes
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Si
use_global_device_ids = true
, alorschannel_id > 0
. - (C5)
computation
est de type(tensor<E>, tensor<E>) -> (tensor<E>)
, oùis_promotable(element_type(operand), E)
- (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, divise les valeurs de
les Tensors operands
le long de split_dimension
en plusieurs parties, dispersent la division
entre les processus, concatène les parties dispersées
concat_dimension
et génère des Tensors results
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est
définis comme suit:
cross_replica(replica_groups)
sichannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Ensuite, dans chaque process_group
:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
pour lessender
deprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
oùreceiver_index = process_group.index(receiver)
results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operands |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1-C3), (C9) |
(I2) | split_dimension |
constante de type si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
constante de type si64 |
(C3), (C9) |
(I4) | split_count |
constante de type si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
Constante de Tensor bidimensionnelle de type si64 |
(C5-C8) |
(I6) | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C9) |
Contraintes
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
sauf, sisplit_dimension != concat_dimension
: <ph type="x-smartling-placeholder">- </ph>
dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
et
Sémantique
Effectue l'opérateur AND par élément de deux Tensors lhs
et rhs
et génère une result
Tensor. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: l'opérateur logique AND.
- Pour les entiers: AND (ET) bit à bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type booléen ou entier | (C1) |
(I2) | rhs |
Tensor de type booléen ou entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Sémantique
Effectue des opérations atan2 au niveau des éléments sur les Tensors lhs
et rhs
, et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
atan2
d'IEEE-754. - Pour les nombres complexes: complexe atan2.
- Pour les types quantifiés:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Sémantique
Calcule les gradients de plusieurs entrées de la rétropropagation de batch_norm_training
.
à partir de grad_output
, et génère grad_operand
, grad_scale
et grad_offset
Tensors. Plus formellement, cette opération peut être exprimée sous la forme d'une décomposition
Opérations StableHLO existantes à l'aide de la syntaxe Python, comme suit:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Pour les types quantifiés, exécute
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1-C3), (C5) |
(I2) | scale |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C4), (C5) |
(I3) | mean |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C4) |
(I4) | variance |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C4) |
(I5) | grad_output |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C2), (C3) |
(I6) | epsilon |
constante de type f32 |
|
(I7) | feature_index |
constante de type si64 |
(C1), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
grad_operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C2), (C3) |
grad_scale |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C4) |
grad_offset |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C4) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
grad_scale
etgrad_offset
ont le mêmebaseline_element_type
. - (C3)
operand
,grad_output
etgrad_operand
ont la même forme. - (C4)
scale
,mean
,variance
,grad_scale
etgrad_offset
ont même forme. - (C5)
size(scale) = dim(operand, feature_index)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Sémantique
Normalise le Tensor operand
pour toutes les dimensions, à l'exception du
feature_index
et génère un Tensor result
. Plus formellement, cela
peut être exprimée sous la forme d'une décomposition en opérations StableHLO existantes
à l'aide de la syntaxe Python suivante:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Pour les types quantifiés, exécute
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1-C7) |
(I2) | scale |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C3) |
(I3) | offset |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C4) |
(I4) | mean |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C5) |
(I5) | variance |
Tensor unidimensionnel de type à virgule flottante ou quantifié par Tensor | (C2), (C6) |
(I6) | epsilon |
constante de type f32 |
|
(I7) | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C2), (C7) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
etresult
ont mêmebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Sémantique
Calcule la moyenne et la variance pour toutes les dimensions, à l'exception de feature_index
et normalise le Tensor operand
générant output
, batch_mean
et batch_var
. Plus formellement, cette opération peut être exprimée sous la forme
décomposition en opérations StableHLO existantes à l'aide de la syntaxe Python en tant que
ce qui suit:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Pour les types quantifiés, exécute
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
(I2) | scale |
Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor | (C2), (C3) |
(I3) | offset |
Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor | (C2), (C4) |
(I4) | epsilon |
constante de type f32 |
(C1), (C3-C6) |
(I5) | feature_index |
constante de type si64 |
(C1), (C3-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C7) |
batch_mean |
Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor | (C2), (C5) |
batch_var |
Tensor unidimensionnel de valeurs à virgule flottante ou quantifié par Tensor | (C2), (C6) |
Contraintes
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
etoutput
ont le mêmebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Exemples
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Sémantique
Effectue une opération de diffusion de bits sur le Tensor operand
et génère un Tensor result
.
où les bits de l'ensemble du Tensor operand
sont réinterprétés à l'aide du
Type de Tensor result
.
Plus formellement, étant donné que E = element_type(operand)
, E' = element_type(result)
,
et R = rank(operand)
:
- Si la valeur est
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
- Si la valeur est
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
- Si la valeur est
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
bits
renvoie la représentation en mémoire d'une valeur donnée et de son comportement.
est définie par l'implémentation, car la représentation exacte des Tensors est
définie par l'implémentation, et la représentation exacte des types d'éléments
définies par l'implémentation.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1-C2) |
Contraintes
- (C1) Avec
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
etR = rank(operand)
: <ph type="x-smartling-placeholder">- </ph>
- Si
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Si la valeur est
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
pour tous les0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Si la valeur est
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
pour tous les0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Si
- (C2) Si
is_complex(operand) or is_complex(result)
, alorsis_complex(operand) and is_complex(result)
Exemples
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Sémantique
Développe les dimensions et/ou le rang d'un Tensor d'entrée en dupliquant les données
dans le Tensor operand
et produit un Tensor result
. Plus formellement,
result[result_index] = operand[operand_index]
où pour les d
de
axes(operand)
:
operand_index[d] = 0
sidim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1), (C3), (C5-C6) |
Contraintes
- (C1) La valeur
element_type(result)
est donnée par: <ph type="x-smartling-placeholder">- </ph>
element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, à l'exception dequantization_dimension(operand)
,scales(operand)
etzero_points(operand)
peuvent être différents dequantization_dimension(result)
,scales(result)
etzero_points(result)
resp., sinon.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Pour tous les
d
deaxes(operand)
: <ph type="x-smartling-placeholder">- </ph>
dim(operand, d) = 1
oudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si
is_per_axis_quantized(result)
: <ph type="x-smartling-placeholder">- </ph>
quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si la valeur est
dim(operand, quantization_dimension(operand)) = 1
, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
Exemples
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
coque
Sémantique
Génère le résultat de l'exécution d'une seule fonction de branches
en fonction de la valeur de index
. Plus formellement, result = selected_branch()
où:
selected_branch = branches[index]
si0 <= index < size(branches)
.selected_branch = branches[-1]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | index |
Tensor à 0 dimensions de type si32 |
|
(I2) | branches |
nombre variadique de fonctions | (C1-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C4) |
Contraintes
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Exemples
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Sémantique
Effectue une opération racine cubique par élément sur le Tensor operand
et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
rootn(x, 3)
d'IEEE-754. - Pour les nombres complexes: racine cubique complexe.
- Pour les types quantifiés:
dequantize_op_quantize(cbrt, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Sémantique
Effectue un ceil par élément du Tensor operand
et produit un Tensor result
.
Elle met en œuvre l'opération roundToIntegralTowardPositive
de la norme IEEE-754.
spécifique. Pour les types quantifiés, exécute
dequantize_op_quantize(ceil, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Sémantique
Calcule la décomposition par Cholesky d'un lot de matrices.
Plus formellement, pour tous les i
de index_space(result)
,
result[i0, ..., iR-3, :, :]
est une décomposition cholesky de
a[i0, ..., iR-3, :, :]
, qui se présente sous la forme d'un triangle inférieur
(si lower
est true
) ou une matrice triangulaire supérieure (si lower
est false
).
Les valeurs de sortie dans le triangle opposé, c'est-à-dire le triangle supérieur strict
le triangle inférieur strict, sont définies par l'implémentation.
S'il existe des valeurs i
où la matrice d'entrée n'est pas de type hermitien
, le comportement n'est pas défini.
Pour les types quantifiés, exécute
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1-C3) |
(I2) | lower |
Constante de Tensor à 0 dimensions de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Exemples
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
limiter
Sémantique
Bloque chaque élément du Tensor operand
entre une valeur minimale et une valeur maximale
et produit un Tensor result
. Plus formellement, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
où min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
Pour les types quantifiés,
exécute dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Ordonner des nombres complexes implique une sémantique surprenante, C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir. pour cette opération (#560).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | min |
Tensor ou Tensor quantifié par Tensor | (C1), (C3) |
(I2) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C4) |
(I3) | max |
Tensor ou Tensor quantifié par Tensor | (C2), (C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C4) |
Contraintes
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoyez la valeur du
Tensor operand
du processus source aux processus cibles et produit une
Tensor result
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est
définis comme suit:
cross_replica(replica_groups)
sichannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Ensuite, result@process
est donné par:
operand@process_groups[i, 0]
s'il existe uni
de sorte que le processus soit dans le pays suivant :process_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
sinon.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C3) |
(I2) | replica_groups |
nombre variadique de constantes de Tensor unidimensionnelles de type si64 |
(C1), (C2) |
(I3) | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C3) |
Contraintes
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, oùN
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C3)
type(result) = type(operand)
.
Exemples
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Sémantique
Dans chaque groupe de processus de la grille de processus StableHLO, envoie la valeur du
Tensor operand
du processus source au processus cible et produit une
Tensor result
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est
définis comme suit:
cross_replica(source_target_pairs)
sichannel_id <= 0
.cross_partition(source_target_pairs)
sichannel_id > 0
.
Ensuite, result@process
est donné par:
operand@process_groups[i, 0]
, s'il existe uni
tel queprocess_groups[i, 1] = process
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
sinon.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C5) |
(I2) | source_target_pairs |
Constante de Tensor bidimensionnelle de type si64 |
(C1-C4) |
(I3) | channel_id |
constante de type si64 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, oùN
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
num_replicas
sicross_replica
est utilisé.num_partitions
sicross_partition
est utilisé.
- (C5)
type(result) = type(operand)
.
Exemples
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Sémantique
Effectue une comparaison par élément des Tensors lhs
et rhs
en fonction
comparison_direction
et compare_type
, et génère un Tensor result
.
Les valeurs de comparison_direction
et compare_type
ont les valeurs suivantes :
sémantique:
Pour les éléments de type booléen et entier:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Pour les types d'éléments à virgule flottante avec compare_type = FLOAT
, l'opération implémente
les opérations IEEE-754 suivantes:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Pour les types d'éléments à virgule flottante avec compare_type = TOTALORDER
, l'opération
utilise la combinaison des opérations totalOrder
et compareQuietEqual
de
IEEE-754.
Pour les types d'éléments complexes, la comparaison lexicographique des paires (real, imag)
est
effectuées à l'aide des méthodes comparison_direction
et compare_type
fournies.
Ordonner des nombres complexes implique une sémantique surprenante,
C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir.
lorsque comparison_direction
est GE
, GT
, LE
ou LT
(560)
Pour les types quantifiés. exécute dequantize_compare(lhs, rhs,
comparison_direction)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1-C3) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1-C2) |
(I3) | comparison_direction |
énumération de EQ , NE , GE , GT , LE et LT |
|
(I4) | compare_type |
énumération de FLOAT , TOTALORDER , SIGNED et UNSIGNED |
(C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen | (C2) |
Contraintes
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
SIGNED
siis_signed_integer(element_type(lhs))
.UNSIGNED
siis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
ouTOTALORDER
siis_float(element_type(lhs))
.FLOAT
siis_complex(element_type(lhs))
.
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complexe
Sémantique
Effectue une conversion au niveau des éléments en une valeur complexe à partir d'une paire de valeurs réelles et
des valeurs imaginaires, lhs
et rhs
, et produit un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type f32 ou f64 |
(C1-C3) |
(I2) | rhs |
Tensor de type f32 ou f64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type complexe | (C2), (C3) |
Contraintes
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
est de typecomplex<E>
, oùE = element_type(lhs)
Exemples
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
composite
Sémantique
Encapsule une opération composée d'autres opérations StableHLO.
en prenant inputs
et composite_attributes
, et en produisant results
. La
de l'opération est implémentée par l'attribut decomposition
. La
L'opération composite
peut être remplacée par sa décomposition sans modifier le programme
la sémantique. Dans les cas où l'intégration de la décomposition ne fournit pas
utilisez plutôt custom_call
.
Le champ version
(par défaut, 0
) est utilisé pour indiquer quand un élément composite est
changement de sémantique.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre variadique de valeurs |
(I2) | name |
constante de type string |
(I3) | composite_attributes |
dictionnaire d'attributs |
(I4) | decomposition |
constante de type string |
(I5) | version |
constante de type si32 |
Sorties
Nom | Type |
---|---|
results |
nombre variadique de valeurs |
Contraintes
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Exemples
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Sémantique
Concatène inputs
avec la dimension dimension
dans le même ordre que celui donné
et génère un Tensor result
. Plus formellement,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, où:
id = d0 + ... + dk-1 + kd
.d
est égal àdimension
, etd0
... correspond à lad
e taille de dimension surinputs
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1-C6) |
(I2) | dimension |
constante de type si64 |
(C2), (C4), (C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C5-C6) |
Contraintes
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
, saufdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, sauf: <ph type="x-smartling-placeholder">- </ph>
dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Exemples
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Sémantique
Génère un Tensor output
à partir d'une constante value
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | value |
constante | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor ou Tensor quantifié | (C1) |
Contraintes
- (C1)
type(value) = type(output)
.
Exemples
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
d'effectuer une conversion
Sémantique
Effectue une conversion au niveau des éléments d'un type d'élément à un autre sur
operand
et produit un Tensor result
.
Pour les conversions boolean-to-any-supported-type, la valeur false
est
convertie en zéro, et la valeur true
en un. Pour
any-supported-type-to-boolean, une valeur nulle est convertie en
false
et les valeurs non nulles sont converties en true
. Découvrez ci-dessous comment cela
fonctionnent pour des
types complexes.
Pour les conversions impliquant entier en entier, entier en virgule flottante ou floating-point-to-floating-point, si la valeur source peut être exactement représentée dans le type "destination", la valeur du résultat correspond exactement représentation. Sinon, le comportement reste à déterminer (#180)
Pour les conversions impliquant une conversion floating-point-to-integer, la partie fractionnaire est la suivante : tronquées. Si la valeur tronquée ne peut pas être représentée dans le type de destination, comportement à déterminer (#180).
Les conversions complexes à complexes suivent le même comportement : les conversions floating-point-to-floating-point pour convertir des valeurs réelles et des parties imaginaires.
Pour les conversions de type complex-to-any-other-type et complex-to-any-other-type : la valeur imaginaire source est ignorée, ou la valeur imaginaire de destination est mis à zéro, respectivement. La conversion de la partie réelle suit les conversions à virgule flottante.
En principe, cette opération peut exprimer une déquantification (conversion à partir de
Tensors quantifiés en Tensors standards), la quantification (conversion des Tensors
en Tensors quantifiés) et requantification (conversion entre des Tensors quantifiés)
mais pour le moment, nous avons des opérations dédiées.
uniform_dequantize
pour le premier cas d'utilisation et uniform_quantize
pour le
deuxième et troisième cas d'utilisation. À l'avenir, ces deux opérations pourront être fusionnées
dans convert
(#1576).
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor | (C1) |
Contraintes
- (C1)
shape(operand) = shape(result)
.
Exemples
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
Convolution
Sémantique
Calcule les produits scalaires entre les fenêtres de lhs
et des tranches de rhs
, puis produit
result
Le schéma suivant montre comment les éléments de result
sont calculés à partir de
lhs
et rhs
à l'aide d'un exemple concret.
Plus formellement, envisagez de recadrer les entrées ci-dessous en termes de lhs
.
pour pouvoir exprimer des fenêtres de lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
lhs_padding = lhs_shape([0, 0], padding, [0, 0])
lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Ce recadrement utilise les fonctions d'assistance suivantes:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
, oùj[d] = i[permutation[d]]
.
Si feature_group_count = 1
et batch_group_count = 1
, alors pour toutes
output_spatial_index
à index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
, où:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
Il semble que cette fonctionnalité ne soit plus utilisée. Nous prévoyons donc de la supprimer à l'avenir. (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Si la valeur est feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
result = concatenate(results, output_feature_dimension)
.
Si la valeur est batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
Pour les types quantifiés, exécute dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Pour les types hybrides quantifiés, exécute hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
Tensor ou Tensor quantifié | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C3), (C25) |
(I4) | padding |
Constante de Tensor bidimensionnelle de type si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Constante de Tensor unidimensionnelle de type si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Constante de Tensor unidimensionnelle de type si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Constante de Tensor unidimensionnelle de type i1 |
(C9) |
(I8) | input_batch_dimension |
constante de type si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
constante de type si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C12), (C13), (C25) |
(I11). | kernel_input_feature_dimension |
constante de type si64 |
(C14), (C18) |
(I12). | kernel_output_feature_dimension |
constante de type si64 |
(C15-C16), (C18), (C25), (C29) |
(I13). | kernel_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C17-C18), (C25) |
(I14). | output_batch_dimension |
constante de type si64 |
(C20), (C25) |
(I15). | output_feature_dimension |
constante de type si64 |
(C20), (C25), (C30) |
(I16). | output_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
constante de type si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18). | batch_group_count |
constante de type si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19). | precision_config |
nombre variadique d'énumérations de DEFAULT , HIGH et HIGHEST |
(C24) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C25-C28), (C30), (C32-34) |
Contraintes
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Avec
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
: <ph type="x-smartling-placeholder">- </ph>
is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Avec
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
: <ph type="x-smartling-placeholder">- </ph>
is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Donnée
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
: <ph type="x-smartling-placeholder">- </ph>
is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
dim(lhs, input_batch_dimension) / batch_group_count
siresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
siresult_dim = output_feature_dimension
.num_windows
dans les autres cas, où:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
rhs_dim = kernel_spatial_dimensions[spatial_dim]
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si l'opération utilise des Tensors non quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si l'opération utilise des Tensors quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si
is_per_axis_quantized(rhs)
, puis surquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, alorsquantization_dimension(result) = output_feature_dimension
- Si la valeur est
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, alorsis_per_tensor_quantized(result)
- Si la valeur est
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Exemples
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosinus
Sémantique
Effectue une opération cosinus par élément sur le Tensor operand
et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
cos
d'IEEE-754. - Pour les nombres complexes: cosinus complexe.
- Pour les types quantifiés:
dequantize_op_quantize(cosine, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Sémantique
Effectue un comptage par élément du nombre de bits zéro au début dans operand
.
et produit un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Sémantique
Encapsule une opération call_target_name
définie par l'implémentation qui accepte
inputs
et called_computations
, et génère results
. has_side_effect
,
backend_config
et api_version
peuvent être utilisés pour fournir
définies par l'implémentation.
Pour le moment, cette opération contient une collection assez désorganisée de de métadonnées qui reflètent l'évolution organique de son fonctionnement équivalent dans le compilateur XLA. À l'avenir, nous prévoyons d'unifier ces métadonnées (#741)
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre variadique de valeurs |
(I2) | call_target_name |
constante de type string |
(I3) | has_side_effect |
constante de type i1 |
(I4) | backend_config |
constante de type string ou dictionnaire d'attributs |
(I5) | api_version |
constante de type si32 |
(I6) | called_computations |
nombre variadique de constantes de type string |
Sorties
Nom | Type |
---|---|
results |
nombre variadique de valeurs |
Exemples
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
diviser
Sémantique
Effectue une division par élément des Tensors lhs
et diviseur rhs
.
et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les entiers: division des entiers, qui produit le quotient algébrique avec n'importe lequel partie fractionnaire supprimée.
- Pour les nombres à virgule flottante:
division
d'IEEE-754. - Pour les nombres complexes: division complexe.
- Pour les types quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Sémantique
Calcule les produits scalaires entre les tranches de lhs
et les tranches de rhs
, et génère une
Tensor result
.
Plus formellement, result[result_index] = dot_product
, où:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
oùsize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
etsize(result_rhs_index) = size(rhs_result_dimensions)
transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
Pour les types quantifiés, exécute dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Pour les types hybrides quantifiés, exécute hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
contrôle le compromis entre vitesse et précision pour
sur les backends d'accélérateur. Il peut s'agir de l'un des éléments suivants (au niveau
Dans un premier temps, la sémantique de ces valeurs d'énumération est sous-spécifiée, mais nous n'avons
prévoient de résoudre
ce problème dans
#755):
DEFAULT
: calcul le plus rapide, mais approximation la moins précise du numéro d'origine.HIGH
: calcul plus lent, mais approximation plus précise du numéro d'origine.HIGHEST
: calcul le plus lent, mais approximation la plus précise du numéro d'origine.
Un DotAlgorithm
définit les principales propriétés de l'algorithme utilisé pour implémenter
l'opération par point, qui définit également la précision. Si l'attribut de l'algorithme
sont définis, l'élément precision_config
doit être défini sur DEFAULT
. DotAlgorithms
n'ont pas de valeur par défaut, car les paramètres par défaut sont
définis. Par conséquent, tous les champs de l'algorithme peuvent être définis sur None
pour spécifier une
algorithme vide avec points, qui utilisera plutôt la valeur precision_config
.
Les champs DotAlgorithm
incluent:
lhs_precision_type
etrhs_precision_type
, la précision des fonctions LHS et à droite de l'opération sont arrondies. Les types de précision sont indépendants des de stockage des entrées et des sorties.accumulation_type
est la précision utilisée pour l'accumulation.lhs_component_count
,rhs_component_count
etnum_primitive_operations
s'appliquent lorsque nous faisons un algorithme qui décompose le LHS et/ou le RHS en plusieurs composants et effectue plusieurs tâches sur ces valeurs : généralement pour émuler une précision plus élevée (par exemple, Exploiter le type de données bfloat16 de l'intelligence artificielle pour effectuer des calculs de plus grande précision: bf16_6x, tf32_3x, etc.). Pour les algorithmes sans décomposition, ces valeurs doit être défini sur1
.allow_imprecise_accumulation
pour spécifier si l'accumulation dans une précision inférieure est autorisé pour certaines étapes (par exemple,CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Exemples d'attributs DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
C'est aux implémentations de décider quelles combinaisons sont acceptées. Dans général, il n'est pas garanti que chaque algorithme soit compatible type d'accélérateur par le consommateur de StableHLO. Si un algorithme donné n'est pas une erreur doit être générée, plutôt que de revenir à une alternative. La vérification StableHLO s'efforcera au mieux, empêchant ainsi les algorithmes qui ne sont pas compatibles avec aucun matériel.
Voir xla_data.proto > Algorithm
pour certaines valeurs d'algorithme acceptées. Le ticket n° 2483 capture le plan pour créer un
document centralisé sur les algorithmes pris
en charge par le backend.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
Tensor ou Tensor quantifié | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
nombre variadique d'énumérations de DEFAULT , HIGH et HIGHEST |
(C11), (C21) |
(I8) | lhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType ou TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType ou TensorFloat32 | (C21) |
(I11). | lhs_component_count |
constante de type si32 |
(C21), (C22) |
(I12). | rhs_component_count |
constante de type si32 |
(C21), (C23) |
(I13). | num_primitive_operations |
constante de type si32 |
(C21), (C24) |
(I14). | allow_imprecise_accumulation |
constante de type bool |
(C21) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C12), (C14), (C18-C20) |
Contraintes
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Si l'opération utilise des Tensors non quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Si l'opération utilise des Tensors quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Si
is_per_axis_quantized(rhs)
, alorsquantization_dimension(rhs)
pas dansrhs_contracting_dimensions
. - Si la valeur est
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Si
is_per_tensor_quantized(rhs)
, alorsis_per_tensor_quantized(result)
- Si la valeur est
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Si la valeur est
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
: <ph type="x-smartling-placeholder">- </ph>
- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Exemples
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Sémantique
Cette opération est fonctionnellement identique à
broadcast_in_dim
op, mais la forme du résultat est spécifiée dynamiquement via output_dimensions
.
L'opération accepte également les attributs facultatifs known_expanding_dimensions
et known_non_expanding_dimensions
.
pour exprimer des connaissances statiques sur le comportement d'expansion des dimensions.
Si aucune valeur n'est spécifiée, toutes les dimensions sont supposées pouvoir s'étendre.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensor unidimensionnel de type entier | (C7) |
(I3) | broadcast_dimensions |
Tensor constant unidimensionnel de type entier | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensor constant unidimensionnel de type entier | (C8-C9) |
(I5) | known_non_expanding_dimensions |
Tensor constant unidimensionnel de type entier | (C8-C9) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1), (C3), (C5-C7) |
Contraintes
- (C1) La valeur
element_type(result)
est donnée par: <ph type="x-smartling-placeholder">- </ph>
element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, à l'exception dequantization_dimension(operand)
,scales(operand)
etzero_points(operand)
peuvent être différents dequantization_dimension(result)
,scales(result)
etzero_points(result)
resp., sinon.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Pour tous les
d
deaxes(operand)
: <ph type="x-smartling-placeholder">- </ph>
dim(operand, d) = 1
oudim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si
is_per_axis_quantized(result)
: <ph type="x-smartling-placeholder">- </ph>
quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si la valeur est
dim(operand, quantization_dimension(operand)) = 1
, alorsscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_non_expanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_non_expanding_dimensions < rank(operand)
.
Exemples
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Sémantique
Cette opération est fonctionnellement identique à
Convolution
op, mais la marge intérieure est spécifiée dynamiquement via padding
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
Tensor ou Tensor quantifié | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensor bidimensionnel de type entier | (C4) |
(I4) | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C3) |
(I5) | lhs_dilation |
Constante de Tensor unidimensionnelle de type si64 |
(C5-C6) |
(I6) | rhs_dilation |
Constante de Tensor unidimensionnelle de type si64 |
(C7-C8) |
(I7) | window_reversal |
Constante de Tensor unidimensionnelle de type i1 |
(C9) |
(I8) | input_batch_dimension |
constante de type si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
constante de type si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C12), (C13) |
(I11). | kernel_input_feature_dimension |
constante de type si64 |
(C14), (C18) |
(I12). | kernel_output_feature_dimension |
constante de type si64 |
(C15-C16), (C18), (C28) |
(I13). | kernel_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C17-C18) |
(I14). | output_batch_dimension |
constante de type si64 |
(C20) |
(I15). | output_feature_dimension |
constante de type si64 |
(C20), (C29) |
(I16). | output_spatial_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C19-C20) |
(I17) | feature_group_count |
constante de type si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18). | batch_group_count |
constante de type si64 |
(C10), (C15), (C22), (C23) |
(I19). | precision_config |
nombre variadique d'énumérations de DEFAULT , HIGH et HIGHEST |
(C24) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C25-C27), (C29), (C31-C33) |
Contraintes
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Avec
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
: <ph type="x-smartling-placeholder">- </ph>
is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Avec
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
: <ph type="x-smartling-placeholder">- </ph>
is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Donnée
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
: <ph type="x-smartling-placeholder">- </ph>
is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
dim(lhs, input_batch_dimension) / batch_group_count
siresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
siresult_dim = output_feature_dimension
.num_windows
dans les autres cas, où:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
rhs_dim = kernel_spatial_dimensions[spatial_dim]
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si l'opération utilise des Tensors non quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si l'opération utilise des Tensors quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si
is_per_axis_quantized(rhs)
, puis surquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, alorsquantization_dimension(result) = output_feature_dimension
- Si la valeur est
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, alorsis_per_tensor_quantized(result)
- Si la valeur est
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Exemples
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Sémantique
Cette opération est fonctionnellement identique à
rassembler
op, avec le slice_sizes
spécifié de manière dynamique en tant que valeur.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
Tensor de type entier | (C2), (C3), (C13) |
(I3) | slice_sizes |
Tensor unidimensionnel de type entier | (C8), (C11-C13) |
(I4) | offset_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Constante de Tensor unidimensionnelle de type si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
constante de type si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
constante de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C5), (C13-C14) |
Contraintes
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
où: <ph type="x-smartling-placeholder">- </ph>
batch_dim_sizes = shape(start_indices)
, sauf que la dimension destart_indices
correspondant àindex_vector_dim
n'est pas inclus.offset_dim_sizes = shape(slice_sizes)
, sauf que les dimensions dansslice_sizes
correspondant àcollapsed_slice_dims
ne sont pas incluses.combine
placebatch_dim_sizes
sur les axes correspondant àbatch_dims
etoffset_dim_sizes
sur les axes correspondant àoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Exemples
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Sémantique
Cette opération est fonctionnellement identique à
iota
op, mais la forme du résultat est spécifiée dynamiquement via output_shape
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | output_shape |
Tensor unidimensionnel de type entier | (C1), (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C2) |
Contraintes
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Exemples
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Sémantique
Cette opération est fonctionnellement identique à
pad
op, mais avec edge_padding_low
, edge_padding_high
et interior_padding
spécifiées dynamiquement en tant que valeurs.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor à 0 dimensions ou Tensor quantifié par Tensor | (C1) |
(I3) | edge_padding_low |
Tensor unidimensionnel de type entier | (C1), (C4) |
(I4) | edge_padding_high |
Tensor unidimensionnel de type entier | (C1), (C4) |
(I5) | interior_padding |
Tensor unidimensionnel de type entier | (C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C3-C6) |
Contraintes
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Exemples
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Sémantique
Cette opération est fonctionnellement identique à
remodeler
op, mais la forme du résultat est spécifiée dynamiquement via output_shape
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C3) |
(I2) | output_shape |
Tensor unidimensionnel de type entier | (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1-C4) |
Contraintes
- (C1) La valeur
element_type(result)
est donnée par: <ph type="x-smartling-placeholder">- </ph>
element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, à l'exception dequantization_dimension(operand)
et Sinon,quantization_dimension(result)
peut être différent.
- (C2)
size(operand) = size(result)
. - (C3) Si
is_per_axis_quantized(operand)
: <ph type="x-smartling-placeholder">- </ph>
reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Sémantique
Extrait une tranche de operand
à l'aide d'index de départ calculés dynamiquement.
et génère un Tensor result
. start_indices
contiennent les index de départ de
la part de chaque dimension susceptible d'être ajustée, et slice_sizes
contenir les tailles du secteur pour chaque dimension. Plus formellement,
result[result_index] = operand[operand_index]
où:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C4) |
(I2) | start_indices |
nombre variadique de Tensors à 0 dimensions de type entier | (C2), (C3) |
(I3) | slice_sizes |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Exemples
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Sémantique
Génère un Tensor result
égal au Tensor operand
, sauf que
la tranche commençant à start_indices
est mise à jour avec les valeurs de update
.
Plus formellement, result[result_index]
est défini comme suit:
update[update_index]
si0 <= update_index < shape(update)
où: <ph type="x-smartling-placeholder">- </ph>
adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C4), (C6) |
(I2) | update |
Tensor ou Tensor quantifié par Tensor | (C2), (C3), (C6) |
(I3) | start_indices |
nombre variadique de Tensors à 0 dimensions de type entier | (C4), (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Exemples
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponentiel
Sémantique
Effectue une opération exponentielle au niveau des éléments sur le Tensor operand
et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
exp
d'IEEE-754. - Pour les nombres complexes: exponentiel complexe
- Pour les types quantifiés:
dequantize_op_quantize(exponential, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Sémantique
Effectue une opération exponentielle par élément moins une sur le Tensor operand
et
et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
expm1
d'IEEE-754. - Pour les nombres complexes: exponentiel complexe moins un.
- Pour les types quantifiés:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Sémantique
Effectue les transformations de Fourier avant et inverses pour des valeurs réelles et complexes d'entrées et de sorties.
fft_type
est l'un des éléments suivants :
FFT
: transfert de données FFT complexe à complexe.IFFT
: FFT complexe à complexe inverse.RFFT
: transfert FFT réel vers complexe.IRFFT
: FFT réel-complexe inverse (il s'agit d'une méthode complexe qui renvoie un résultat réel).
Plus formellement, étant donné la fonction fft
, qui accepte des Tensors unidimensionnels de
les types complexes en entrée, produit des Tensors unidimensionnels des mêmes types que
puis calcule la transformation de Fourier discrète:
Pour fft_type = FFT
, result
est défini comme le résultat final d'une série de L
où L = size(fft_length)
. Par exemple, pour L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
De plus, étant donné la fonction ifft
, qui a la même signature de type et
calcule l'inverse de fft
:
Pour fft_type = IFFT
, result
est défini comme l'inverse des calculs.
pour fft_type = FFT
. Par exemple, pour L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
result[i0, ..., :] = ifft(result2[i0, ..., :])
.
De plus, étant donné la fonction rfft
, qui accepte des Tensors unidimensionnels de
types à virgule flottante produit des Tensors unidimensionnels de types complexes des
la même sémantique à virgule flottante et fonctionne comme suit:
rfft(real_operand) = truncated_result
oùcomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Lorsque la transformation de Fourier discrète est calculée pour des opérandes réels, le premier
Les éléments N/2 + 1
du résultat définissent sans ambiguïté le reste du résultat,
Le résultat de rfft
est donc tronqué pour éviter le calcul d'éléments redondants).
Pour fft_type = RFFT
, result
est défini comme le résultat final d'une série de L
où L = size(fft_length)
. Par exemple, pour L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Enfin, pour la fonction irfft
, qui a le même type de signature et
calcule l'inverse de rfft
:
Pour fft_type = IRFFT
, result
est défini comme l'inverse des calculs.
pour fft_type = RFFT
. Par exemple, pour L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type complexe ou à virgule flottante | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
énumération de FFT , IFFT , RFFT et IRFFT |
(C2), (C5) |
(I3) | fft_length |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C3), (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type complexe ou à virgule flottante | (C2), (C4), (C5) |
Contraintes
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relation entre les types d'éléments
operand
etresult
varie: <ph type="x-smartling-placeholder">- </ph>
- Si
fft_type = FFT
,element_type(operand)
etelement_type(result)
sont du même type complexe. - Si
fft_type = IFFT
,element_type(operand)
etelement_type(result)
sont du même type complexe. - Si la valeur est
fft_type = RFFT
,element_type(operand)
est un type à virgule flottante etelement_type(result)
est un type complexe de la même valeur à virgule flottante la sémantique. - Si la valeur est
fft_type = IRFFT
,element_type(operand)
est un type complexe etelement_type(result)
est un type à virgule flottante de la même valeur la sémantique.
- Si
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Si
operand
etresult
comportent le Tensorreal
d'un type à virgule flottante, puisshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, sauf: <ph type="x-smartling-placeholder">- </ph>
- Si la valeur est
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
- Si la valeur est
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
- Si la valeur est
Exemples
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
sol
Sémantique
Effectue un prix plancher par élément du Tensor operand
et produit un Tensor result
.
Elle met en œuvre l'opération roundToIntegralTowardNegative
de la norme IEEE-754.
spécifique. Pour les types quantifiés, exécute
dequantize_op_quantize(floor, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
rassembler
Sémantique
Regroupe les tranches du Tensor operand
à partir des décalages spécifiés dans start_indices
.
et génère un Tensor result
.
Le schéma suivant montre comment les éléments de result
sont mappés sur les éléments de
operand
à l'aide d'un exemple concret. Le schéma sélectionne quelques exemples de result
index et explique en détail les indices operand
auxquels ils correspondent.
Plus formellement, result[result_index] = operand[operand_index]
, où:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
start_indices[bi0, ..., :, ..., biN]
oùbi
correspond à des éléments individuels dansbatch_index
et:
sont insérés à l'indexindex_vector_dim
, siindex_vector_dim
<rank(start_indices)
[start_indices[batch_index]]
dans les autres cas.
- Pour
d_operand
dansaxes(operand)
, <ph type="x-smartling-placeholder">- </ph>
full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sid_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
dans les autres cas.
- Pour
d_operand
dansaxes(operand)
, <ph type="x-smartling-placeholder">- </ph>
full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_operand = operand_batching_dims[i_batching]
etd_start = start_indices_batching_dims[i_batching]
full_batching_index[d_operand] = 0
dans les autres cas.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
oùoi
sont des individus dansoffset_index
, et0
est inséré au niveau des index decollapsed_slice_dims
etoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
Si indices_are_sorted
est défini sur true
, l'implémentation peut supposer que
Les éléments start_indices
sont triés par rapport à start_index_map
. Dans le cas contraire, les valeurs
ce comportement n'est pas défini. Plus formellement, pour toutes les i1 < i2
de indices(result)
,
full_start_index(i1) <= full_start_index(i2)
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
Tensor de type entier | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C13-C17) |
(I7) | start_index_map |
Constante de Tensor unidimensionnelle de type si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
constante de type si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Constante de Tensor unidimensionnelle de type si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
constante de type i1 |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C5), (C22-C23) |
Contraintes
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
où: <ph type="x-smartling-placeholder">- </ph>
batch_dim_sizes = shape(start_indices)
, sauf que la dimension destart_indices
correspondant àindex_vector_dim
n'est pas inclus.offset_dim_sizes = slice_sizes
, sauf que les dimensions deslice_sizes
correspondant àcollapsed_slice_dims
et Lesoperand_batching_dims
ne sont pas inclus.combine
placebatch_dim_sizes
sur les axes correspondant àbatch_dims
etoffset_dim_sizes
sur les axes correspondant àoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Exemples
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Sémantique
Génère la taille de l'élément dimension
donné de l'élément operand
. Plus formellement,
result = dim(operand, dimension)
La sémantique ne concerne que la forme
composant du type. Le type d'élément peut être n'importe quel élément.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1) |
(I2) | dimension |
constante de type si64 |
(C1) |
Sorties
Nom | Type |
---|---|
result |
Tensor à 0 dimensions de type si32 |
Contraintes
- (C1)
0 <= dimension < rank(operand)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
<ph type="x-smartling-placeholder"></ph>
Sémantique
Extrait l'élément à la position index
du tuple operand
et génère une
result
Plus formellement, result = operand[index]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
constante de type si32 |
(C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tout type compatible | (C2) |
Contraintes
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Exemples
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
si
Sémantique
Génère le résultat de l'exécution d'une seule fonction à partir de true_branch
ou
false_branch
en fonction de la valeur de pred
. Plus formellement, result =
pred ? true_branch() : false_branch()
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | pred |
Tensor à 0 dimensions de type i1 |
|
(I2) | true_branch |
fonction | (C1-C3) |
(I3) | false_branch |
fonction | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C3) |
Contraintes
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Exemples
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
image
Sémantique
Extrait la partie imaginaire, élément par élément, de operand
et génère une
Tensor result
. Plus formellement, pour chaque élément x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type complexe ou à virgule flottante | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
complex_element_type(element_type(operand))
siis_complex(operand)
.element_type(operand)
dans les autres cas.
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
flux d'alimentation
Sémantique
Lit les données du flux d'entrée et génère results
.
La sémantique de infeed_config
est définie par l'implémentation.
results
est constitué de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient
en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux
des sorties distinctes pour plus de clarté
(#670)
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
constante de type string |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C1-C3) |
Contraintes
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Exemples
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
IoTa
Sémantique
Remplit un Tensor output
avec des valeurs par ordre croissant à partir de zéro.
avec la dimension iota_dimension
. Plus formellement,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
0 <= iota_dimension < rank(output)
.
Exemples
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Sémantique
Effectue une vérification au niveau des éléments si la valeur de x
est finie (c'est-à-dire si ni
+Inf, -Inf, or NaN) et génère un Tensor y
. Implémentation de isFinite
de la spécification IEEE-754. Pour les types quantifiés, le résultat est
toujours true
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | x |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
y |
Tensor de type booléen | (C1) |
Contraintes
- (C1)
shape(x) = shape(y)
.
Exemples
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Sémantique
Effectue une opération logarithme par élément sur le Tensor operand
et produit une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
log
d'IEEE-754. - Pour les nombres complexes: logarithme complexe.
- Pour les types quantifiés:
dequantize_op_quantize(log, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Sémantique
Effectue un logarithme par élément plus une opération sur le Tensor operand
.
et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
logp1
d'IEEE-754. - Pour les nombres complexes: logarithme complexe plus un.
- Pour les types quantifiés:
dequantize_op_quantize(log_plus_one, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistique
Sémantique
Effectue des opérations logistiques par élément sur le Tensor operand
et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
division(1, addition(1, exp(-x)))
d'IEEE-754. - Pour les nombres complexes: logistique complexe.
- Pour les types quantifiés:
dequantize_op_quantize(logistic, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
carte
<ph type="x-smartling-placeholder"></ph>
Sémantique
Applique une fonction de carte computation
à inputs
avec dimensions
et
génère un Tensor result
.
Plus formellement, result[result_index] = computation(inputs...[result_index])
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1-C4) |
(I2) | dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C3) |
(I3) | computation |
fonction | (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C4) |
Contraintes
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
est de type(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
oùEi = element_type(inputs[i])
etE' = element_type(result)
.
Exemples
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maximum
Sémantique
Effectue une opération maximale au niveau des éléments sur les Tensors lhs
et rhs
et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: opérateur logique OU.
- Pour les entiers: entier maximum.
- Pour les nombres à virgule flottante:
maximum
d'IEEE-754. - Pour les nombres complexes: maximum lexicographique pour la paire
(real, imaginary)
. Ordonner des nombres complexes implique une sémantique surprenante, C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir. pour cette opération (#560). - Pour les types quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Sémantique
Effectue une opération minimale par élément sur les Tensors lhs
et rhs
et génère une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: l'opérateur logique AND.
- Pour les entiers: entier minimum.
- Pour les nombres à virgule flottante:
minimum
d'IEEE-754. - Pour les nombres complexes: minimum lexicographique pour la paire
(real, imaginary)
. Ordonner des nombres complexes implique une sémantique surprenante, C'est pourquoi nous prévoyons de ne plus prendre en charge les nombres complexes à l'avenir. pour cette opération (#560). - Pour les types quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplier
Sémantique
Effectue le produit élément par élément de deux Tensors lhs
et rhs
, et produit une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: l'opérateur logique AND.
- Pour les entiers: multiplication d'entiers.
- Pour les nombres à virgule flottante:
multiplication
d'IEEE-754. - Pour les nombres complexes: multiplication complexe.
- Pour les types quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
Sémantique
Effectue une négation par élément du Tensor operand
et produit une result
Tensor. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les entiers signés: négation des entiers.
- Pour les entiers non signés: conversion de bits en entier signé, négation d'entier, conversion de bits à un entier non signé.
- Pour les nombres à virgule flottante:
negate
d'IEEE-754. - Pour les nombres complexes: négation complexe.
- Pour les types quantifiés:
dequantize_op_quantize(negate, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
not
Sémantique
Effectue l'opérateur NOT au niveau des éléments du Tensor operand
et génère un Tensor result
.
En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: l'opérateur logique NOT.
- Pour les entiers: l'opérateur NOT au niveau du bit.
Arguments
Nom | Type | Contraintes |
---|---|---|
operand |
Tensor de type booléen ou entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Sémantique
Assurez-vous que les opérations qui produisent l'operand
sont exécutées avant toute
opérations qui dépendent de result
et empêchent les transformations de compilation
de déplacer les opérations au-delà de la barrière. En dehors de cela, l'opération est
une identité, par exemple result = operand
.
Arguments
Nom | Type | Contraintes |
---|---|---|
operand |
nombre variadique de Tensors, Tensors ou jetons quantifiés par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
nombre variadique de Tensors, Tensors ou jetons quantifiés par Tensor | (C1) |
Contraintes
- (C1)
type(operand...) = type(result...)
.
Exemples
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
ou
Sémantique
Effectue l'opération OR par élément de deux Tensors lhs
et rhs
et génère une result
Tensor. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: opérateur logique OU.
- Pour les entiers: OR bit à bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier ou booléen | (C1) |
(I2) | rhs |
Tensor de type entier ou booléen | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier ou booléen | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
sortie
Sémantique
Écrit inputs
dans le flux de sortie et génère un jeton result
.
La sémantique de outfeed_config
est définie par l'implémentation.
Entrées
Libellé | Nom | Type |
---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou de Tensors quantifiés |
(I2) | token |
token |
(I3) | outfeed_config |
constante de type string |
Sorties
Nom | Type |
---|---|
result |
token |
Exemples
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
clavier
Sémantique
Développe operand
avec une marge intérieure autour du Tensor et entre les éléments.
du Tensor avec l'élément padding_value
donné.
edge_padding_low
et edge_padding_high
spécifient la quantité de marge intérieure ajoutée.
à la valeur basse (à côté de l'indice 0) et à la valeur élevée (à côté de l'indice le plus élevé)
chaque dimension respectivement. La quantité de marge intérieure peut être négative,
la valeur absolue de la marge intérieure négative indique le nombre d'éléments à supprimer
de la dimension spécifiée.
interior_padding
spécifie la quantité de marge intérieure ajoutée entre deux
dans chaque dimension, qui ne peuvent pas être négatifs. Une marge intérieure intérieure
avant le remplissage des bords, de sorte que le remplissage négatif du bord supprime les éléments de
l'opérande avec remplissage en intérieur.
Plus formellement, result[result_index]
est défini comme suit:
operand[operand_index]
siresult_index = edge_padding_low + operand_index * (interior_padding + 1)
padding_value
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor à 0 dimensions ou Tensor quantifié par Tensor | (C1) |
(I3) | edge_padding_low |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Constante de Tensor unidimensionnelle de type si64 |
(C1), (C4) |
(I5) | interior_padding |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C3-C6) |
Contraintes
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Exemples
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Sémantique
Génère partition_id
du processus actuel.
Sorties
Nom | Type |
---|---|
result |
Tensor à 0 dimensions de type ui32 |
Exemples
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Sémantique
Effectue un décompte par élément du nombre de bits défini dans le Tensor operand
.
et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(operand) = type(result)
.
Exemples
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
puissance
Sémantique
Effectue une interpolation par élément du Tensor lhs
en fonction du Tensor rhs
et
et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les entiers: exponentielle d'entiers.
- Pour les nombres à virgule flottante:
pow
d'IEEE-754. - Pour les nombres complexes: exponentielle complexe.
- Pour les types quantifiés:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Sémantique
Elle extrait la partie réelle, au niveau des éléments, de operand
et génère une result
.
Tensor. Plus formellement, pour chaque élément x
:
real(x) = is_complex(x) ? real_part(x) : x
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type complexe ou à virgule flottante | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
complex_element_type(element_type(operand))
siis_complex(operand)
.element_type(operand)
dans les autres cas.
Exemples
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
réception
Sémantique
Reçoit les données d'un canal avec channel_id
et génère results
.
Si is_host_transfer
est défini sur true
, l'opération transfère les données du
hôte. Sinon, il transfère les données d'un autre appareil. Cela signifie que
définies par l'implémentation. Cet indicateur duplique
les informations fournies dans
channel_type
. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir.
(#666)
results
est constitué de valeurs de charge utile qui apparaissent en premier et d'un jeton qui vient
en dernier. À l'avenir, nous prévoyons de diviser la charge utile et le jeton en deux
des sorties distinctes pour plus de clarté
(#670)
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
constante de type si64 |
|
(I3) | channel_type |
énumération de DEVICE_TO_DEVICE et HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
constante de type i1 |
(C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C2-C4) |
Contraintes
- (C1)
channel_type
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
HOST_TO_DEVICE
siis_host_transfer = true
,DEVICE_TO_DEVICE
dans les autres cas.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ouis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Exemples
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Sémantique
Applique une fonction de réduction body
à inputs
et init_values
le long de
dimensions
et produit des Tensors results
.
L'ordre des réductions est défini par l'implémentation, ce qui signifie que body
et
init_values
doit former un monoid pour garantir que l'opération produit le
les mêmes résultats pour toutes les entrées
dans toutes les implémentations. Cependant, cette condition
n'est pas valable pour de nombreuses réductions populaires. Exemple : l'addition à virgule flottante pour
body
et zéro pour init_values
ne forment pas un monooïde, car
l'addition à virgule flottante n'est pas associative.
Plus formellement, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, où:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, où:
sont insérés àdimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
pour une arborescence binaireschedule
où: <ph type="x-smartling-placeholder">- </ph>
exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
est une arborescence binaire complète définie par l'implémentation, dont l'ordre le balayage comprend: <ph type="x-smartling-placeholder">- </ph>
input_slices_converted...[index]
valeurs, pour tous lesindex
deindex_space(input_slices_converted)
dans l'ordre lexicographique croissant surindex
.- Entrepôt de quantité définie par l'implémentation
init_values_converted
aux positions définies par l'implémentation
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1-C4), (C6), (C7) |
(I2) | init_values |
nombre variadique de Tensors à 0 dimensions ou de Tensors quantifiés par Tensor | (C2), (C3) |
(I3) | dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C4), (C5), (C7) |
(I4) | body |
fonction | (C6) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C3), (C7), (C8) |
Contraintes
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
oùis_promotable(element_type(inputs[i]), Ei)
- (C7)
shape(results...) = shape(inputs...)
, à la différence près que la dimension Les tailles deinputs...
correspondant àdimensions
ne sont pas incluses. - (C8)
element_type(results[i]) = Ei
pour tous lesi
de[0,N)
.
Exemples
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Sémantique
Effectue une conversion de operand
au niveau des éléments en un autre type à virgule flottante
qui utilise exponent_bits
et mantissa_bits
, puis revient à la version d'origine
à virgule flottante et génère un Tensor output
.
Plus formellement:
- Les bits de mantisse de la valeur d'origine sont mis à jour pour arrondir la valeur d'origine.
à la valeur la plus proche représentable par
mantissa_bits
en utilisantroundToIntegralTiesToEven
. - Ensuite, si les valeurs
mantissa_bits
sont inférieures au nombre de bits de mantisse de la valeur d'origine, les bits de mantisse sont tronqués àmantissa_bits
. - Ensuite, si les bits d'exposant du résultat intermédiaire ne correspondent pas au
plage fournie par
exponent_bits
, le résultat intermédiaire déborde sur l'infini à l'aide du signe d'origine ou le dépassement de zéro à zéro à l'aide de la panneau d'origine. - Pour les types quantifiés, exécute
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
(I2) | exponent_bits |
constante de type si32 |
(C2) |
(I3) | mantissa_bits |
constante de type si32 |
(C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Exemples
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Sémantique
Dans chaque groupe de processus de la grille StableHLO, effectue une réduction,
en utilisant computations
sur les valeurs du Tensor operand
de chaque processus,
divise le résultat de la réduction avec scatter_dimension
en plusieurs parties, puis disperse le résultat
les parties divisées entre les processus pour produire le result
.
L'opération divise la grille de processus StableHLO en process_groups
, qui est
définis comme suit:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Ensuite, dans chaque process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
pour lessender
deprocess_group
, oùreceiver_index = process_group.index(receiver)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
constante de type si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
Constante de Tensor bidimensionnelle de type si64 |
(C3-C5) |
(I4) | channel_id |
constante de type si64 |
(C6) |
(I5) | use_global_device_ids |
constante de type i1 |
(C6) |
(I6) | computation |
fonction | (C7) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C8-C9) |
Contraintes
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
num_replicas
sicross_replica
est utilisé.num_replicas
sicross_replica_and_partition
est utilisé.num_processes
siflattened_ids
est utilisé.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Si
use_global_device_ids = true
, alorschannel_id > 0
. - (C7)
computation
est de type(tensor<E>, tensor<E>) -> (tensor<E>)
, oùis_promotable(element_type(operand), E)
- (C8)
shape(result) = shape(operand)
, sauf: <ph type="x-smartling-placeholder">- </ph>
dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Exemples
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Sémantique
Applique une fonction de réduction body
aux fenêtres de inputs
et init_values
et génère results
.
Le schéma suivant montre comment les éléments de results...
sont calculés à partir de
inputs...
à l'aide d'un exemple concret.
Plus formellement,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(voir reduce) :
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
window_end = window_start + (window_dimensions - 1) * window_dilations + 1
windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
nombre variadique de Tensors à 0 dimensions ou de Tensors quantifiés par Tensor | (C1), (C13) |
(I3) | window_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Constante de Tensor unidimensionnelle de type si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Constante de Tensor unidimensionnelle de type si64 |
(C10), (C11), (C15) |
(I7) | padding |
Constante de Tensor bidimensionnelle de type si64 |
(C12), (C15) |
(I8) | body |
fonction | (C13) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1), (C14-C16) |
Contraintes
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
oùis_promotable(element_type(inputs[i]), Ei)
- (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
où: <ph type="x-smartling-placeholder">- </ph>
dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
pour l'ensemble desi
de[0,N)
.
Exemples
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
reste
Sémantique
Effectue le reste par élément des Tensors lhs
et diviseur rhs
et
et génère un Tensor result
.
Plus formellement, le signe du résultat est tiré du dividende, et le
la valeur absolue du résultat est toujours inférieure à la valeur absolue du diviseur.
Le reste est calculé comme suit : lhs - d * rhs
, où d
est calculé comme suit :
- Pour les entiers:
stablehlo.divide(lhs, rhs)
. - Pour les nombres à virgule flottante:
division(lhs, rhs)
conformément à la norme IEEE-754 avec attribut d'arrondiroundTowardZero
- Nombres complexes: à déterminer (#997)
- Pour les types quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Pour les types d'éléments à virgule flottante, cette opération contraste avec la méthode
Opération remainder
de la spécification IEEE-754 où d
est une valeur intégrale
le plus proche de la valeur exacte de lhs/rhs
, avec des liens au nombre pair.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Sémantique
Génère replica_id
du processus actuel.
Sorties
Nom | Type |
---|---|
result |
Tensor à 0 dimensions de type ui32 |
Exemples
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
remodeler
Sémantique
Effectue un remodelage du Tensor operand
en un Tensor result
. Conceptuellement, il
revient à conserver la même représentation canonique, mais peut être amenée
la forme (p.ex., de tensor<2x3xf32>
à tensor<3x2xf32>
ou tensor<6xf32>
.
Plus formellement, result[result_index] = operand[operand_index]
où
result_index
et operand_index
ont la même position dans le texte lexicographique
l'ordre de index_space(result)
et index_space(operand)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1-C3) |
Contraintes
- (C1) La valeur
element_type(result)
est donnée par: <ph type="x-smartling-placeholder">- </ph>
element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, à l'exception dequantization_dimension(operand)
et Sinon,quantization_dimension(result)
peut être différent.
- (C2)
size(operand) = size(result)
. - (C3) Si
is_per_axis_quantized(operand)
: <ph type="x-smartling-placeholder">- </ph>
reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Exemples
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
inverser
Sémantique
Inverse l'ordre des éléments dans operand
le long de la dimensions
spécifiée.
et génère un Tensor result
. Plus formellement,
result[result_index] = operand[operand_index]
où:
operand_index[d] = dim(result, d) - result_index[d] - 1
sid
dansdimensions
.operand_index[d] = result_index[d]
dans les autres cas.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1), (C3) |
(I2) | dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C3) |
Contraintes
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Exemples
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
<ph type="x-smartling-placeholder"></ph>
Sémantique
Génère des nombres aléatoires à l'aide de l'algorithme rng_distribution
et produit un
Tensor result
d'une forme donnée shape
.
Si la valeur est rng_distribution = UNIFORM
, les nombres aléatoires sont générés.
suivant la distribution uniforme sur l'intervalle [a, b)
. Si la valeur est a >= b
,
le comportement n'est pas défini.
Si la valeur est rng_distribution = NORMAL
, les nombres aléatoires sont générés.
suivant la distribution normale avec une moyenne = a
et l'écart type = b
.
Si la valeur est b < 0
, le comportement n'est pas défini.
La manière exacte dont les nombres aléatoires sont générés est définie par l'implémentation. Pour Par exemple, ils peuvent être déterministes ou non, et utiliser ou non état caché.
Lors des conversations avec de nombreux intervenants, cette opération s'est révélée aussi efficace Nous prévoyons donc de les supprimer à l'avenir. (#597)
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor à zéro dimension de type entier, booléen ou à virgule flottante | (C1), (C2) |
(I2) | b |
Tensor à zéro dimension de type entier, booléen ou à virgule flottante | (C1), (C2) |
(I3) | shape |
Constante de Tensor unidimensionnelle de type si64 |
(C3) |
(I4) | rng_distribution |
énumération de UNIFORM et NORMAL |
(C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, booléen ou à virgule flottante | (C1-C3) |
Contraintes
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Si
rng_distribution = NORMAL
, alorsis_float(a)
. - (C3)
shape(result) = shape
.
Exemples
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Sémantique
Renvoie une output
remplie de bits aléatoires uniformes et un état de sortie mis à jour.
output_state
à l'aide de l'algorithme de génération de nombres pseudo-aléatoires rng_algorithm
avec un état initial initial_state
. La sortie est garantie
fonction déterministe de initial_state
, mais sa valeur n'est pas garantie
déterministe entre les implémentations.
rng_algorithm
est l'un des éléments suivants :
DEFAULT
: algorithme défini par l'implémentation.THREE_FRY
: variante définie par l'implémentation de l'algorithme Threefry*.PHILOX
: variante définie par l'implémentation de l'algorithme Philox*.
* Voir Salmon et al. SC 2011. Nombres aléatoires parallèles: aussi simples que 1, 2, 3.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | rng_algorithm |
énumération de DEFAULT , THREE_FRY et PHILOX |
(C2) |
(I2) | initial_state |
Tensor unidimensionnel de type ui64 |
(C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
output_state |
Tensor unidimensionnel de type ui64 |
(C1) |
output |
Tensor de type entier ou à virgule flottante |
Contraintes
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
- définie par l'implémentation si
rng_algorithm = DEFAULT
. 2
sirng_algorithm = THREE_FRY
.2
ou3
sirng_algorithm = PHILOX
.
- définie par l'implémentation si
Exemples
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Sémantique
Effectue un arrondi au niveau des éléments vers l'entier le plus proche, ce qui permet de séparer les relations
à partir de zéro sur le Tensor operand
et produit un Tensor result
. Implémentations
l'opération roundToIntegralTiesToAway
de la spécification IEEE-754. Pour
quantifiés, effectue
dequantize_op_quantize(round_nearest_afz, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Sémantique
Effectue un arrondi au niveau des éléments vers le nombre entier le plus proche, ce qui permet de casser les liens
vers l'entier pair sur le Tensor operand
et produit une result
Tensor. Elle met en œuvre l'opération roundToIntegralTiesToEven
de la norme IEEE-754.
spécifique. Pour les types quantifiés, exécute
dequantize_op_quantize(round_nearest_even, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Sémantique
Effectue une opération de racine carrée réciproque par élément sur le Tensor operand
et
et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
rSqrt
d'IEEE-754. - Pour les nombres complexes: racine carrée réciproque complexe.
- Pour les types quantifiés:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
disperser
Sémantique
Génère des Tensors results
égaux aux Tensors inputs
, sauf que
plusieurs tranches spécifiées par scatter_indices
sont mises à jour avec les valeurs
updates
avec update_computation
.
Le schéma suivant montre comment les éléments de updates...
sont mappés sur les éléments de
results...
à l'aide d'un exemple concret. Le diagramme choisit quelques exemples
updates...
indexe et explique en détail les indices results...
qu'il
correspondent.
Plus formellement, pour tous les update_index
de index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
scatter_indices[si0, ..., :, ..., siN]
oùsi
sont des individus dansupdate_scatter_index
et:
est insérée au niveau Indexindex_vector_dim
, siindex_vector_dim
<rank(scatter_indices)
[scatter_indices[update_scatter_index]]
dans les autres cas.
- Pour
d_input
dansaxes(inputs[0])
, <ph type="x-smartling-placeholder">- </ph>
full_start_index[d_input] = start_index[d_start]
sid_input = scatter_dims_to_operand_dims[d_start]
full_start_index[d_input] = 0
dans les autres cas.
- Pour
d_input
dansaxes(inputs[0])
, <ph type="x-smartling-placeholder">- </ph>
full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_input = input_batching_dims[i_batching]
etd_start = scatter_indices_batching_dims[i_batching]
full_batching_index[d_input] = 0
dans les autres cas.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
oùwi
sont des individus dansupdate_window_index
, et0
est inséré au niveau des index deinserted_window_dims
etinput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Par conséquent, results = exec(schedule, inputs)
, où:
schedule
est une permutation définie par l'implémentation deindex_space(updates[0])
exec([update_index, ...], results) = exec([...], updated_results)
où: <ph type="x-smartling-placeholder">- </ph>
- Si
result_index
est dans les limites deshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
est une copie deresults
avecresults...[result_index]
définie surupdated_values...
.- Sinon, procédez comme suit :
updated_results = results
.
- Si
exec([], results) = results
.
Si indices_are_sorted
est défini sur true
, l'implémentation peut supposer que
Les éléments scatter_indices
sont triés par rapport à scatter_dims_to_operand_dims
,
sinon le comportement n'est pas défini. Plus formellement, pour tous les i1 < i2
de
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Si unique_indices
est défini sur true
, l'implémentation peut supposer que tous les
Les index result_index
dispersés sont uniques. Si unique_indices
correspond à
true
, mais les index dispersés ne sont pas uniques, alors le comportement est
non défini.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
Tensor de type entier | (C4), (C15), (C19), (C22) |
(I3) | updates |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C3-C6), (C8) |
(I4) | update_window_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C14-C18). |
(I8) | scatter_dims_to_operand_dims |
Constante de Tensor unidimensionnelle de type si64 |
(C19-C21) |
(I9) | index_vector_dim |
constante de type si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
constante de type i1 |
|
(I11). | unique_indices |
constante de type i1 |
|
(I12). | update_computation |
fonction | (C23) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C24-C25) |
Contraintes
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
<ph type="x-smartling-placeholder">
- </ph>
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
où: <ph type="x-smartling-placeholder">- </ph>
update_scatter_dim_sizes = shape(scatter_indices)
sauf que la dimension descatter_indices
correspondant àindex_vector_dim
n'est pas inclus.update_window_dim_sizes <= shape(inputs[0])
sauf que les dimensions deinputs[0]
correspondant àinserted_window_dims
etinput_batching_dims
ne sont pas inclus.combine
placeupdate_scatter_dim_sizes
sur les axes correspondant àupdate_scatter_dims
etupdate_window_dim_sizes
au niveau des axes correspondant àupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
est de type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, oùis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
pour l'ensemble desi
de[0,N)
.
Exemples
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
select
Sémantique
Elle génère un Tensor result
où chaque élément est sélectionné à partir de on_true
ou
Tensor on_false
basé sur la valeur de l'élément correspondant de pred
.
Plus formellement, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, où pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Pour les types quantifiés, exécute
dequantize_select_quantize(pred, on_true, on_false, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | pred |
Tensor de type i1 |
(C1) |
(I2) | on_true |
Tensor ou Tensor quantifié par Tensor | (C1-C2) |
(I3) | on_false |
Tensor ou Tensor quantifié par Tensor | (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C2) |
Contraintes
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Exemples
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Sémantique
Distribue les valeurs du Tensor source
à l'aide de scatter
en fonction du
résultat de la fonction reduce_window
du Tensor input
à l'aide de select
et produit
un Tensor result
.
Le schéma suivant montre comment les éléments de result
sont calculés à partir de
operand
et source
à l'aide d'un exemple concret.
Plus formellement:
selected_values = reduce_window_without_init(...)
avec les entrées suivantes:inputs = [operand].
window_dimensions
,window_strides
etpadding
, qui sont utilisés tels quels.base_dilations = windows_dilations = 1
.body
est défini comme suit:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
où
E = element_type(operand)
etreduce_window_without_init
fonctionnent exactement commereduce_window
, sauf que leschedule
de l'instancereduce
(voir reduce) n'inclut pas les valeurs init. Il est actuellement non spécifié, que se passe-t-il si la fenêtre correspondante ne comporte pas de valeurs ? (#731)result[result_index] = reduce([source_values], [init_value], [0], scatter)
où:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
siselected_values[source_index]
comporte l'élémentoperand
à partir deoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C4), (C6), (C8-C11) |
(I2) | source |
Tensor ou Tensor quantifié par Tensor | (C1), (C2) |
(I3) | init_value |
Tensor à 0 dimensions ou Tensor quantifié par Tensor | (C3) |
(I4) | window_dimensions |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C6), (C7) |
(I6) | padding |
Constante de Tensor bidimensionnelle de type si64 |
(C2), (C8) |
(I7) | select |
fonction | (C9) |
(I8) | scatter |
fonction | (C10) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C11-C12) |
Contraintes
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
où: <ph type="x-smartling-placeholder">- </ph>
padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
est de type(tensor<E>, tensor<E>) -> tensor<i1>
, oùE = element_type(operand)
- (C10)
scatter
est de type(tensor<E>, tensor<E>) -> tensor<E>
, oùis_promotable(element_type(operand), E)
- (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Exemples
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
envoyer
Sémantique
Envoie inputs
à un canal channel_id
et génère un jeton result
.
Si is_host_transfer
est défini sur true
, l'opération transfère les données vers
hôte. Sinon, il transfère les données vers un autre appareil. Cela signifie que
définies par l'implémentation. Cet indicateur duplique
les informations fournies dans
channel_type
. Nous prévoyons donc de n'en conserver qu'un seul à l'avenir.
(#666)
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou de Tensors quantifiés | |
(I2) | token |
token |
|
(I3) | channel_id |
constante de type si64 |
|
(I4) | channel_type |
énumération de DEVICE_TO_DEVICE et DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
constante de type i1 |
(C1) |
Sorties
Nom | Type |
---|---|
result |
token |
Contraintes
- (C1)
channel_type
est défini comme suit: <ph type="x-smartling-placeholder">- </ph>
DEVICE_TO_HOST
siis_host_transfer = true
,DEVICE_TO_DEVICE
dans les autres cas.
Exemples
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Sémantique
Effectue un décalage vers la gauche par élément sur le Tensor lhs
en fonction du nombre de rhs
de bits et génère un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier | (C1) |
(I2) | rhs |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Sémantique
Effectue un décalage arithmétique vers la droite au niveau des éléments sur le Tensor lhs
en
rhs
de bits et produit un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier | (C1) |
(I2) | rhs |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Sémantique
Effectue un décalage logique vers la droite au niveau des éléments sur le Tensor lhs
par rhs
de bits et produit un Tensor result
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier | (C1) |
(I2) | rhs |
Tensor de type entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
signe "=".
Sémantique
Renvoie le signe de l'élément operand
au niveau des éléments et produit un Tensor result
.
Plus formellement, pour chaque élément x
, la sémantique peut être exprimée à l'aide de
Syntaxe Python comme suit:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Pour les types quantifiés, exécute
dequantize_op_quantize(sign, operand, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier signé, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Sémantique
Effectue une opération sinus par élément sur le Tensor operand
et génère une result
.
Tensor. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
sin
d'IEEE-754. - Pour les nombres complexes: sinus complexe.
- Pour les types quantifiés:
dequantize_op_quantize(sine, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Sémantique
Extrait une tranche de operand
à l'aide d'index de départ calculés de manière statique.
et génère un Tensor result
. start_indices
contiennent les index de départ de
la tranche de chaque dimension, limit_indices
, contient les index de fin ;
(exclusif) pour le secteur de chaque dimension et strides
contiennent les pas
pour chaque dimension.
Plus formellement, result[result_index] = operand[operand_index]
où
operand_index = start_indices + result_index * strides
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié par Tensor | (C1-C3), (C5) |
(I2) | start_indices |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C3), (C5) |
(I4) | strides |
Constante de Tensor unidimensionnelle de type si64 |
(C2), (C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié par Tensor | (C1), (C5) |
Contraintes
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Exemples
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
trier
Sémantique
Trie les tranches unidimensionnelles de inputs
le long de la dimension dimension
.
selon un comparator
et produit results
.
Contrairement aux entrées similaires dans d'autres opérations, dimension
autorise les valeurs négatives,
à l'aide de la sémantique décrite ci-dessous. À l'avenir, cela pourra être refusé
pour des raisons de cohérence
(#1377)
Si is_stable
est "true", le tri est stable, c'est-à-dire l'ordre relatif des
éléments considérés comme égaux par le comparateur est conservé. Pour l'étui
où il n'y a qu'une seule entrée, deux éléments e1
et e2
sont considérés comme
par le comparateur si et seulement si
comparator(e1, e2) = comparator(e2, e1) = false
Consultez la formalisation ci-dessous.
sur la généralisation à plusieurs entrées.
Plus formellement, pour tous les result_index
de index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
oùriN
sont des individus éléments dansresult_index
, et:
est inséré au niveau deadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- où
sort
trie un segment unidimensionnel dans l'ordre non décroissant, quecomparator_together
renvoietrue
si l'argument de gauche est inférieur au second argument de droite. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | inputs |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C1-C5) |
(I2) | dimension |
constante de type si64 |
(C4) |
(I3) | is_stable |
constante de type i1 |
|
(I4) | comparator |
fonction | (C5) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors ou Tensors quantifiés par Tensor | (C2), (C3) |
Contraintes
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, oùR = rank(inputs[0])
. - (C5)
comparator
a un type(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, oùEi = element_type(inputs[i])
.
Exemples
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
carré
Sémantique
Effectue une opération de racine carrée par élément sur le Tensor operand
et produit une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
squareRoot
d'IEEE-754. - Pour les nombres complexes: racine carrée complexe.
- Pour les types quantifiés:
dequantize_op_quantize(sqrt, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Sémantique
Effectue la soustraction par élément de deux Tensors lhs
et rhs
, et produit une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les entiers: soustraction d'entiers.
- Pour les nombres à virgule flottante:
subtraction
d'IEEE-754. - Pour les nombres complexes: soustraction complexe.
- Pour les types quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
(I2) | rhs |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type entier, à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Exemples
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Sémantique
Effectue une opération de tangente par élément sur le Tensor operand
et produit une
Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
tan
d'IEEE-754. - Pour les nombres complexes: tangente complexe.
- Pour les types quantifiés:
dequantize_op_quantize(tan, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
Tanh
Sémantique
Effectue une opération de tangente hyperbolique par élément sur le Tensor operand
et
et génère un Tensor result
. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les nombres à virgule flottante:
tanh
d'IEEE-754. - Pour les nombres complexes: tangente hyperbolique complexe.
- Pour les types quantifiés:
<ph type="x-smartling-placeholder">
- </ph>
dequantize_op_quantize(tanh, operand, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_type(operand) = baseline_type(result)
.
Exemples
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transposer
Sémantique
Permute les dimensions du Tensor operand
à l'aide de permutation
et génère une
Tensor result
. Plus formellement, result[result_index] = operand[operand_index]
où result_index[d] = operand_index[permutation[d]]
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor ou Tensor quantifié | (C1-C4) |
(I2) | permutation |
Constante de Tensor unidimensionnelle de type si64 |
(C2-C4) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor ou Tensor quantifié | (C1), (C3-C4) |
Contraintes
- (C1) La valeur
element_type(result)
est donnée par: <ph type="x-smartling-placeholder">- </ph>
element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, à l'exception dequantization_dimension(operand)
et Sinon,quantization_dimension(result)
peut être différent.
- (C2)
permutation
est une permutation derange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Si
is_per_axis_quantized(result)
, alorsquantization_dimension(operand) = permutation(quantization_dimension(result))
Exemples
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Sémantique
Résoudre des lots de systèmes d'équations linéaires avec des triangles inférieurs ou supérieurs matrices de coefficients.
Plus formellement, étant donné a
et b
, result[i0, ..., iR-3, :, :]
est la solution.
vers op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
quand left_side
est
true
ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
lorsque
left_side
correspond à false
, ce qui résout la variable x
où op(a)
est déterminé
par transpose_a
, qui peut prendre l'une des valeurs suivantes:
NO_TRANSPOSE
: effectue l'opération à l'aide dea
tel quel.TRANSPOSE
: effectue une opération sur la transposition dea
.ADJOINT
: effectue l'opération sur la transposition conjuguée dea
.
Les données d'entrée sont lues uniquement à partir du triangle inférieur de a
, si lower
est true
ou
triangle supérieur de a
, sinon. Les données de sortie sont renvoyées dans le même triangle.
les valeurs de l'autre triangle sont définies par l'implémentation.
Si unit_diagonal
est "true", l'implémentation peut supposer que la diagonale
éléments de a
sont égaux à 1. Sinon, le comportement n'est pas défini.
Pour les types quantifiés, exécute
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | a |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1-C3) |
(I2) | b |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1-C4) |
(I3) | left_side |
constante de type i1 |
(C3) |
(I4) | lower |
constante de type i1 |
|
(I5) | unit_diagonal |
constante de type i1 |
|
(I6) | transpose_a |
énumération de NO_TRANSPOSE , TRANSPOSE et ADJOINT |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante ou complexe, ou Tensor quantifié par Tensor | (C1) |
Contraintes
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relation entre
shape(a)
etshape(b)
est définie comme suit: <ph type="x-smartling-placeholder">- </ph>
shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Exemples
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
<ph type="x-smartling-placeholder"></ph>
Sémantique
Génère un tuple result
à partir des valeurs val
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | val |
nombre variadique de valeurs | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
tuple | (C1) |
Contraintes
- (C1)
result
est de typetuple<E0, ..., EN-1>
, oùEi = type(val[i])
.
Exemples
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Sémantique
Effectue la conversion élément par élément du Tensor quantifié operand
en
Tensor à virgule flottante result
en fonction des paramètres de quantification définis
par le type operand
.
Plus formellement, result = dequantize(operand)
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor quantifié | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type à virgule flottante | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Exemples
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Sémantique
Effectue une conversion élément par élément du Tensor à virgule flottante ou quantifié
operand
sur un Tensor quantifié result
en fonction de la quantification
les paramètres définis par le type result
.
Plus formellement,
- Si la valeur est
is_float(operand)
: <ph type="x-smartling-placeholder">- </ph>
result = quantize(operand, type(result))
.
- Si la valeur est
is_quantized(operand)
: <ph type="x-smartling-placeholder">- </ph>
float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
Tensor de type à virgule flottante ou quantifié | (C1), (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor quantifié | (C1), (C2) |
Contraintes
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Exemples
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
tandis que
Sémantique
Génère le résultat de l'exécution de la fonction body
zéro fois ou plus, tandis que la
La fonction cond
génère true
. Plus formellement, la sémantique peut être exprimée
à l'aide de la syntaxe Python suivante:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Le comportement d'une boucle infinie est à déterminer (#383)
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | operand |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C1-C3) |
(I2) | cond |
fonction | (C1) |
(I3) | body |
fonction | (C2) |
Sorties
Nom | Type | Contraintes |
---|---|---|
results |
nombre variadique de Tensors, Tensors quantifiés ou jetons | (C3) |
Contraintes
- (C1)
cond
est de type(T0, ..., TN-1) -> tensor<i1>
, oùTi = type(operand[i])
- (C2)
body
est de type(T0, ..., TN-1) -> (T0, ..., TN-1)
, oùTi = type(operand[i])
- (C3)
type(results...) = type(operand...)
.
Exemples
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
XOR
Sémantique
Effectue une opération XOR par élément de deux Tensors lhs
et rhs
, et produit une result
Tensor. En fonction du type d'élément, effectue les opérations suivantes:
- Pour les valeurs booléennes: opérateur logique "XOR".
- Pour les entiers: opération XOR (OU exclusif) bit à bit.
Entrées
Libellé | Nom | Type | Contraintes |
---|---|---|---|
(I1) | lhs |
Tensor de type booléen ou entier | (C1) |
(I2) | rhs |
Tensor de type booléen ou entier | (C1) |
Sorties
Nom | Type | Contraintes |
---|---|---|
result |
Tensor de type booléen ou entier | (C1) |
Contraintes
- (C1)
type(lhs) = type(rhs) = type(result)
.
Exemples
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Dialect Interop
À l'heure actuelle, les programmes StableHLO contenant parfois des opérations qui ne sont pas définis par StableHLO.
Module, fonction, appel et retour
StableHLO utilise les opérations MLIR en amont pour ModuleOp, FuncOp, CallOp et RenvoiOp. Cela a permis d'améliorer l'interopérabilité avec les machines MLIR existantes, car de nombreux des passes utiles sont écrites ciblant FuncOp et ModuleOp, et de nombreuses compilations les pipelines s'attendent à ce que ces opérations soient présentes. Les garanties de compatibilité totale sont appliquée à ces opérations. En cas de changement dans ces opérations incompatible (par exemple, suppression), des équivalents StableHLO seront ajoutés pour préserver et la compatibilité avec d'autres appareils.
CHLO
L'ensemble d'opérations CHLO contient des opérations de niveau supérieur qui se décomposent en StableHLO. Il n'existe actuellement aucune garantie de compatibilité pour CHLO. Compatibilité les garanties chlo-legalize-to-stablehlo doit être utilisée avant la sérialisation.
Opérations de tracé
Il est courant au sein de la communauté d'avoir recours à certaines opérations
Dialectes MLIR utilisés dans les programmes StableHLO dynamiques pour effectuer des calculs de forme.
Le plus souvent, il s'agit du dialecte shape
.
Opérations telles que shape_of
ou num_elements
, dialecte tensor
comme dim
ou from_elements
, et le type index
intégré.
Le document Dynamism RFC > O2
les indique comme étant hors du champ d'application, mais la prise en charge des types index
est
incluses à des fins d'interopérabilité. Il n'existe aucune garantie de compatibilité
opérations ou types. La commande shape-legalize-to-stablehlo
peut servir à convertir ces opérations en opérations StableHLO entièrement compatibles.
Opérations obsolètes
Plusieurs opérations StableHLO ont été héritées MHLO qui sont obsolètes et seront bientôt supprimés de StableHLO. Les détails complets de ces sont indiquées dans le fichier StableHLO v1.0 Cleanup #2283. Le problème de l'outil de suivi pour ces abandons est le n° 2340.
Ces opérations appartiennent à plusieurs catégories:
- "Pas dans HLO" des opérations StableHLO, car celles-ci faisaient initialement partie
l'opset StableHLO, mais on a par la suite considéré qu'il ne s'adaptait pas correctement:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
torch_index_select
,unary_einsum
(n° 3) - Opérations inutilisées : ces opérations ont peut-être été utiles à un moment donné, mais les opérations
étaient soit sous-développées, soit les pipelines utilisant ces opérations ont été
refactorisées pour ne plus en avoir besoin. Cela inclut
map
,tuple
(#598), comparaisonsget_tuple_element
,rng
,complex
#560, et la convolutionwindow_reversal
(#1181).
Certaines de ces opérations peuvent être facilement
supprimées car elles peuvent être exprimées à l’aide de
opérations existantes (broadcast
, create_token
, cross-replica-sum
, dot
,
unary_einsum
) et seront supprimées après le délai de compatibilité existant
(6 mois). D'autres sont toujours en cours de suppression (einsum
,
get_tuple_element
, map
, torch_index_select
rng
, tuple
et complex
comparaisons, window_reversal
). En attente des commentaires de la communauté,
ces opérations seront soit supprimées, soit ajoutées à la spécification avec une prise en charge totale. Jusqu'au
si ces contrats à terme sont connus, ils ne sont garantis que 6 mois de compatibilité.
Exécution
Exécution séquentielle
Un programme StableHLO est exécuté en fournissant des valeurs d'entrée à la fonction main
et le calcul des valeurs de sortie. Les valeurs de sortie d'une fonction sont calculées
Exécuter le graphe des opérations en mode root dans l'opération return
correspondante
L'ordre d'exécution est défini par l'implémentation, tant qu'il est aligné avec
Dataflow, c'est-à-dire si les opérations sont exécutées avant leur utilisation. Dans StableHLO,
les opérations à effets secondaires consomment un seul jeton et en produisent un (plusieurs jetons
être multiplexé en un seul jeton via after_all
), de sorte que l'ordre d'exécution du côté
est également aligné sur Dataflow. Par exemple, dans le programme ci-dessous,
il existe deux ordres d'exécution possibles: %0
→ %1
→ %2
→ return
et
%1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Plus formellement, un processus StableHLO est une combinaison des éléments suivants:
1) un programme StableHLO, 2) des états des opérations (pas encore exécuté),
(déjà exécuté)) et 3) les valeurs intermédiaires sur lesquelles le processus travaille.
Le processus commence par les valeurs d'entrée de la fonction main
,
le graphique des opérations mettant à jour les états et les valeurs intermédiaires,
se termine par des valeurs de sortie. Plus de formalisation à déterminer
(#484)
Exécution parallèle
Les programmes StableHLO peuvent être exécutés en parallèle et organisés dans une grille de processus 2D.
de num_replicas
par num_partitions
, qui sont tous deux de type ui32
.
Dans la grille de processus StableHLO, num_replicas * num_partitions
de StableHLO
processus s'exécutent
en même temps. Chaque processus a une
process_id = (replica_id, partition_id)
, où
replica_id
dans replica_ids = range(num_replicas)
et
partition_id
dans partition_ids = range(num_partitions)
, qui ont toutes les deux
saisissez ui32
.
La taille de la grille de processus est connue de manière statique pour chaque programme (dans le
à l'avenir, nous prévoyons d'en faire une partie explicite des programmes StableHLO.
#650) et la position
au sein de la grille de processus est connue
de manière statique pour chaque processus. Chaque processus comporte
accès à sa position dans la grille de processus via replica_id
et
partition_id
opérations
Au sein de la grille des processus, les programmes peuvent tous être identiques (dans la section Programme, Données multiples" style), peuvent tous être différents (dans la section "Programmes, Données multiples" ou un autre style d'annonce. À l'avenir, nous prévoyons pour permettre la définition de programmes StableHLO parallèles dans d'autres idiomes, y compris GSPMD (#619).
Dans la grille de processus, les processus sont pour la plupart indépendants les uns des autres : ont des états d'opération distincts et des valeurs d'entrée/intermédiaire/sortie distinctes. La plupart des opérations sont exécutées séparément entre les processus, à l'exception d'un petit nombre d'opérations collectives décrites ci-dessous.
Étant donné que l'exécution de la plupart des opérations n'utilise que des valeurs de la même
processus, il est généralement sans ambiguïté de désigner ces valeurs par leur nom.
Toutefois, cette définition n'est pas suffisante pour décrire la sémantique des opérations collectives.
qui donne lieu à la notation name@process_id
pour faire référence à la valeur name
au cours d'un processus particulier. De ce point de vue, un name
non qualifié peut être
est considéré comme un raccourci pour name@(replica_id(), partition_id())
).
L'ordre d'exécution entre les processus est défini par l'implémentation, à l'exception de synchronisation introduite par la communication point à point et les opérations collectives comme décrit ci-dessous.
Communication point à point
Les processus StableHLO
peuvent communiquer entre eux via
Canaux StableHLO. Une chaîne est représentée par un identifiant positif de type.
si64
Grâce à diverses opérations, il est possible d'envoyer des valeurs aux canaux
les recevoir des chaînes.
Structuration plus poussée, par exemple d'où viennent ces identifiants de critères, les processus et les programmes en prennent conscience et le type de synchronisation introduites par eux, est à déterminer (#484)
Communication en flux continu
Chaque processus StableHLO a accès à deux interfaces de streaming:
- InFeed, qui peut être lu.
- Flux de sortie dans lequel des opérations d'écriture peuvent être effectuées.
Contrairement aux canaux, qui sont utilisés pour communiquer entre les processus et donc des processus, les flux d'entrée et de sortie définies par l'implémentation.
Structuration plus poussée, par exemple comment la communication en flux influence l'exécution et le type de synchronisation qu'il introduit, est à déterminer. (#484)
Opérations collectives
StableHLO comporte six opérations collectives: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
et
reduce_scatter
Toutes ces opérations divisent les processus dans le processus StableHLO
en groupes de processus StableHLO et exécuter un calcul commun dans
chaque groupe de processus, indépendamment des autres groupes de processus.
Au sein de chaque groupe de processus, des opérations collectives peuvent introduire une synchronisation de sécurité. Structuration plus poussée, par exemple à savoir quand exactement la synchronisation se produit, comment exactement les processus parviennent à cet obstacle, et ce qui se passe si ce n'est pas le cas, à déterminer (#484)
Si le groupe de processus implique une communication entre partitions, c'est-à-dire qu'il existe
processus du groupe de processus dont les ID de partition sont différents, puis l'exécution
de l'opération collective a besoin d'une chaîne, et l'opération collective doit fournir une
channel_id
à inclure de type si64
. La communication entre instances répliquées n'a pas besoin
canaux de distribution.
Les calculs effectués par les opérations collectives sont spécifiques à des opérations individuelles. et sont décrites dans les sections dédiées aux opérations individuelles ci-dessus. Cependant, les stratégies dans laquelle la grille de processus est divisée en groupes de processus. Ces groupes sont partagés entre ces opérations. et sont décrits dans cette section. Plus formellement, StableHLO prend en charge en suivant quatre stratégies.
cross_replica
Seules les communications multi-instances répliquées ont lieu au sein de chaque groupe de processus. Ce
utilise replica_groups
(une liste de listes d'ID d'instances répliquées) et calcule
un produit cartésien de replica_groups
par partition_ids
. replica_groups
doit comporter des éléments uniques et couvrir tous les replica_ids
. Plus formellement, en utilisant
Syntaxe Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]]
et num_partitions = 2
,
cross_replica
produira
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
cross_partition
Seules les communications entre partitions ont lieu au sein de chaque groupe de processus. Ce
utilise partition_groups
, une liste de listes d'ID de partition, et
calcule un produit cartésien de partition_groups
par replica_ids
.
partition_groups
doit comporter des éléments uniques et couvrir tous les partition_ids
.
Plus formellement, en utilisant la syntaxe Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour partition_groups = [[0, 1]]
et num_replicas = 4
,
cross_partition
produira
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
cross_replica_and_partition
Des communications entre instances répliquées et entre partitions peuvent avoir lieu au sein de chaque
de processus. Cette stratégie utilise replica_groups
, une liste de listes
d'instances répliquées - et calcule les produits cartésiens de chaque replica_group
par
partition_ids
replica_groups
doit comporter des éléments uniques et tous les couvrir
replica_ids
Plus formellement, en utilisant la syntaxe Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour replica_groups = [[0, 1], [2, 3]]
et num_partitions = 2
,
cross_replica_and_partition
produira
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
flattened_ids
Cette stratégie utilise flattened_id_groups
, une liste de listes "aplaties"
les identifiants de processus sous la forme replica_id * num_partitions + partition_id
;
les transforme en identifiants de processus. flattened_id_groups
doit comporter des éléments uniques
et couvrent tous les process_ids
. Plus formellement, en utilisant la syntaxe Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Par exemple, pour flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
et num_partitions = 2
, flattened_ids
produira
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
Précision
Pour le moment, StableHLO n'offre aucune garantie de précision numérique, mais cela pourrait changer à l'avenir (#1156)
Sémantique d'exécution d'une opération quantifiée
L'interprétation des opérations StableHLO quantifiées peut varier en fonction du la configuration requise et les capacités matérielles. Par exemple, certains matériels peuvent choisir interpréter les opérations quantifiées à l'aide d'une fonction "dequantize, effectuer des opérations à virgule flottante et enfin l'opération de quantification ». stratégie. D'autres peuvent effectuer l'intégralité avec des calculs arithmétiques entiers. Par conséquent, l'interprétation de des opérations StableHLO quantifiées est exclusivement déterminé par la clé la mise en œuvre. Interprétation de la quantification hybride (#1575) doit être basée sur sa sémantique, telle qu'indiquée dans la spécification (via 1792).
Erreurs
Les programmes StableHLO sont validés par un vaste ensemble de contraintes pour des opérations individuelles, ce qui exclut de nombreuses classes d'erreurs avant l'exécution. Toutefois, des conditions d'erreur restent possibles, par exemple via des dépassements d'entiers, accès hors limites, etc. Sauf indication contraire explicite, toutes ces erreurs peut entraîner un comportement défini par l'implémentation, mais cela peut changer au niveau (#1157)
Exceptions à virgule flottante
À titre d'exception à cette règle, les exceptions à virgule flottante dans les programmes StableHLO
ont un comportement bien défini. Les opérations entraînant des exceptions définies par
Norme IEEE-754 (opération non valide, division par zéro, dépassement, dépassement, dépassement ou
exceptions inexactes) génèrent des résultats par défaut (tels que définis dans la norme) et
poursuivre l'exécution sans générer l'indicateur d'état correspondant ; similaire à
Traitement des exceptions raiseNoFlag
par rapport à la norme. Exceptions pour les images non standards
(par exemple, des fonctions arithmétiques complexes et certaines fonctions transcendantes)
définies par l'implémentation.
Incohérences au niveau des formes
StableHLO accepte les Tensors de forme dynamique. Cependant, les formes doivent s'accorder l'environnement d'exécution. Sinon, le comportement n'est pas défini. StableHLO ne fait pas explicitement fournissent une opération qui peut affirmer qu'un Tensor a une forme donnée au moment de l'exécution. La génération du code correct relève de la responsabilité du producteur.
À titre d'exemple, le programme ci-dessous est correct. Cependant, au moment de l'exécution,
les formes exactes de %arg0
et %arg1
doivent être identiques. Sinon, la
le comportement du programme n'est pas défini:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Pour décrire la syntaxe, ce document utilise le type de produit ISO modifié d'EBNF
(ISO/IEC 14977:1996,
Wikipédia),
avec deux modifications: 1) les règles sont définies à l'aide de ::=
au lieu de =
,
2) La concaténation est exprimée à l'aide de la juxtaposition plutôt que de ,
.
Pour décrire la sémantique (c'est-à-dire dans les sections "Types", "Constantes" et "Opérations") nous utilisons des formules basées sur la syntaxe Python, avec prise en charge pour exprimer de manière concise les opérations de tableau, comme décrit ci-dessous. Cela fonctionne bien pour les petits extraits de code, mais dans de rares cas, lorsque de plus grands extraits de code nous utilisons la syntaxe Python vanilla qui est toujours introduite explicitement.
Formules
Examinons le fonctionnement des formules à partir d'un exemple tiré de dot_general
.
spécifique. L'une des contraintes de cette opération se présente comme suit:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
Les noms utilisés dans cette formule proviennent de deux sources: 1) des fonctions globales,
Exemple : dim
, 2) définitions des membres de l'élément du programme correspondant :
Entrées lhs
, lhs_batching_dimensions
, rhs
et rhs_batching_dimensions
défini dans la section "Entrées" de dot_general
.
Comme indiqué ci-dessus, la syntaxe de cette formule est basée sur Python avec certaines extensions axées sur la concision. Pour donner un sens à la formule, transformons en syntaxe Python vanilla.
A) Dans ces formules, nous utilisons =
pour représenter l'égalité. La première étape
pour obtenir la syntaxe Python, remplacez =
par ==
, comme suit:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
B) De plus, ces formules prennent en charge les points de suspension (...
) qui transforment les expressions scalaires
en expressions de Tensor. En résumé, f(xs...)
signifie plus ou moins "pour chaque
x
scalaire dans le Tensor xs
, calculer une valeur f(x)
scalaire, puis renvoyer toutes les valeurs
ces résultats scalaires ensemble sous la forme d'un résultat de Tensor". Dans la syntaxe Python vanilla,
notre exemple de formule
devient:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
Grâce aux points de suspension, il est souvent possible d'éviter de travailler au niveau
des scalaires individuels. Toutefois, dans certains cas délicats, les données
la syntaxe peut être utilisée comme dans la formule start_indices[bi0, ..., :, ..., biN]
de la spécification gather
. Au service de la concision, nous ne faisons pas
fournir un formalisme exact pour traduire
cette syntaxe en Python vanilla, dans
espère qu'elle sera toujours compréhensible de manière intuitive au cas par cas.
Veuillez nous indiquer si certaines formules spécifiques semblent opaques et nous essaierons de
les améliorer.
De plus, vous remarquerez que les formules utilisent des points de suspension pour développer toutes sortes de listes, des Tensors, des listes de Tensors (p.ex., qui peuvent provenir d'un variadique (nombre de Tensors), etc. Il s'agit d'un autre domaine dans lequel nous ne fournissons pas le formalisme (par exemple, les listes ne font même pas partie du système de types StableHLO) et s’appuient plutôt sur une compréhension intuitive.
C) La dernière notation remarquable que nous utilisons est implicite la diffusion d'annonces. Bien que l'opset StableHLO ne prenne pas en charge la diffusion implicite, des formules, également au service de la concision. En résumé, si une valeur scalaire est utilisée dans un contexte où un Tensor est attendu, le scalaire est diffusé à la forme attendue.
Pour poursuivre l'exemple dot_general
, voici une autre contrainte:
0 <= lhs_batching_dimensions < rank(lhs)
Tel que défini dans les dot_general
spécification, lhs_batching_dimensions
est un Tensor, mais 0
et
Les rank(lhs)
sont des scalaires. Après avoir appliqué la diffusion implicite, la formule
devient [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Lorsqu'elle est appliquée à une opération dot_general
spécifique, cette formule
correspond à un Tensor de valeurs booléennes. Lorsque des formules sont utilisées comme contraintes, le
la contrainte indique si la formule renvoie la valeur true
ou un Tensor qui
ne contient que des éléments true
.
Noms
Dans les formules, la portée lexicale inclut: 1) les fonctions globales, 2) les définitions de membres,
3) définitions locales. Vous trouverez ci-dessous la liste des fonctions globales. La liste de définitions d'éléments dépend de l'élément du programme auquel la notation est appliqué à:
- Pour les opérations, les définitions des membres incluent les noms introduits dans "Entrées" et "Sorties" .
- Pour tout le reste, les définitions des membres incluent les parties structurelles du
élément du programme, portant le nom des non-terminaux EBNF correspondants. La plupart des
le nom de ces parties structurelles est obtenu en convertissant
Noms des non-terminaux à utiliser avec snake case (par exemple,
IntegerLiteral
=>integer_literal
), mais il arrive que les noms soient abrégés (par exemple,QuantizationStorageType
=>storage_type
), auquel cas les noms sont introduit explicitement de la même manière que "Entrées" / "Sorties" sections en activité caractéristiques techniques. - De plus, les définitions des membres incluent toujours
self
pour faire référence au l'élément de programme correspondant.
Valeurs
Lorsque les formules sont évaluées, elles fonctionnent avec les types de valeurs suivants:
1) Value
(valeurs réelles, par exemple dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
ils connaissent toujours leurs types),
2) Placeholder
(valeurs futures, par exemple lhs
, rhs
ou result
; leurs valeurs réelles)
les valeurs ne sont pas encore connues, seuls leurs types sont connus),
3) Type
(types définis dans la section "Types"),
4) Function
(fonctions globales telles que définies dans la section "Fonctions").
Selon le contexte, les noms peuvent faire référence à différentes valeurs. Plus
en particulier la "sémantique" pour les opérations (et ses équivalents pour les autres programmes
) définit la logique d'exécution, de sorte que toutes les entrées sont disponibles en tant que Value
.
En revanche, la colonne "Contraintes" des opérations (et leurs équivalents) définit
"compile-time" c'est-à-dire quelque chose qui est
généralement exécuté avant l'exécution,
Ainsi, seules les entrées constantes sont disponibles en tant que Value
, et les autres entrées sont
disponible uniquement en tant que Placeholder
.
Noms | Dans "Sémantique" | Dans "Contraintes" |
---|---|---|
Fonctions globales | Function |
Function |
Entrées constantes | Value |
Value |
Entrées non constantes | Value |
Placeholder |
Sorties | Value |
Placeholder |
Définitions locales | Dépend de la définition | Dépend de la définition |
Prenons un exemple d'opération transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Pour cette opération, permutation
est une constante. Il est donc disponible en tant que Value
.
tant au niveau de la sémantique que des contraintes. En revanche, operand
et result
sont
disponible en tant que Value
en sémantique, mais uniquement en tant que Placeholder
dans les contraintes.
Fonctions
Construction de types
Aucune fonction ne peut être utilisée pour créer des types. Au lieu de cela, nous
directement
utilisez la syntaxe du type,
car elle est généralement plus concise. Exemple :
(tensor<E>, tensor<E>) -> (tensor<E>)
au lieu de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Fonctions sur les types
element_type
est défini sur les types de Tensors et les types de Tensors quantifiés. renvoie respectivementTensorElementType
ouQuantizedTensorElementType
. une partie duTensorType
ou duQuantizedTensorType
correspondant.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
vérifie si le typex
peut être promu pour saisiry
. Lorsquex
ety
sont desQuantizedTensorElementType
, la promotion ne s'applique qu'àstorage_type
. Cette version spécifique de la promotion est actuellement utilisé dans le contexte du calcul de réduction (reportez-vous à RFC pour en savoir plus).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
est un raccourci pouris_quantized_tensor_element_type(x)
is_type_name(x: Value | Placeholder | Type) -> Value
Disponible pour tous de données. Par exemple,is_float(x)
renvoietrue
six
est de typeFloatType
. Six
est une valeur ou un espace réservé, cette fonction est un raccourci pouris_type_name(type(x))
max_value(x: Type) -> Value
renvoie la valeur maximale d'unTensorElementType
Six
n'est pas de typeTensorElementType
, la fonction renvoieNone
.min_value(x: Type) -> Value
renvoie la valeur minimale possible d'uneTensorElementType
Six
n'est pas de typeTensorElementType
, la fonction renvoieNone
.member_name(x: Value | Placeholder | Type) -> Any
Disponible pour tous les membres définitionsmember_name
de tous types. Exemple :tensor_element_type(x)
renvoie la partieTensorElementType
d'unTensorType
correspondant. Six
est une valeur ou un espace réservé, cette fonction est un raccourci pourmember_name(type(x))
Six
n'est pas un type disposant d'un membre approprié, ou une valeur ou un espace réservé de ce type, renvoieNone
.is_empty_algorithm(*args: Type)
vérifie si tous les champs de l'algorithme à points sont définis àNone
. Cela est nécessaire, car l'implémentation est définie pour les algorithmes . comportements par défaut. Par conséquent, spécifier une valeur par défaut serait incorrect.
Construction des valeurs
operation_name(*xs: Value | Type) -> Value
Disponible pour toutes les opérations. Par exemple,add(lhs, rhs)
prend deux valeurs de Tensorlhs
etrhs
, et renvoie le résultat de l'évaluation de l'opérationadd
avec ces entrées. Pour certaines opérations, par exemplebroadcast_in_dim
, les types de leurs sorties sont les suivants : "portant", c'est-à-dire nécessaire pour évaluer une opération. Dans ce cas, la fonction utilise ces types comme arguments.
Fonctions sur les valeurs
Tous les opérateurs et fonctions Python sont disponibles. Exemple : les deux abonnement et le segmentage les notations de Python peuvent être indexées sous forme de Tensors quantifiés, et les tuples.
to_destination_type(x: Value, destination_type: Type) -> Value
est défini le et renvoie la valeur convertie dex
en fonction destype(x)
etdestination_type
comme suit:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
La fusion de convert
, uniform_quantize
et
Opérations uniform_dequantize
(#1576).
Après la fusion, nous n'avons plus besoin de la fonction ci-dessus et pouvons utiliser le nom de l'opération
pour convert
à la place.
is_nan(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si tous les éléments dex
sontNaN
oufalse
dans les autres cas. Six
n'est pas un Tensor, renvoieNone
.is_sorted(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si les éléments dex
sont triés par ordre croissant l'ordre lexicographique de leurs index, oufalse
dans les autres cas. Six
n'est pas un renvoieNone
.is_unique(x: Value) -> Value
est défini sur les Tensors et renvoietrue
six
. ne contient pas d'éléments en double, nifalse
dans le cas contraire. Six
n'est pas un Tensor, renvoieNone
.member_name(x: Value) -> Any
est défini pour toutes les définitions de membre.member_name
de l'ensemble des valeurs. Par exemple,real_part(x)
renvoieRealPart
partie d'unComplexConstant
correspondant. Six
n'est pas une valeur ayant un membre approprié, renvoieNone
.same(x: Value) -> Value
est défini sur les Tensors et renvoietrue
si les éléments dex
sont tous égaux les uns par rapport aux autres, oufalse
dans le cas contraire. Si le Tensor ne comporte aucun élément, c'est-à-dire "tous égaux les uns aux autres". renvoietrue
. Six
n'est pas un Tensor, la fonction renvoieNone
.split(x: Value, num_results: Value, axis: Value) -> Value
est défini le et renvoienum_results
tranches dex
le long de l'axeaxis
. Six
n'est pas un Tensor oudim(x, axis) % num_results != 0
, renvoieNone
.is_defined_in_parent_scope(x: Value) -> Value
est défini sur des chaînes et renvoietrue
six
est le nom d'une fonction définie dans le même champ d'application. en tant que fonction parente de l'opération concernée.is_namespaced_op_name(x: Value) -> Value
est défini sur des chaînes et renvoietrue
six
est un nom d'opération valide, c'est-à-dire qu'il respecte le code expression:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Calculs de formes
axes(x: Value | Placeholder | Type) -> Value
est un raccourci pourrange(rank(x))
dim(x: Value | Placeholder | Type, axis: Value) -> Value
est un raccourci pourshape(x)[axis]
dims(x: Value | Placeholder | Type, axes: List) -> List
est un raccourci pourlist(map(lambda axis: dim(x, axis), axes))
index_space(x: Value | Placeholder | Type) -> Value
est défini sur les Tensors. et renvoie les indexsize(x)
pour la colonneTensorType
correspondante, triée de la façon suivante : ordre lexicographique croissant (par exemple,[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
Six
n'est pas un type de Tensor, un type de Tensor quantifié ou une valeur ou un espace réservé de l'un de ces types, renvoieNone
.rank(x: Value | Placeholder | Type) -> Value
est un raccourci poursize(shape(x))
shape(x: Value | Placeholder | Type) -> Value
est défini dans la colonne "Functions" sur les types" viamember_name
.size(x: Value | Placeholder | Type) -> Value
est un raccourci pourreduce(lambda x, y: x * y, shape(x))
Calculs de quantification
def baseline_element_type(x: Value | Placeholder | Type) -> Type
est un raccourci pourelement_type(baseline_type(x))
.baseline_type
est défini sur les types de Tensors et les types de Tensors quantifiés. les transforme en "référence", c'est-à-dire un type ayant la même forme, mais ayant les valeurs par défaut des paramètres de quantification du type d'élément sont rétablis. C'est est une astuce pratique pour comparer les types de Tensors et quantifiés de manière uniforme, ce qui est assez souvent nécessaire. Pour les types quantifiés, cela permet comparer des types en ignorant les paramètres de quantification, c'est-à-direshape
,storage_type
,expressed_type
,storage_min
,storage_max
etquantization_dimension
(pour le type quantifié par axe) doit tous correspondre, mais Les valeurs entrescales
etzero points
peuvent être différentes.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
est défini sur des types de Tensor quantifiés et les transforme en types de Tensor à virgule flottante. Cela se produit via la conversion d'éléments quantifiés qui représentent des valeurs entières du type de stockage valeurs à virgule flottante du type exprimé en utilisant le point zéro et l'échelle associé au type d'élément quantifié.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
est défini sur les types de Tensors à virgule flottante et les transforme en types de Tensor quantifiés. Cela se produit via la conversion des valeurs à virgule flottante du type exprimé en valeurs entières correspondantes du type de stockage. en utilisant le point zéro et l'échelle associées au type d'élément quantifié.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
permet de spécifier des calculs au niveau des éléments sur les Tensors quantifiés. Elle déquantifie, c'est-à-dire transforme les éléments quantifiés en des types exprimés, puis effectue une opération, puis effectue une quantification les résultats dans leurs types de stockage. Pour le moment, cette fonction ne fonctionne pour la quantification par Tensor. La quantification par axe est en cours de développement (#1574)
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
permet de spécifier une quantification pondérée uniquement pour Opération hybride qui accepte lhs en virgule flottante et rh en types quantifiés. Il déquantifie les entrées quantifiées en types exprimés et effectue des calculs en float. Type d'élément du Tensor lhs flottant et type exprimé de rh quantifiée doit être identique.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Calculs en grille
cross_partition(replica_groups: Value) -> Value
Consultez l'instance "cross_replica" ci-dessus.cross_replica(replica_groups: Value) -> Value
Consultez l'instance "cross_replica" ci-dessus.cross_replica_and_partition(replica_groups: Value) -> Value
Consultez le "cross_replica_and_partition" ci-dessus.flattened_ids(replica_groups: Value) -> Value
Voir la colonne "flattened_ids" ci-dessus.
Dynamique
Les valeurs StableHLO peuvent avoir des tailles de dimension dynamiques (par exemple, tensor<?xi64>
Toutefois, les valeurs StableHLO ne peuvent pas comporter de nombres dynamiques de dimensions (non classées
dynamisme, par exemple tensor<*xi64>
). Les opérandes et les résultats sont autorisés à utiliser des
des dimensions, même s'il existe des contraintes sur les tailles. Les contraintes seront
vérifiées statiquement si possible, sinon ils sont différés à l'exécution et
incohérences entraîneront un comportement non défini. Vous trouverez des exemples ci-dessous.
Incohérences de forme pour les opérations unaires par élément
Prenons l'exemple du programme de jouets suivant:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Un tel programme est inhabituel, car il n'est pas courant de connaître la forme du
mais pas la forme de l'entrée. Il s'agit toutefois d'un StableHLO valide
programme. Il n'est pas possible de valider de manière statique l'opération abs
dans cette
programme, car la forme exacte de l'opérande est inconnue. Cependant, les formes
sont certainement compatibles, et vous pouvez le vérifier de manière statique: ?
pourrait s'avérer
serait 2
au moment de l'exécution, et il n'y aurait aucun problème. Toutefois, ?
pourrait
s'avèrent également être un autre nombre entier, auquel cas le comportement n'est pas défini.
Notez que si une taille de dimension est dynamique dans le résultat, il ne peut pas un comportement indéfini. En effet, il n'y a pas de "présence" Il ne peut donc pas y avoir ne correspondent pas.
Incohérences de forme pour les opérations binaires par élément
Prenons l'exemple du programme de jouets suivant:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Dans le cas d'opérations binaires par élément, la forme des entrées et le le résultat doit concorder au moment de l'exécution. Au moment de la compilation, les dimensions statiques doivent être égales. sinon ils ont juste besoin d'être compatibles. Si n'importe quelle dimension est dynamique dans les entrées, il est possible qu'il n'y ait pas de définition au moment de l'exécution, car il est possible que la taille dynamique ne corresponde pas taille dans l'autre opérande (statique ou dynamique). Si toutes les entrées sont statique, le fait que le résultat soit dynamique ou non n'a pas d'importance: statiquement, les dimensions connues sont vérifiées de manière statique, contrairement aux dimensions dynamiques imposent des contraintes.
Incohérences de formes pour les opérations dont la forme de sortie est un opérande
Prenons l'exemple du programme de jouets suivant:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Les valeurs de l'opérande de forme au moment de l'exécution doivent correspondre à la forme du résultat.
sinon le comportement n'est pas défini. Autrement dit, au moment de l'exécution, %arg0
doit avoir une
la valeur de dense<[3, 4]> : tensor<2xi32>
. Si l'opérande de forme est constant, cette
peut être vérifiée de manière statique. Si la forme du résultat est entièrement dynamique,
ne peut pas être une incohérence.