StableHLO ist ein Satz von High-Level-Operationen (HLO) in Modellen für maschinelles Lernen (ML). StableHLO dient als Portabilitätsschicht zwischen verschiedenen ML-Frameworks und ML-Compilern: ML-Frameworks, die StableHLO-Programme generieren, sind mit ML-Compilern kompatibel, die StableHLO-Programme verwenden.
Unser Ziel ist es, die ML-Entwicklung zu vereinfachen und zu beschleunigen, indem wir die Interoperabilität zwischen verschiedenen ML-Frameworks (z. B. TensorFlow, JAX und PyTorch) und ML-Compilern (z. B. XLA und IREE) verbessern. Zu diesem Zweck enthält dieses Dokument eine Spezifikation für die Programmiersprache StableHLO.
Diese Spezifikation enthält drei Hauptabschnitte. Im Abschnitt Programme wird zuerst die Struktur von StableHLO-Programmen beschrieben, die aus StableHLO-Funktionen bestehen, die wiederum aus StableHLO-Vorgängen bestehen. Innerhalb dieser Struktur wird im Abschnitt Ops die Semantik der einzelnen Vorgänge angegeben. Der Abschnitt Ausführung enthält die Semantik für alle diese Vorgänge, die gemeinsam in einem Programm ausgeführt werden. Im Abschnitt Notation wird die in der gesamten Spezifikation verwendete Notation erläutert.
Öffnen Sie das Repository für den gewünschten getaggten Release, um die Spezifikation einer früheren StableHLO-Version anzusehen. Beispiel: StableHLO-Version 0.19.0. Informationen zu den Änderungen bei jeder Minor-Version von StableHLO finden Sie im Versionsprotokoll in VhloDialect.td.
Programme
Program ::= {Func}
StableHLO-Programme bestehen aus einer beliebigen Anzahl von StableHLO-Funktionen.
Unten sehen Sie ein Beispielprogramm mit einer Funktion @main
mit drei Eingaben (%image
, %weights
und %bias
) und einer Ausgabe. Der Hauptteil der Funktion
hat 6 Operationen.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funktionen
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
StableHLO-Funktionen (auch benannte Funktionen genannt) haben eine Kennung, Eingaben/Ausgaben und einen Body. Künftig planen wir, zusätzliche Metadaten für Funktionen einzuführen, um eine bessere Kompatibilität mit HLO zu erreichen (#425, #626, #740, #744).
IDs
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
StableHLO-Kennungen ähneln den Kennungen in vielen Programmiersprachen, mit zwei Besonderheiten: 1) Alle Kennungen haben Sigils, die verschiedene Arten von Kennzeichnungen unterscheiden, 2) Wertkennungen können vollständig numerisch sein, um die Erstellung von StableHLO-Programmen zu vereinfachen.
Typen
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
StableHLO-Typen werden in Werttypen kategorisiert, die auch First-Class-Typen genannt werden. Sie stellen stabileHLO-Werte dar und Typen ohne Wert, die andere Programmelemente beschreiben. StableHLO-Typen ähneln den Typen vieler Programmiersprachen. Die Hauptbesonderheit ist die domänenspezifische Natur von StableHLO, die zu einigen ungewöhnlichen Ergebnissen führt (z. B. sind Skalartypen keine Werttypen).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Tensor-Typen stellen Tensoren dar, also mehrdimensionale Arrays. Sie haben eine Form und einen Elementtyp, wobei eine Form nicht negative oder unbekannte Dimensionsgrößen in aufsteigender Reihenfolge der entsprechenden Abmessungen (auch Achsen genannt) darstellt, die von 0
bis R-1
nummeriert sind. Die Anzahl der Dimensionen R
wird als Rang bezeichnet. tensor<2x3xf32>
ist beispielsweise ein Tensortyp mit der Form 2x3
und dem Elementtyp f32
. Es hat zwei Dimensionen (oder zwei Achsen): die Nullte Dimension und die Erste Dimension, deren Größe 2 und 3 ist. Sein Rang ist 2.
Formen können teilweise oder vollständig unbekannt (dynamisch) sein, z. B. tensor<?x2xf64>
teilweise unbekannt und tensor<?x?xf64>
vollständig unbekannt. Die Größe dynamischer Dimensionen wird durch ein ?
dargestellt. Die Rangfolge von Formen kann nicht aufgehoben werden.
In Zukunft möchten wir Tensortypen über Dimensionsgrößen und Elementtypen hinaus erweitern, z. B. um Layouts (#629) und Sparsity (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Name | Typ | Einschränkungen |
---|---|---|
storage_type |
Ganzzahltyp | (C1-C3), (C8) |
storage_min |
Ganzzahlkonstante | (C1), (C3), (C7) |
storage_max |
Ganzzahlkonstante | (C2), (C3), (C7) |
expressed_type |
Gleitkommatyp | (C4) |
quantization_dimension |
optionale Ganzzahlkonstante | (C10–C12) |
scales |
Variadische Anzahl von Gleitkommakonstanten | (C4-C6), (C9), (C10), (C13) |
zero_points |
variadische Anzahl ganzzahliger Konstanten | (C7–C9) |
Quantisierte Elementtypen stellen Ganzzahlwerte eines Speichertyps im Bereich von storage_min
bis storage_max
(einschließlich) dar, die Gleitkommawerten eines ausgedrückten Typs entsprechen. Für einen gegebenen Ganzzahlwert i
kann der entsprechende Gleitkommawert f
als f = (i - zero_point) * scale
berechnet werden, wobei scale
und zero_point
als Quantisierungsparameter bezeichnet werden. storage_min
und storage_max
sind in der Grammatik optional, haben aber die Standardwerte min_value(storage_type)
und max_value(storage_type)
. Quantisierte Elementtypen unterliegen den folgenden Einschränkungen:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Wenn
is_empty(quantization_dimension)
, dannsize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Derzeit ist QuantizationScale
eine Gleitkommakonstante. Es besteht jedoch ein großes Interesse an ganzzahlbasierten Skalen, die durch Multiplikatoren und Verschiebungen dargestellt werden. Wir planen, dies in naher Zukunft zu untersuchen (#1404).
Es gibt eine laufende Diskussion über die Semantik von QuantizationZeroPoint
, einschließlich des Typs, der Werte und der Frage, ob es in einem quantisierten Tensortyp nur einen oder potenziell mehrere Nullpunkte geben kann. Basierend auf den Ergebnissen dieser Diskussion kann sich die Spezifikation für Nullpunkte in Zukunft ändern (#1405).
Eine weitere laufende Diskussion betrifft die Semantik von QuantizationStorageMin
und QuantizationStorageMax
, um festzustellen, ob für diese Werte und die Werte quantisierter Tensoren Einschränkungen gelten sollten (#1406).
Außerdem möchten wir die Darstellung unbekannter Skalen und Nullpunkte untersuchen, ähnlich wie wir die Darstellung unbekannter Dimensionsgrößen untersuchen (#1407).
Quantisierte Tensortypen stellen Tensoren mit quantisierten Elementen dar. Diese Tensoren sind genau wie normale Tensoren, mit der Ausnahme, dass ihre Elemente quantisierte Elementtypen anstelle von regulären Elementtypen haben.
Bei quantisierten Tensoren kann die Quantisierung pro Tensor erfolgen, d. h., es gibt ein scale
und ein zero_point
für den gesamten Tensor. Sie kann aber auch pro Achse erfolgen, d. h., es gibt mehrere scales
und zero_points
, jeweils ein Paar pro Scheibe einer bestimmten Dimension quantization_dimension
. Formal gibt es in einem Tensor t
mit Achsenquantisierung dim(t, quantization_dimension)
-Segmente des quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
usw. Alle Elemente im i
-ten Segment verwenden scales[i]
und zero_points[i]
als Quantisierungsparameter. Für quantisierte Tensortypen gelten die folgenden Einschränkungen:
- Für die Quantisierung pro Tensor:
- Es gelten keine zusätzlichen Einschränkungen.
- Für die Quantisierung pro Achse:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Tokentypen repräsentieren Tokens, also nicht-transparente Werte, die von einigen Vorgängen erzeugt und verbraucht werden. Mithilfe von Tokens wird die Ausführungsreihenfolge von Vorgängen festgelegt, wie im Abschnitt Ausführung beschrieben.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Tupeltypen stellen Tupel dar, also heterogene Listen. Tupel sind eine ältere Funktion, die nur zur Kompatibilität mit HLO vorhanden ist. In HLO werden Tupel verwendet, um variadische Ein- und Ausgaben darzustellen. In StableHLO werden variadische Eingaben und Ausgaben nativ unterstützt. Die einzige Verwendung von Tupeln in StableHLO besteht darin, die HLO-ABI umfassend darzustellen, wobei sich z. B. T
, tuple<T>
und tuple<tuple<T>>
je nach Implementierung erheblich unterscheiden können. Wir planen in Zukunft Änderungen am HLO ABI, die es uns ermöglichen könnten, Tupeltypen aus StableHLO zu entfernen (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Elementtypen stellen Elemente von Tensortypen dar. Im Gegensatz zu vielen Programmiersprachen sind diese Typen in StableHLO keine Erstklassigen. Das bedeutet, dass Werte dieser Typen in StableHLO-Programmen nicht direkt dargestellt werden können. Daher ist es üblich, Skalarwerte vom Typ T
mit 0-dimensionalen Tensorwerten vom Typ tensor<T>
darzustellen.
- Der Boolesche Typ stellt die booleschen Werte
true
undfalse
dar. - Ganzzahltypen können vorzeichenbehaftet (
si
) oder vorzeichenlos (ui
) sein und haben eine der unterstützten Bitbreiten (2
,4
,8
,16
,32
oder64
). VorzeichenbehaftetesiN
-Typen stellen Ganzzahlwerte von-2^(N-1)
bis2^(N-1)-1
dar, vorzeichenloseuiN
-Typen Ganzzahlwerte von0
bis2^N-1
. - Gleitkommatypen können einen der folgenden Werte haben:
f8E3M4
,f8E4M3
undf8E5M2
sind 8‑Bit-Gleitkommazahlen gemäß IEEE-754-Konventionen.- Die Typen
f8E4M3FN
undf8E5M2
entsprechen denE4M3
- undE5M2
-Codierungen des FP8-Formats, die in FP8-Formate für Deep Learning beschrieben sind. f8E4M3FNUZ
- undf8E5M2FNUZ
-Typen, die der CodierungE4M3
undE5M2
der FP8-Formate entsprechen, die unter Numerische 8-Bit-Formate für neuronale Deep-Learning-Netzwerke beschrieben werden.f8E4M3B11FNUZ
-Typ, der derE4M3
-Codierung der FP8-Formate entspricht, die unter Hybrid-8-Bit-Gleitkommazahl (HFP8) Training und Inferenz für neuronale Deep-Learning-Netzwerke beschrieben werden.bf16
-Typ, der dembfloat16
-Format entspricht, das in BFloat16: Das Geheimnis der hohen Leistung auf Cloud TPUs beschrieben wird.- Die Typen
f16
,f32
undf64
entsprechen den Formatenbinary16
(„halbe Genauigkeit“),binary32
(„einfache Genauigkeit“) undbinary64
(„doppelte Genauigkeit“), die im IEEE 754-Standard beschrieben sind. - Der Typ
tf32
entspricht dem TensorFloat32-Format und wird in StableHLO nur eingeschränkt unterstützt. f4E2M1FN
,f6E2M3FN
,f6E3M2FN
undf8E8M0FNU
MX-Typen (Microscaling) gemäß der OCP-Spezifikation für Microscaling-Formate
- Komplexe Typen stellen komplexe Werte dar, die einen reellen Teil und einen imaginären Teil desselben Elementtyps haben. Unterstützte komplexe Typen sind
complex<f32>
(beide Teile vom Typf32
) undcomplex<f64>
(beide Teile vom Typf64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Funktionstypen können sowohl benannte als auch anonyme Funktionen sein. Sie haben Eingabetypen (die Liste der Typen links von ->
) und Ausgabetypen (die Liste der Typen rechts von ->
). In vielen Programmiersprachen sind Funktionstypen erstklassig, aber nicht in StableHLO.
StringType ::= 'string'
Der Stringtyp steht für Bytefolgen. Anders als in vielen Programmiersprachen ist der Stringtyp in StableHLO nicht erstklassig und wird nur zum Angeben statischer Metadaten für Programmelemente verwendet.
Vorgänge
StableHLO-Vorgänge (auch Ops genannt) stellen eine geschlossene Gruppe von Hochebenen-Vorgängen in Modellen für maschinelles Lernen dar. Wie bereits erwähnt, ist die StableHLO-Syntax stark von MLIR inspiriert. MLIR ist nicht unbedingt die innovativste Alternative, eignet sich aber wohl am besten für das Ziel von StableHLO, mehr Interoperabilität zwischen ML-Frameworks und ML-Compilern zu schaffen.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
StableHLO-Vorgänge (auch Ops genannt) haben einen Namen, Eingaben/Ausgaben und eine Signatur. Der Name besteht aus dem Präfix stablehlo.
und einer Mnemonik, die eine der unterstützten Vorgänge eindeutig identifiziert. Unten finden Sie eine vollständige Liste aller unterstützten Vorgänge.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Vorgänge verbrauchen Eingaben und erzeugen Ausgaben. Eingaben werden in Eingabewerte (während der Ausführung berechnet), Eingabefunktionen (statisch bereitgestellt, da Funktionen in StableHLO keine Erstklassigen Werte sind) und Eingabeattribute (auch statisch bereitgestellt) unterteilt. Die Art der Eingaben und Ausgaben, die von einer Operation verbraucht und erzeugt werden, hängt von ihrer Mnemonic ab. Der Operator add
benötigt beispielsweise zwei Eingabewerte und liefert einen Ausgabewert. Im Vergleich dazu verbraucht die Operation select_and_scatter
3 Eingabewerte, 2 Eingabefunktionen und 3 Eingabeattribute.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Eingabefunktionen (auch anonyme Funktionen genannt) ähneln benannten Funktionen, mit folgenden Ausnahmen: 1) Sie haben keine Kennung (daher der Name „anonym“). 2) Sie deklarieren keine Ausgabetypen. Diese werden aus der return
-Operation innerhalb der Funktion abgeleitet.
Die Syntax für Eingabefunktionen enthält einen derzeit nicht verwendeten Teil (siehe Unused
-Produktion oben), der aus Gründen der Kompatibilität mit MLIR vorhanden ist. In MLIR gibt es das allgemeinere Konzept von „Regionen“, die mehrere „Blöcke“ von Anweisungen haben können, die über Sprunganweisungen miteinander verbunden sind. Diese Blöcke haben IDs, die der Unused
-Produktion entsprechen, damit sie voneinander unterschieden werden können.
StableHLO hat keine Sprungbefehle. Daher wird der entsprechende Teil der MLIR-Syntax nicht verwendet, ist aber weiterhin vorhanden.
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Eingabeattribute haben einen Namen und einen Wert, der eine der unterstützten Konstanten ist. Sie sind die primäre Möglichkeit, statische Metadaten für Programmelemente anzugeben. Bei der Operation concatenate
wird beispielsweise das Attribut dimension
verwendet, um die Dimension anzugeben, entlang derer die Eingabewerte zusammengefügt werden. Ähnlich werden bei der slice
-Operation mehrere Attribute wie start_indices
und limit_indices
verwendet, um die Grenzen anzugeben, die zum Schneiden des Eingabewerts verwendet werden.
Derzeit enthalten StableHLO-Programme in der Praxis manchmal Attribute, die in diesem Dokument nicht beschrieben werden. Künftig werden wir diese Attribute entweder in den StableHLO-Opset aufnehmen oder ihre Verwendung in StableHLO-Programmen verbieten. In der Zwischenzeit finden Sie hier eine Liste dieser Attribute:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619)output_operand_aliases
(#740).- Standortmetadaten (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Die Operatorsignatur besteht aus den Typen aller Eingabewerte (die Liste der Typen auf der linken Seite von ->
) und den Typen aller Ausgabewerte (die Liste der Typen auf der rechten Seite von ->
). Streng genommen sind die Eingabetypen redundant und die Ausgabetypen fast immer auch, da bei den meisten StableHLO-Operatoren die Ausgabetypen aus den Eingaben abgeleitet werden können. Die Operatorsignatur ist jedoch bewusst Teil der StableHLO-Syntax, um mit MLIR kompatibel zu sein.
Unten sehen Sie ein Beispiel für eine Operation mit der mnemonischen Taste select_and_scatter
. Sie nimmt drei Eingabewerte (%operand
, %source
und %init_value
), zwei Eingabefunktionen und drei Eingabeattribute (window_dimensions
, window_strides
und padding
) auf. Beachten Sie, dass die Signatur der Operation nur die Typen der Eingabewerte enthält, aber nicht die Typen der Eingabefunktionen und Attribute, die inline angegeben werden.
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Konstanten
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
StableHLO-Konstanten haben ein Literal und einen Typ, die zusammen einen StableHLO-Wert darstellen. Generell ist der Typ Teil der konstanten Syntax, es sei denn, er ist eindeutig. Eine boolesche Konstante hat z. B. den Typ i1
, während eine Ganzzahlkonstante mehrere mögliche Typen haben kann.
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Boolesche Konstanten repräsentieren die booleschen Werte true
und false
. Boolesche Konstanten haben den Typ i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Ganzzahlkonstanten stellen Ganzzahlwerte über Strings in Dezimal- oder Hexadezimalschreibweise dar. Andere Basen, z. B. binär oder oktal, werden nicht unterstützt. Für Ganzzahlkonstanten gelten die folgenden Einschränkungen:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Gleitkommakonstanten stellen Gleitkommawerte über Strings dar, die die Dezimal- oder wissenschaftliche Schreibweise verwenden. Außerdem können Sie mit der Hexadezimalschreibweise die zugrunde liegenden Bits direkt im Gleitkommaformat des entsprechenden Typs angeben. Für Gleitkommakonstanten gelten die folgenden Einschränkungen:
- (C1) Wenn keine hexadezimal Schreibweise verwendet wird,
is_wellformed(float_literal, float_type)
. - (C2) Bei Verwendung der Hexadezimalschreibweise:
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Komplexe Konstanten stellen komplexe Werte mithilfe von Listen aus einem reellen Teil (an erster Stelle) und einem imaginären Teil (an zweiter Stelle) dar. (1.0, 0.0) : complex<f32>
steht beispielsweise für 1.0 + 0.0i
und (0.0, 1.0) : complex<f32>
für 0.0 + 1.0i
. Die Reihenfolge, in der diese Teile dann im Arbeitsspeicher gespeichert werden, ist implementierungsabhängig. Für komplexe Konstanten gelten die folgenden Einschränkungen:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Tensorkonstanten stellen Tensorwerte mit verschachtelten Listen dar, die über die NumPy-Notation angegeben werden. dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
steht beispielsweise für einen Tensorwert mit der folgenden Zuordnung von Indizes zu Elementen: {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
, {1, 2} => 6
. Die Reihenfolge, in der diese Elemente dann im Arbeitsspeicher gespeichert werden, ist von der Implementierung abhängig. Für Tensorkonstanten gelten die folgenden Einschränkungen:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, wobei:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, wobei gilt:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- Andernfalls
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Quantisierte Tensorkonstanten stellen quantisierte Tensorwerte mit derselben Notation wie Tensorkonstanten, wobei Elemente als Konstanten ihres Speichertyps angegeben sind. Für quantisierte Tensorkonstanten gelten die folgenden Einschränkungen:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Stringliterale bestehen aus Bytes, die mit ASCII-Zeichen und Escape-Sequenzen angegeben werden. Sie sind unabhängig von der Codierung, sodass die Interpretation dieser Bytes von der Implementierung abhängt. Stringliterale haben den Typ string
.
Operativer Betrieb
abs
Semantik
Führt einen elementweisen Absolutwert-Vorgang auf den operand
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Bei vorzeichenbehafteten Ganzzahlen: Ganzzahlmodul.
- Für Gleitkommazahlen:
abs
von IEEE-754. - Bei komplexen Zahlen: Komplexmodulus.
- Für quantisierte Typen:
dequantize_op_quantize(abs, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor eines vorzeichenbehafteten Ganzzahl-, Gleitkomma- oder komplexer Typ oder eines quantisierten Tensors pro Tensor | (C1-C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahl- oder Gleitkommatensor oder quantisierter Tensor pro Tensor | (C1-C2) |
Einschränkungen
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
ist definiert als:complex_element_type(element_type(operand))
, wennis_complex(operand)
.- Andernfalls
baseline_element_type(operand)
.
Beispiele
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
Hinzufügen
Semantik
Führt die elementweise Addition von zwei Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für boolesche Werte: Logisches OR.
- Für Ganzzahlen: Addition von Ganzzahlen.
- Für Gleitkommazahlen:
addition
von IEEE-754. - Bei komplexen Zahlen: komplexe Addition.
- Für quantisierte Typen:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor | (C1–C6) |
(I2) | rhs |
Tensor oder quantisierter Tensor | (C1–C5), (C7) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1–C7) |
Einschränkungen
- Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Wenn für den Vorgang quantisierte Tensoren verwendet werden:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Wenn
is_per_axis_quantized(lhs)
, dannquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Wenn
is_per_axis_quantized(rhs)
, dannquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Beispiele
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantik
Sorgt dafür, dass die Vorgänge, die die inputs
erzeugen, vor Vorgängen ausgeführt werden, die von result
abhängig sind. Die Ausführung dieses Vorgangs hat keine Auswirkungen. Er dient nur dazu, Datenabhängigkeiten von result
zu inputs
herzustellen.
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | inputs |
variadische Zahl von token |
Ausgaben
Name | Typ |
---|---|
result |
token |
Beispiele
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantik
Verkettet innerhalb jeder Prozessgruppe im StableHLO-Prozess-Raster die Werte der operands
-Tensoren aus jedem Prozess entlang all_gather_dim
und erzeugt results
-Tensoren.
Dabei wird das StableHLO-Prozess-Raster in process_groups
aufgeteilt, was so definiert ist:
cross_replica(replica_groups)
wennchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
wennchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
wennchannel_id > 0 and use_global_device_ids = true
.
Gehen Sie anschließend in jedem process_group
so vor:
operands...@receiver = [operand@sender for sender in process_group]
für allereceiver
inprocess_group
results...@process = concatenate(operands...@process, all_gather_dim)
für alleprocess
inprocess_group
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operands |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C1), (C6) |
(I2) | all_gather_dim |
Konstante vom Typ si64 |
(C1), (C6) |
(I3) | replica_groups |
2-dimensionale Tensorkonstante vom Typ si64 |
(C2-C4) |
(I4) | channel_id |
Konstante vom Typ si64 |
(C5) |
(I5) | use_global_device_ids |
Konstante vom Typ i1 |
(C5) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C6) |
Einschränkungen
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
wird so definiert:num_replicas
, wenncross_replica
verwendet wird.num_replicas
, wenncross_replica_and_partition
verwendet wird.num_processes
, wennflattened_ids
verwendet wird.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Wenn
use_global_device_ids = true
, dannchannel_id > 0
. - (C6)
type(results...) = type(operands...)
, mit folgenden Ausnahmen:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantik
Innerhalb jeder Prozessgruppe im StableHLO-Prozess-Raster wird eine Reduzierungsfunktion computation
auf die Werte der operands
-Tensoren aus jedem Prozess angewendet und es werden results
-Tensoren erzeugt.
Dabei wird das StableHLO-Prozess-Raster in process_groups
aufgeteilt, was so definiert ist:
cross_replica(replica_groups)
wennchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
wennchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
wennchannel_id > 0 and use_global_device_ids = true
.
Gehen Sie anschließend in jedem process_group
so vor:
results...@process[result_index] = exec(schedule)
für einen Binärbaumschedule
wobei:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
ist ein implementierungsdefinierter Binärbaum, dessen sequenzielle Durchlaufreihenfolgeto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
ist.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operands |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C5), (C6) |
(I2) | replica_groups |
Variadische Anzahl von eindimensionalen Tensorkonstanten vom Typ si64 |
(C1-C3) |
(I3) | channel_id |
Konstante vom Typ si64 |
(C4) |
(I4) | use_global_device_ids |
Konstante vom Typ i1 |
(C4) |
(I5) | computation |
Funktion | (C5) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C6-C7) |
Einschränkungen
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
ist definiert als:num_replicas
, wenncross_replica
verwendet wird.num_replicas
, wenncross_replica_and_partition
verwendet wird.num_processes
, wennflattened_ids
verwendet wird.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Wenn
use_global_device_ids = true
, dannchannel_id > 0
. - (C5)
computation
hat den Typ(tensor<E>, tensor<E>) -> (tensor<E>)
, wobeiis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantik
Teilt innerhalb jeder Prozessgruppe im StableHLO-Prozess-Raster die Werte der operands
-Tensoren entlang von split_dimension
in Teile auf, verteilt die aufgeteilten Teile auf die Prozesse, verkettet die verteilten Teile entlang von concat_dimension
und erzeugt results
-Tensoren.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, was so definiert ist:
cross_replica(replica_groups)
, wennchannel_id <= 0
.cross_partition(replica_groups)
wennchannel_id > 0
.
Gehen Sie anschließend in jedem process_group
so vor:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
für allesender
inprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
, wobeireceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operands |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C1–C3), (C9) |
(I2) | split_dimension |
Konstante vom Typ si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
Konstante vom Typ si64 |
(C3), (C9) |
(I4) | split_count |
Konstante vom Typ si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
2-dimensionale Tensorkonstante vom Typ si64 |
(C5-C8) |
(I6) | channel_id |
Konstante vom Typ si64 |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C9) |
Einschränkungen
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
wird so definiert:num_replicas
, wenncross_replica
verwendet wird.num_partitions
, wenncross_partition
verwendet wird.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
, außer wennsplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
und
Semantik
Führt die elementweise UND-Verknüpfung der beiden Tensoren lhs
und rhs
durch und erzeugt einen result
-Tensor. Gehen Sie je nach Elementtyp so vor:
- Für boolesche Werte: Logisches AND.
- Für Ganzzahlen: Bitweises AND.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor des booleschen oder Ganzzahltyps | (C1) |
(I2) | rhs |
Tensor des booleschen oder Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des booleschen oder Ganzzahltyps | (C1) |
Einschränkungen
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantik
Führt eine elementweise atan2-Operation auf den Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
atan2
von IEEE-754. - Für komplexe Zahlen: komplexer Atan2
- Für quantisierte Typen:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
(I2) | rhs |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantik
Berechnet die Gradienten mehrerer Eingaben von batch_norm_training
, die von grad_output
zurückgepropagiert werden, und erzeugt grad_operand
-, grad_scale
- und grad_offset
-Tensoren. Formeller ausgedrückt kann dieser Vorgang als Zerlegung in vorhandene StableHLO-Vorgänge mit der folgenden Python-Syntax ausgedrückt werden:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Bei quantisierten Typen wird dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C1-C3), (C5) |
(I2) | scale |
Eindimensionaler Tensor vom Typ „Gleitkomma“ oder „pro Tensor quantisiert“ | (C2), (C4), (C5) |
(I3) | mean |
1-dimensionaler Tensor eines Gleitkomma- oder pro Tensor quantisierten Typs | (C2), (C4) |
(I4) | variance |
Eindimensionaler Tensor vom Typ „Gleitkomma“ oder „pro Tensor quantisiert“ | (C2), (C4) |
(I5) | grad_output |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C2), (C3) |
(I6) | epsilon |
Konstante vom Typ f32 |
|
(I7) | feature_index |
Konstante vom Typ si64 |
(C1), (C5) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
grad_operand |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C2), (C3) |
grad_scale |
Eindimensionaler Tensor vom Typ „Gleitkomma“ oder „pro Tensor quantisiert“ | (C2), (C4) |
grad_offset |
Eindimensionaler Tensor vom Typ „Gleitkomma“ oder „pro Tensor quantisiert“ | (C2), (C4) |
Einschränkungen
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
undgrad_offset
haben dieselbebaseline_element_type
. - (C3)
operand
,grad_output
undgrad_operand
haben dieselbe Form. - (C4)
scale
,mean
,variance
,grad_scale
undgrad_offset
haben die gleiche Form. - (C5)
size(scale) = dim(operand, feature_index)
.
Beispiele
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantik
Normalisiert den operand
-Tensor in allen Dimensionen mit Ausnahme der feature_index
-Dimension und erzeugt einen result
-Tensor. Formaler kann dieser Vorgang als Zerlegung vorhandener StableHLO-Vorgänge unter Verwendung der Python-Syntax so ausgedrückt werden:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Bei quantisierten Typen wird dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1–C7) |
(I2) | scale |
Eindimensionaler Tensor vom Typ „Gleitkomma“ oder „pro Tensor quantisiert“ | (C2), (C3) |
(I3) | offset |
1-dimensionaler Tensor eines Gleitkomma- oder pro Tensor quantisierten Typs | (C2), (C4) |
(I4) | mean |
1-dimensionaler Tensor eines Gleitkomma- oder pro Tensor quantisierten Typs | (C5) |
(I5) | variance |
1-dimensionaler Tensor eines Gleitkomma- oder pro Tensor quantisierten Typs | (C2), (C6) |
(I6) | epsilon |
Konstante vom Typ f32 |
|
(I7) | feature_index |
Konstante vom Typ si64 |
(C1), (C3–C6) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C2), (C7) |
Einschränkungen
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
undresult
haben dieselbebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantik
Mittelwert und Varianz werden für alle Dimensionen außer der feature_index
-Dimension berechnet. Der operand
-Tensor wird normalisiert, wodurch die Tensoren output
, batch_mean
und batch_var
entstehen. Formeller ausgedrückt kann dieser Vorgang als Zerlegung in vorhandene StableHLO-Vorgänge mit der Python-Syntax so ausgedrückt werden:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Bei quantisierten Typen wird dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1) |
(I2) | scale |
1-dimensionaler Tensor mit Gleitkommazahlen oder pro Tensor quantisiert | (C2), (C3) |
(I3) | offset |
1-dimensionaler Tensor mit Gleitkommazahlen oder pro Tensor quantisiert | (C2), (C4) |
(I4) | epsilon |
Konstante vom Typ f32 |
(C1), (C3-C6) |
(I5) | feature_index |
Konstante vom Typ si64 |
(C1), (C3–C6) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
output |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C7) |
batch_mean |
1-dimensionaler Tensor mit Gleitkommazahlen oder pro Tensor quantisiert | (C2), (C5) |
batch_var |
1-dimensionaler Tensor mit Gleitkommazahlen oder pro Tensor quantisiert | (C2), (C6) |
Einschränkungen
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
undoutput
haben dieselbebaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Beispiele
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantik
Führt eine Bitcast-Operation für den operand
-Tensor aus und erzeugt einen result
-Tensor, bei dem die Bits des gesamten operand
-Tensors mit dem Typ des result
-Tensors neu interpretiert werden.
Formeller gesprochener Text bei E = element_type(operand)
, E' = element_type(result)
und R = rank(operand)
:
- Wenn
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Wenn
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Wenn
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
gibt die In-Memory-Darstellung eines bestimmten Werts zurück. Das Verhalten ist implementierungsabhängig, da die genaue Darstellung von Tensoren und Elementtypen implementierungsabhängig ist.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1-C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1-C2) |
Einschränkungen
- (C1) Für
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
undR = rank(operand)
:- Wenn
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Wenn
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
für alle0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Wenn
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
für alle0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Wenn
- (C2) Wenn
is_complex(operand) or is_complex(result)
, dannis_complex(operand) and is_complex(result)
.
Beispiele
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantik
Erweitert die Dimensionen und/oder den Rang eines Eingabetensors, indem die Daten im operand
-Tensor dupliziert werden. Es wird ein result
-Tensor erzeugt. Formal gilt: result[result_index] = operand[operand_index]
, wo für alle d
in axes(operand)
Folgendes gilt:
operand_index[d] = 0
, wenndim(operand, d) = 1
.- Andernfalls
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1–C2), (C5–C6) |
(I2) | broadcast_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2 bis C6) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1), (C3), (C5-C6) |
Einschränkungen
- (C1)
element_type(result)
ist gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, wobei sichquantization_dimension(operand)
,scales(operand)
undzero_points(operand)
vonquantization_dimension(result)
,scales(result)
undzero_points(result)
unterscheiden können.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Für alle
d
inaxes(operand)
:dim(operand, d) = 1
oderdim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Wenn
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Wenn
dim(operand, quantization_dimension(operand)) = 1
, dannscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Beispiele
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
Supportanfrage
Semantik
Erzeugt die Ausgabe aus der Ausführung genau einer Funktion aus branches
, je nach Wert von index
. Formeller ausgedrückt: result = selected_branch()
wobei:
selected_branch = branches[index]
, wenn0 <= index < size(branches)
.- Andernfalls
selected_branch = branches[-1]
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | index |
Nulldimensionaler Tensor vom Typ si32 |
|
(I2) | branches |
Variadische Anzahl von Funktionen | (C1-C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens | (C4) |
Einschränkungen
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Beispiele
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Semantik
Führt eine elementweise Kubikwurzel-Operation auf dem operand
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
rootn(x, 3)
aus IEEE-754. - Für komplexe Zahlen: komplexe kubische Wurzel.
- Für quantisierte Typen:
dequantize_op_quantize(cbrt, operand, type(result))
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
Ceil
Semantik
Führt eine elementweise Aufrundung des operand
-Tensors aus und erzeugt einen result
-Tensor.
Implementiert den roundToIntegralTowardPositive
-Vorgang aus der IEEE-754-Spezifikation. Bei quantisierten Typen wird dequantize_op_quantize(ceil, operand, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semantik
Berechnet die Cholesky-Zerlegung einer Reihe von Matrizen.
Formeller ausgedrückt: Für alle i
in index_space(result)
ist result[i0, ..., iR-3, :, :]
eine Cholesky-Zerlegung von a[i0, ..., iR-3, :, :]
in Form einer unter- oder obertriangularen Matrix (wenn lower
true
oder false
ist).
Die Ausgabewerte im gegenüberliegenden Dreieck, d. h. im strikten oberen oder unteren Dreieck, sind implementierungsabhängig.
Wenn es i
gibt, bei dem die Eingabematrix keine hermitische positiv-definierte Matrix ist, ist das Verhalten nicht definiert.
Bei quantisierten Typen wird dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | a |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1-C3) |
(I2) | lower |
Nulldimensionale Tensorkonstante vom Typ i1 |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Beispiele
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
einschränken
Semantik
Hiermit wird jedes Element des operand
-Tensors zwischen einem Mindest- und einem Höchstwert begrenzt und ein result
-Tensor erzeugt. Formal gilt: result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
, wobei min_element = rank(min) = 0 ? min[] : min[result_index]
und max_element = rank(max) = 0 ? max[] : max[result_index]
sind. Bei quantisierten Typen wird dequantize_op_quantize(clamp, min, operand, max, type(result))
ausgeführt.
Die Sortierung komplexer Zahlen führt zu überraschenden Ergebnissen. Daher werden wir die Unterstützung für komplexe Zahlen für diesen Vorgang in Zukunft entfernen (#560).
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | min |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C3) |
(I2) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1–C4) |
(I3) | max |
Tensor oder pro Tensor quantisierter Tensor | (C2), (C3) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C4) |
Einschränkungen
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantik
Senden Sie innerhalb jeder Prozessgruppe im StableHLO-Prozessraster den Wert des Tensors operand
vom Quellprozess an die Zielprozesse und erzeugen Sie einen result
-Tensor.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, was so definiert ist:
cross_replica(replica_groups)
, wennchannel_id <= 0
.cross_partition(replica_groups)
, wennchannel_id > 0
.
Danach ergibt sich result@process
aus:
operand@process_groups[i, 0]
, wenn es einei
gibt, sodass sich der Prozess inprocess_groups[i]
befindet.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
Andernfalls
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C3) |
(I2) | replica_groups |
Variadische Anzahl von eindimensionalen Tensorkonstanten vom Typ si64 |
(C1), (C2) |
(I3) | channel_id |
Konstante vom Typ si64 |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C3) |
Einschränkungen
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, wobeiN
so definiert ist:num_replicas
, wenncross_replica
verwendet wird.num_partitions
, wenncross_partition
verwendet wird.
- (C3)
type(result) = type(operand)
.
Beispiele
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantik
Sendet innerhalb jeder Prozessgruppe im StableHLO-Prozess-Raster den Wert des operand
-Tensors vom Quellprozess an den Zielprozess und erzeugt einen result
-Tensor.
Dabei wird das StableHLO-Prozess-Raster in process_groups
aufgeteilt, was so definiert ist:
cross_replica(source_target_pairs)
, wennchannel_id <= 0
.cross_partition(source_target_pairs)
wennchannel_id > 0
.
Danach ergibt sich result@process
aus:
operand@process_groups[i, 0]
, wenn eini
vorhanden ist, sodassprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
Andernfalls
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C5) |
(I2) | source_target_pairs |
2-dimensionale Tensorkonstante vom Typ si64 |
(C1-C4) |
(I3) | channel_id |
Konstante vom Typ si64 |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, wobeiN
so definiert ist:num_replicas
, wenncross_replica
verwendet wird.num_partitions
, wenncross_partition
verwendet wird.
- (C5)
type(result) = type(operand)
.
Beispiele
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
vergleichen
Semantik
Führt einen elementweisen Vergleich der Tensoren lhs
und rhs
gemäß comparison_direction
und compare_type
durch und erzeugt einen result
-Tensor.
Die Werte von comparison_direction
und compare_type
haben die folgende Semantik:
Für boolesche und Ganzzahlelementtypen:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Für Gleitkommaelementtypen mit compare_type = FLOAT
implementiert der Operator die folgenden IEEE-754-Operationen:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Für Gleitkomma-Elementtypen mit compare_type = TOTALORDER
wird die Kombination aus totalOrder
- und compareQuietEqual
-Vorgängen aus IEEE-754 verwendet.
Bei komplexen Elementtypen wird der lexikografische Vergleich von (real, imag)
-Paaren anhand der angegebenen comparison_direction
- und compare_type
-Werte durchgeführt.
Die Sortierung komplexer Zahlen führt zu überraschenden Ergebnissen. Daher werden wir in Zukunft die Unterstützung für komplexe Zahlen entfernen, wenn comparison_direction
= GE
, GT
, LE
oder LT
ist (#560).
Für quantisierte Typen wird dequantize_compare(lhs, rhs,
comparison_direction)
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder pro Tensor quantisierter Tensor | (C1-C3) |
(I2) | rhs |
Tensor oder pro Tensor quantisierter Tensor | (C1-C2) |
(I3) | comparison_direction |
enum von EQ , NE , GE , GT , LE und LT |
|
(I4) | compare_type |
enum von FLOAT , TOTALORDER , SIGNED und UNSIGNED |
(C3) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Typ „Boolescher Wert“ | (C2) |
Einschränkungen
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
ist definiert als:SIGNED
, wennis_signed_integer(element_type(lhs))
.UNSIGNED
wennis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oderTOTALORDER
, wennis_float(element_type(lhs))
.FLOAT
, wennis_complex(element_type(lhs))
.
Beispiele
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
Komplex
Semantik
Führt eine elementweise Umwandlung in einen komplexen Wert aus einem Paar reeller und imaginärer Werte, lhs
und rhs
, aus und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor vom Typ f32 oder f64 |
(C1-C3) |
(I2) | rhs |
Tensor vom Typ f32 oder f64 |
(C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des komplexen Typs | (C2), (C3) |
Einschränkungen
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
hat den Typcomplex<E>
, wobeiE = element_type(lhs)
.
Beispiele
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
Zusammengesetzt
Semantik
Kapselt einen Vorgang ein, der aus anderen StableHLO-Vorgängen besteht, wobei inputs
und composite_attributes
als Eingabe und results
als Ausgabe dienen. Die Semantik der Operation wird durch das Attribut decomposition
implementiert. Die composite
-Operation kann durch ihre Dekomposition ersetzt werden, ohne die Programmsemantik zu ändern. Wenn die Einbettung der Dekomposition nicht dieselbe Operatorsemantik bietet, verwenden Sie custom_call
.
Das Feld version
(Standardwert: 0
) gibt an, wann sich die Semantik eines Composites ändert.
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | inputs |
variadische Anzahl von Werten |
(I2) | name |
Konstante vom Typ string |
(I3) | composite_attributes |
Attributwörterbuch |
(I4) | decomposition |
Konstante vom Typ string |
(I5) | version |
Konstante vom Typ si32 |
Ausgaben
Name | Typ |
---|---|
results |
Variadische Anzahl von Werten |
Einschränkungen
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Beispiele
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semantik
Verkettet inputs
entlang der Dimension dimension
in derselben Reihenfolge wie die angegebenen Argumente und erzeugt einen result
-Tensor. Formeller ausgedrückt:
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, wobei:
id = d0 + ... + dk-1 + kd
.d
ist gleichdimension
undd0
sind died
. Dimensionsgrößen voninputs
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C1-C6) |
(I2) | dimension |
Konstante vom Typ si64 |
(C2), (C4), (C6) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C5 bis C6) |
Einschränkungen
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
, mit Ausnahme vondim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, mit folgenden Ausnahmen:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Beispiele
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
Konstante
Semantik
Erzeugt einen output
-Tensor aus einer Konstante value
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | value |
Konstante | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
output |
Tensor oder quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
type(value) = type(output)
.
Beispiele
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
eine Conversion ausführen
Semantik
Führt eine elementweise Konvertierung von einem Elementtyp in einen anderen mit dem operand
-Tensor durch und erzeugt einen result
-Tensor.
Bei Conversions vom Typ boolean-to-any-supported-type wird der Wert false
in null und der Wert true
in eins konvertiert. Bei Konvertierungen vom Typ any-supported-type-to-boolean wird ein Nullwert in false
und Werte, die nicht null sind, in true
konvertiert. Unten finden Sie weitere Informationen dazu, wie das bei komplexen Typen funktioniert.
Bei Conversions vom Typ Ganzzahl zu Ganzzahl, Ganzzahl zu Gleitkomma oder Gleitkomma zu Gleitkomma wird der Ergebniswert genau so dargestellt, wie der Quellwert im Zieltyp dargestellt werden kann. Andernfalls ist das Verhalten noch nicht festgelegt (#180).
Bei Conversions vom Typ Gleitkomma in Ganzzahl wird der Bruchteil abgeschnitten. Wenn der abgeschnittene Wert nicht im Zieltyp dargestellt werden kann, ist das Verhalten noch nicht festgelegt (#180).
Bei der Konvertierung von komplexen Zahlen in komplexe Zahlen werden reelle und imaginäre Teile wie bei der Konvertierung von Gleitkommazahlen in Gleitkommazahlen konvertiert.
Bei Umwandlungen vom Typ Komplex in einen beliebigen anderen Typ und Beliebiger anderer Typ in komplex wird der Imaginäre Wert der Quelle bzw. des Ziels ignoriert. Die Umwandlung des reellen Teils erfolgt nach den Regeln für Gleitkommaumwandlungen.
Im Prinzip könnte dieser Vorgang die Entquantisierung (Umwandlung von quantisierten Tensoren in reguläre Tensoren), die Quantisierung (Umwandlung von regulären Tensoren in quantisierte Tensoren) und die Neuquantisierung (Umwandlung zwischen quantisierten Tensoren) ausdrücken. Derzeit haben wir jedoch spezielle Vorgänge dafür: uniform_dequantize
für den ersten Anwendungsfall und uniform_quantize
für den zweiten und dritten Anwendungsfall. In Zukunft werden diese beiden Vorgänge möglicherweise in convert
zusammengeführt (#1576).
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor | (C1) |
Einschränkungen
- (C1)
shape(operand) = shape(result)
.
Beispiele
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
Faltung
Semantik
Hiermit werden Punktprodukte zwischen Fenstern von lhs
und Scheiben von rhs
berechnet und result
ausgegeben. Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result
aus lhs
und rhs
berechnet werden.
Formeller ausgedrückt: Betrachten Sie die folgenden Umformulierungen der Eingaben in Bezug auf lhs
, um Zeitfenster für lhs
angeben zu können:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Für diese Neuausrichtung werden die folgenden Hilfsfunktionen verwendet:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
, wobeij[d] = i[permutation[d]]
.
Wenn feature_group_count = 1
und batch_group_count = 1
, dann für alle output_spatial_index
in index_space(dim(result, output_spatial_dimensions...))
result[result_shape(:, output_spatial_index, :)] = dot_product
, wobei:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Dieses Feature wird anscheinend noch nicht verwendet, deshalb planen wir, es in Zukunft zu entfernen (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Wenn feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Wenn batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Für quantisierte Typen wird dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
ausgeführt.
Für hybrid quantisierte Typen wird hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
Tensor oder quantisierter Tensor | (C1), (C14–C16), (C25), (C27–C29), (C31–C34) |
(I3) | window_strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2-C3), (C25) |
(I4) | padding |
2-dimensionale Tensorkonstante vom Typ si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Eindimensionale Tensorkonstante vom Typ i1 |
(C9) |
(I8) | input_batch_dimension |
Konstante vom Typ si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
Konstante vom Typ si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
Konstante vom Typ si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
Konstante vom Typ si64 |
(C15-C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
Konstante vom Typ si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
Konstante vom Typ si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
Konstante vom Typ si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
Konstante vom Typ si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
Variadische Anzahl von DEFAULT -, HIGH - und HIGHEST -Enum-Typen |
(C24) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C25–C28), (C30), (C32–34) |
Einschränkungen
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Gegeben ist
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Gegeben ist
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Gegeben ist
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
wird so definiert:dim(lhs, input_batch_dimension) / batch_group_count
, wennresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
, wennresult_dim = output_feature_dimension
.num_windows
, andernfalls gilt:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Wenn für den Vorgang quantisierte Tensoren verwendet werden:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Wenn
is_per_axis_quantized(rhs)
, dannquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Wenn
is_per_axis_quantized(result)
, dannquantization_dimension(result) = output_feature_dimension
. - Wenn
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Wenn
is_per_tensor_quantized(rhs)
, dannis_per_tensor_quantized(result)
. - Wenn
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Beispiele
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
Kosinus
Semantik
Führt eine elementweise Kosinus-Operation auf den Tensor operand
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
cos
von IEEE-754. - Bei komplexen Zahlen: komplexer Kosinus
- Für quantisierte Typen:
dequantize_op_quantize(cosine, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
leading_zeros_count
Semantik
Führt eine elementweise Zählung der Anzahl der führenden Nullbits im operand
-Tensor aus und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Ganzzahltensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahltensor | (C1) |
Einschränkungen
- (C1)
type(operand) = type(result)
.
Beispiele
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantik
Kapselt einen implementierungsdefinierten Vorgang call_target_name
ein, der inputs
und called_computations
als Eingaben nimmt und results
als Ausgabe liefert. has_side_effect
, backend_config
und api_version
können verwendet werden, um zusätzliche implementierungsdefinierte Metadaten anzugeben.
Derzeit enthält dieser Vorgang eine ziemlich unorganisierte Sammlung von Metadaten, die die organische Entwicklung des entsprechenden Vorgangs im XLA-Compiler widerspiegelt. Wir planen, diese Metadaten in Zukunft zu vereinheitlichen (#741).
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | inputs |
variadische Anzahl von Werten |
(I2) | call_target_name |
Konstante vom Typ string |
(I3) | has_side_effect |
Konstante vom Typ i1 |
(I4) | backend_config |
Konstante vom Typ string oder Attributwörterbuch |
(I5) | api_version |
Konstante vom Typ si32 |
(I6) | called_computations |
variadische Anzahl von Konstanten vom Typ string |
Ausgaben
Name | Typ |
---|---|
results |
Variadische Anzahl von Werten |
Beispiele
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
Dividieren
Semantik
Führt eine elementweise Division des Dividenden-Tensors lhs
durch den Divisor-Tensor rhs
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Ganzzahlen: Ganzzahldivision, die den algebraischen Quotienten mit verworfenen Bruchteilen erzeugt.
- Für Gleitkommazahlen:
division
von IEEE-754. - Für komplexe Zahlen: komplexe Division.
- Für quantisierte Typen:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantik
Berechnet Punktprodukte zwischen Scheiben von lhs
und Scheiben von rhs
und erzeugt einen result
-Tensor.
Formeller result[result_index] = dot_product
, wobei Folgendes gilt:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
wobeisize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
undsize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Für quantisierte Typen wird dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
ausgeführt.
Führt für hybrid quantisierte Typen hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
aus.
precision_config
steuert den Kompromiss zwischen Geschwindigkeit und Genauigkeit bei Berechnungen in Accelerator-Back-Ends. Dies kann einer der folgenden sein (derzeit ist die Semantik dieser Enum-Werte zu niedrig angegeben, aber wir planen, dies in #755 zu beheben):
DEFAULT
: Schnellste Berechnung, aber ungenaueste Näherung an die ursprüngliche Zahl.HIGH
: langsamere Berechnung, aber genauere Annäherung an die ursprüngliche Zahl.HIGHEST
: Die langsamste Berechnung, aber die genaueste Annäherung an die ursprüngliche Zahl.
Ein DotAlgorithm
definiert die Haupteigenschaften des Algorithmus, der zur Implementierung der Punktoperation verwendet wird, und definiert auch die Genauigkeit. Wenn die Attributfelder des Algorithmus festgelegt sind, muss precision_config
DEFAULT
sein. DotAlgorithms
haben keinen Standardwert, da die Standardparameter von der Implementierung definiert werden. Daher können alle Felder für den Punktalgorithmus auf None
gesetzt werden, um einen leeren Punktalgorithmus anzugeben, für den stattdessen der Wert precision_config
verwendet wird.
Zu den DotAlgorithm
-Feldern gehören:
lhs_precision_type
undrhs_precision_type
: die Genauigkeiten, auf die die rechte und linke Seite der Operation gerundet werden. Die Genauigkeitstypen sind unabhängig von den Speichertypen der Ein- und Ausgabe.accumulation_type
ist die für die Akkumulation verwendete Genauigkeit.lhs_component_count
,rhs_component_count
undnum_primitive_operations
werden angewendet, wenn ein Algorithmus die linke und/oder rechte Seite in mehrere Komponenten zerlegt und mehrere „primitive“ Punktprodukte auf diese Werte ausführt, in der Regel um eine höhere Genauigkeit zu emulieren (z. B. Der bfloat16-Datentyp für künstliche Intelligenz für Berechnungen mit höherer Genauigkeit: bf16_6x, tf32_3x usw.). Bei Algorithmen ohne Zerlegung sollten diese Werte auf1
gesetzt werden.allow_imprecise_accumulation
, um anzugeben, ob für einige Schritte eine Akkumulation mit niedrigerer Genauigkeit zulässig ist (z.B.CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Beispiele für DotAlgorithm
-Attribute:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Welche Kombinationen unterstützt werden, hängt von der Implementierung ab. Im Allgemeinen kann nicht garantiert werden, dass jeder Algorithmus von jedem Beschleunigertyp vom Nutzer des StableHLO unterstützt wird. Wenn ein bestimmter Algorithmus nicht unterstützt wird, sollte ein Fehler ausgegeben werden, anstatt auf eine Alternative zurückzugreifen. Die StableHLO-Überprüfung bietet eine Best-Effort-Überprüfung und verhindert Algorithmen, die bekanntermaßen auf keiner Hardware unterstützt werden.
Einige unterstützte Algorithmuswerte finden Sie unter xla_data.proto > Algorithm
. In Ticket 2483 wird der Plan zum Erstellen einer zentralen Dokumentation zu den unterstützten Algorithmen nach Backend beschrieben.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder pro Tensor quantisierter Tensor | (C5–C6), (C9–C10), (C12–C14), (C17–C18), (C20) |
(I2) | rhs |
Tensor oder quantisierter Tensor | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
Variadische Anzahl von DEFAULT -, HIGH - und HIGHEST -Enum-Typen |
(C11), (C21) |
(I8) | lhs_precision_type |
FloatType oder TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType oder TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType oder TensorFloat32 | (C21) |
(I11) | lhs_component_count |
Konstante vom Typ si32 |
(C21), (C22) |
(I12) | rhs_component_count |
Konstante vom Typ si32 |
(C21), (C23) |
(I13) | num_primitive_operations |
Konstante vom Typ si32 |
(C21), (C24) |
(I14) | allow_imprecise_accumulation |
Konstante vom Typ bool |
(C21) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C12), (C14), (C18-C20) |
Einschränkungen
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Wenn für den Vorgang quantisierte Tensoren verwendet werden:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Wenn
is_per_axis_quantized(rhs)
, dannquantization_dimension(rhs)
nicht inrhs_contracting_dimensions
. - Wenn
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Wenn
is_per_tensor_quantized(rhs)
, dannis_per_tensor_quantized(result)
. - Wenn
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Wenn
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Beispiele
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantik
Dieser Vorgang ist funktional mit der Operation broadcast_in_dim identisch. Die Ergebnisform wird jedoch dynamisch über output_dimensions
angegeben.
Der Vorgang akzeptiert auch die optionalen Attribute known_expanding_dimensions
und known_nonexpanding_dimensions
, um statisches Wissen über das Maximierungsverhalten von Dimensionen auszudrücken.
Wenn Sie keine Angabe machen, wird davon ausgegangen, dass alle Dimensionen möglicherweise erweitert werden.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Eindimensionaler Tensor vom Typ „Ganzzahl“ | (C7) |
(I3) | broadcast_dimensions |
1-dimensionaler konstanter Tensor vom Typ „Ganzzahl“ | (C2-C6) |
(I4) | known_expanding_dimensions |
1-dimensionaler konstanter Tensor vom Typ „Ganzzahl“ | (C8-C9) |
(I5) | known_nonexpanding_dimensions |
1-dimensionaler Konstantentensor vom Ganzzahltyp | (C8-C9) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1), (C3), (C5-C7) |
Einschränkungen
- (C1)
element_type(result)
ist gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, wobei sichquantization_dimension(operand)
,scales(operand)
undzero_points(operand)
vonquantization_dimension(result)
,scales(result)
undzero_points(result)
unterscheiden können.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Für alle
d
inaxes(operand)
:dim(operand, d) = 1
oderdim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Wenn
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Wenn
dim(operand, quantization_dimension(operand)) = 1
, dannscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Beispiele
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantik
Funktional ist dieser Vorgang mit dem Vorgang Faltung identisch. Das Padding wird jedoch dynamisch über padding
angegeben.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
Tensor oder quantisierter Tensor | (C1), (C14–C16), (C26–C28), (C30–C33) |
(I3) | padding |
2-dimensionaler Tensor vom Typ „Ganzzahl“ | (C4) |
(I4) | window_strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2–C3) |
(I5) | lhs_dilation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C5 bis C6) |
(I6) | rhs_dilation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C7-C8) |
(I7) | window_reversal |
Eindimensionale Tensorkonstante vom Typ i1 |
(C9) |
(I8) | input_batch_dimension |
Konstante vom Typ si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
Konstante vom Typ si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C12), (C13) |
(I11) | kernel_input_feature_dimension |
Konstante vom Typ si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
Konstante vom Typ si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C17-C18) |
(I14) | output_batch_dimension |
Konstante vom Typ si64 |
(C20) |
(I15) | output_feature_dimension |
Konstante vom Typ si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C19-C20) |
(I17) | feature_group_count |
Konstante vom Typ si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
Konstante vom Typ si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
Variadische Anzahl von DEFAULT -, HIGH - und HIGHEST -Enum-Typen |
(C24) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C25-C27), (C29), (C31-C33) |
Einschränkungen
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Gegeben ist
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Gegeben ist
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Gegeben ist
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
wird so definiert:dim(lhs, input_batch_dimension) / batch_group_count
, wennresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
, wennresult_dim = output_feature_dimension
.num_windows
, andernfalls gilt:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Wenn für den Vorgang quantisierte Tensoren verwendet werden:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Wenn
is_per_axis_quantized(rhs)
, dannquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Wenn
is_per_axis_quantized(result)
, dannquantization_dimension(result) = output_feature_dimension
. - Wenn
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Wenn
is_per_tensor_quantized(rhs)
, dannis_per_tensor_quantized(result)
. - Wenn
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Beispiele
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantik
Dieser Vorgang ist funktional mit dem Vorgang gather identisch. Dabei wird slice_sizes
dynamisch als Wert angegeben.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
Tensor des Ganzzahltyps | (C2), (C3), (C13) |
(I3) | slice_sizes |
Eindimensionaler Tensor vom Typ „Ganzzahl“ | (C8), (C11–C13) |
(I4) | offset_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Eindimensionale Tensorkonstante vom Typ si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
Konstante vom Typ si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
Konstante vom Typ i1 |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C5), (C13-C14) |
Einschränkungen
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
wobei:batch_dim_sizes = shape(start_indices)
, mit der Ausnahme, dass die Dimensionsgröße vonstart_indices
, dieindex_vector_dim
entspricht, nicht enthalten ist.offset_dim_sizes = shape(slice_sizes)
, mit der Ausnahme, dass die Dimensionsgrößen inslice_sizes
, diecollapsed_slice_dims
entsprechen, nicht enthalten sind.- Bei
combine
wirdbatch_dim_sizes
auf Achsen platziert, diebatch_dims
entsprechen, undoffset_dim_sizes
auf Achsen, dieoffset_dims
entsprechen.
- (C14)
element_type(operand) = element_type(result)
.
Beispiele
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantik
Dieser Vorgang ist funktional mit der Operation iota identisch. Die Ergebnisform wird jedoch dynamisch über output_shape
angegeben.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | output_shape |
Eindimensionaler Tensor vom Typ „Ganzzahl“ | (C1), (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor von Ganzzahl, Gleitkomma oder komplexer Typ oder quantisierter Tensor pro Tensor | (C2) |
Einschränkungen
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Beispiele
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantik
Dieser Vorgang ist funktional mit dem Vorgang pad identisch, bei dem edge_padding_low
, edge_padding_high
und interior_padding
jedoch dynamisch als Werte angegeben werden.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Nulldimensionaler Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I3) | edge_padding_low |
Eindimensionaler Tensor vom Typ „Ganzzahl“ | (C1), (C4) |
(I4) | edge_padding_high |
Eindimensionaler Tensor vom Typ „Ganzzahl“ | (C1), (C4) |
(I5) | interior_padding |
Eindimensionaler Tensor vom Typ „Ganzzahl“ | (C2-C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C3–C6) |
Einschränkungen
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Beispiele
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantik
Funktional ist dieser Vorgang mit dem Vorgang Umformung identisch, aber die Ergebnisform wird dynamisch über output_shape
angegeben.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1-C3) |
(I2) | output_shape |
Eindimensionaler Tensor vom Typ „Ganzzahl“ | (C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1-C4) |
Einschränkungen
- (C1)
element_type(result)
ist gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, mit der Ausnahme, dass sichquantization_dimension(operand)
undquantization_dimension(result)
möglicherweise unterscheiden.
- (C2)
size(operand) = size(result)
. - (C3) Wenn
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Beispiele
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantik
Extrahiert ein Segment aus dem operand
mit dynamisch berechneten Startindexen und erzeugt einen result
-Tensor. start_indices
enthalten die Startindexe des Segments für jede Dimension, die einer möglichen Anpassung unterliegt, und slice_sizes
die Größen des Segments für jede Dimension. Formeller ausgedrückt:
result[result_index] = operand[operand_index]
wobei:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C2), (C4) |
(I2) | start_indices |
Variadische Anzahl von 0-dimensionalen Tensoren vom Typ „Ganzzahl“ | (C2), (C3) |
(I3) | slice_sizes |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C5) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C5) |
Einschränkungen
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Beispiele
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantik
Erzeugt einen result
-Tensor, der dem operand
-Tensor entspricht, mit der Ausnahme, dass das Segment, das bei start_indices
beginnt, mit den Werten in update
aktualisiert wird.
Formal ist result[result_index]
so definiert:
update[update_index]
wenn0 <= update_index < shape(update)
wobei:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- Andernfalls
operand[result_index]
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1–C4), (C6) |
(I2) | update |
Tensor oder pro Tensor quantisierter Tensor | (C2), (C3), (C6) |
(I3) | start_indices |
Variadische Anzahl von nulldimensionalen Tensoren vom Ganzzahltyp | (C4), (C5) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Beispiele
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
Exponentialfunktionen
Semantik
Führt eine elementweise exponentielle Operation auf den operand
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
exp
von IEEE-754. - Für komplexe Zahlen: komplexe Exponentialfunktion
- Für quantisierte Typen:
dequantize_op_quantize(exponential, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantik
Führt eine elementweise Exponential-Subtraktion-Eins-Operation auf den operand
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
expm1
von IEEE-754. - Für komplexe Zahlen: komplexe Exponentialzahlen minus eins.
- Für quantisierte Typen:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Semantik
Führt die Vorwärts- und Rückwärts-Fouriertransformation für reelle und komplexe Eingaben/Ausgaben aus.
fft_type
ist einer der folgenden Werte:
FFT
: Vorwärts-FFT von komplex zu komplex.IFFT
: Inverse komplexe-zu-komplexe FFT.RFFT
: Vorwärts-FFT (reell-zu-komplex).IRFFT
: Inverse reell-zu-komplexe FFT (d. h. nimmt komplexe Werte an, gibt reelle Werte zurück).
Formaler wird mit der Funktion fft
, die eindimensionale Tensoren komplexer Typen als Eingabe verwendet, eindimensionale Tensoren desselben Typs wie die Ausgabe erzeugt und die diskrete Fourier-Transformation berechnet:
Bei fft_type = FFT
wird result
als Endergebnis einer Reihe von L-Berechnungen definiert, bei denen L = size(fft_length)
. Zum Beispiel für L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Außerdem gibt es die Funktion ifft
mit derselben Typsignatur, die das Inverse von fft
berechnet:
Für fft_type = IFFT
ist result
als Kehrwert der Berechnungen für fft_type = FFT
definiert. Beispiel für L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Darüber hinaus erzeugt die Funktion rfft
, die eindimensionale Tensoren von Gleitkommatypen verwendet, eindimensionale Tensoren komplexer Typen derselben Gleitkommasemantik und funktioniert so:
rfft(real_operand) = truncated_result
wobeicomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
Wenn die diskrete Fourier-Transformation für reelle Operanden berechnet wird, definieren die ersten N/2 + 1
Elemente des Ergebnisses eindeutig den Rest des Ergebnisses. Daher wird das Ergebnis von rfft
abgeschnitten, um die Berechnung redundanter Elemente zu vermeiden.
Bei fft_type = RFFT
wird result
als Endergebnis einer Reihe von L-Berechnungen definiert, bei denen L = size(fft_length)
. Beispiel für L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Schließlich die Funktion irfft
mit derselben Typsignatur, die das Inverse von rfft
berechnet:
Bei fft_type = IRFFT
wird result
als Kehrwert der Berechnungen für fft_type = RFFT
definiert. Beispiel für L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor vom Gleitkomma- oder komplexen Typ | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
Aufzählung von FFT , IFFT , RFFT und IRFFT |
(C2), (C5) |
(I3) | fft_length |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C3), (C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Gleitkomma- oder komplexen Typ | (C2), (C4), (C5) |
Einschränkungen
- (C1)
size(fft_length) <= rank(operand)
. - (C2) Die Beziehung zwischen den Elementtypen
operand
undresult
variiert:- Wenn
fft_type = FFT
,element_type(operand)
undelement_type(result)
denselben komplexen Typ haben. - Wenn
fft_type = IFFT
,element_type(operand)
undelement_type(result)
denselben komplexen Typ haben. - Wenn
fft_type = RFFT
,element_type(operand)
ein Gleitkommatyp ist undelement_type(result)
ein komplexer Typ mit derselben Gleitkommasemantik. - Wenn
fft_type = IRFFT
,element_type(operand)
ein komplexer Typ undelement_type(result)
ein Gleitkommatyp mit derselben Gleitkommasemantik ist.
- Wenn
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Wenn unter
operand
undresult
ein Tensorreal
eines Gleitkommatyps vorhanden ist, dannshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, mit folgenden Ausnahmen:- Wenn
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Wenn
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Wenn
Beispiele
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
Boden
Semantik
Führt eine elementweise Untergrenze des operand
-Tensors aus und erzeugt einen result
-Tensor.
Implementiert den roundToIntegralTowardNegative
-Vorgang aus der IEEE-754-Spezifikation. Bei quantisierten Typen wird dequantize_op_quantize(floor, operand, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
zusammentragen
Semantik
Erfasst Scheiben aus dem operand
-Tensor anhand der in start_indices
angegebenen Abweichungen und erzeugt einen result
-Tensor.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result
auf Elemente in operand
abgebildet werden. Im Diagramm werden einige result
-Beispielindizes ausgewählt und es wird ausführlich erläutert, welchen operand
-Indizes sie entsprechen.
Formeller result[result_index] = operand[operand_index]
, wobei Folgendes gilt:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
wird so definiert:start_indices[bi0, ..., :, ..., biN]
, wobeibi
einzelne Elemente inbatch_index
sind und:
in den Indexindex_vector_dim
eingefügt wird, wennindex_vector_dim
<rank(start_indices)
ist.- Andernfalls
[start_indices[batch_index]]
.
- Für
d_operand
inaxes(operand)
:full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
wennd_operand = start_index_map[d_start]
.- Andernfalls
full_start_index[d_operand] = 0
.
- Für
d_operand
inaxes(operand)
:full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
wennd_operand = operand_batching_dims[i_batching]
undd_start = start_indices_batching_dims[i_batching]
.- Andernfalls
full_batching_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, wobeioi
einzelne Elemente inoffset_index
sind und0
an den Indexencollapsed_slice_dims
undoperand_batching_dims
eingefügt wird.operand_index = full_start_index + full_batching_index + full_offset_index
.
Wenn indices_are_sorted
true
ist, kann die Implementierung davon ausgehen, dass start_indices
nach start_index_map
sortiert ist. Andernfalls ist das Verhalten nicht definiert. Formeller ausgedrückt: Für alle i1 < i2
von indices(result)
,
full_start_index(i1) <= full_start_index(i2)
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
Ganzzahltensor | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C6–C9), (C22) |
(I5) | operand_batching_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C13-C17) |
(I7) | start_index_map |
Eindimensionale Tensorkonstante vom Typ si64 |
(C3), (C18–C19) |
(I8) | index_vector_dim |
Konstante vom Typ si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Eindimensionale Tensorkonstante vom Typ si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
Konstante vom Typ i1 |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C5), (C22-C23) |
Einschränkungen
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
wobei:batch_dim_sizes = shape(start_indices)
, mit der Ausnahme, dass die Dimensionsgröße vonstart_indices
, dieindex_vector_dim
entspricht, nicht enthalten ist.offset_dim_sizes = slice_sizes
mit dem Unterschied, dass die Dimensionsgrößen inslice_sizes
, diecollapsed_slice_dims
undoperand_batching_dims
entsprechen, nicht enthalten sind.combine
platziertbatch_dim_sizes
an den Achsenbatch_dims
undoffset_dim_sizes
an die Achsen, dieoffset_dims
entsprechen.
- (C23)
element_type(operand) = element_type(result)
.
Beispiele
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantik
Gibt die Größe der angegebenen dimension
der operand
zurück. Formeller ausgedrückt:
result = dim(operand, dimension)
. Die Semantik bezieht sich nur auf die Formkomponente des Typs. Der Elementtyp kann beliebig sein.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1) |
(I2) | dimension |
Konstante vom Typ si64 |
(C1) |
Ausgaben
Name | Typ |
---|---|
result |
Nulldimensionaler Tensor vom Typ si32 |
Einschränkungen
- (C1)
0 <= dimension < rank(operand)
.
Beispiele
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantik
Hiermit wird das Element an der Position index
des operand
-Tupels extrahiert und eine result
erstellt. Formeller: result = operand[index]
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tupel | (C1), (C2) |
(I2) | index |
Konstante vom Typ si32 |
(C1), (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Beliebiger unterstützter Typ | (C2) |
Einschränkungen
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Beispiele
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
wenn
Semantik
Erzeugt die Ausgabe aus der Ausführung genau einer Funktion aus true_branch
oder false_branch
, je nach Wert von pred
. Formeller: result =
pred ? true_branch() : false_branch()
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | pred |
Nulldimensionaler Tensor vom Typ i1 |
|
(I2) | true_branch |
Funktion | (C1–C3) |
(I3) | false_branch |
Funktion | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens | (C3) |
Einschränkungen
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Beispiele
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Semantik
Extrahiert den imaginären Teil elementweise aus operand
und erzeugt einen result
-Tensor. Formeller ausgedrückt: Für jedes Element x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des Gleitkommatyps | (C1), (C2) |
Einschränkungen
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
ist definiert als:complex_element_type(element_type(operand))
, wennis_complex(operand)
.- Andernfalls
element_type(operand)
.
Beispiele
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
Infeed
Semantik
Liest Daten aus dem Infeed und generiert results
.
Die Semantik von infeed_config
ist implementierungsabhängig.
results
besteht aus Nutzlastwerten, die zuerst kommen, und einem Token, das als Letztes kommt. Künftig werden wir die Nutzlast und das Token in zwei separate Ausgaben aufteilen, um für mehr Klarheit zu sorgen (#670).
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
Konstante vom Typ string |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C1-C3) |
Einschränkungen
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
oderis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Beispiele
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
ITA
Semantik
Füllt einen output
-Tensor mit Werten in aufsteigender Reihenfolge ab null entlang der iota_dimension
-Dimension. Formeller gesprochen
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
output |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
0 <= iota_dimension < rank(output)
.
Beispiele
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantik
Es wird elementweise geprüft, ob der Wert in x
endlich ist (d. h. weder +Inf, -Inf noch NaN), und es wird ein y
-Tensor erstellt. Implementiert den isFinite
-Vorgang aus der IEEE-754-Spezifikation. Bei quantisierten Typen ist das Ergebnis immer true
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | x |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
y |
Tensor vom Typ „Boolescher Wert“ | (C1) |
Einschränkungen
- (C1)
shape(x) = shape(y)
.
Beispiele
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantik
Führt einen elementweisen Logarithmus-Vorgang auf den Tensor operand
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
log
von IEEE-754. - Für komplexe Zahlen: komplexer Logarithmus.
- Für quantisierte Typen:
dequantize_op_quantize(log, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantik
Führt den elementweisen Logarithmus mit einer zusätzlichen Operation auf dem Tensor operand
aus und erzeugt einen Tensor vom Typ result
. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
logp1
von IEEE-754. - Bei komplexen Zahlen: Komplexer Logarithmus plus 1.
- Für quantisierte Typen:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
Logistik
Semantik
Führt einen elementweisen logistischer Vorgang auf den operand
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
division(1, addition(1, exp(-x)))
von IEEE-754. - Für komplexe Zahlen: komplexe Logistik.
- Für quantisierte Typen:
dequantize_op_quantize(logistic, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
Karte
Semantik
Wendet eine Abbildungsfunktion computation
auf inputs
entlang der dimensions
an und erzeugt einen result
-Tensor.
Formeller: result[result_index] = computation(inputs...[result_index])
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C1-C4) |
(I2) | dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C3) |
(I3) | computation |
Funktion | (C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C4) |
Einschränkungen
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
hat den Typ(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, wobeiEi = element_type(inputs[i])
undE' = element_type(result)
.
Beispiele
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
Maximum
Semantik
Führt einen elementweisen Max-Vorgang auf den Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Gehen Sie je nach Elementtyp so vor:
- Für boolesche Werte: Logisches OR.
- Für Ganzzahlen: Ganzzahlmaximum.
- Für Gleitkommazahlen:
maximum
von IEEE-754. - Bei komplexen Zahlen: Lexikalisches Maximum für das
(real, imaginary)
-Paar. Die Sortierung komplexer Zahlen führt zu überraschenden Ergebnissen. Daher werden wir die Unterstützung für komplexe Zahlen für diesen Vorgang in Zukunft entfernen (#560). - Für quantisierte Typen:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
Minimum
Semantik
Führt eine elementweise MIN-Operation auf den Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für boolesche Werte: Logisches AND.
- Für Ganzzahlen: Ganzzahlminimum.
- Für Gleitkommazahlen:
minimum
aus IEEE-754. - Bei komplexen Zahlen: Lexikalisches Minimum für das
(real, imaginary)
-Paar. Das Aufstellen einer Sortierung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, in Zukunft die Unterstützung für komplexe Zahlen für diesen Vorgang einzustellen (#560). - Für quantisierte Typen:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplizieren
Semantik
Führt das elementweise Produkt von zwei Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Gehen Sie je nach Elementtyp so vor:
- Für boolesche Werte: Logisches AND.
- Für Ganzzahlen: Ganzzahlmultiplikation.
- Für Gleitkommazahlen:
multiplication
von IEEE-754. - Bei komplexen Zahlen: komplexe Multiplikation.
- Für quantisierte Typen:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
Semantik
Führt eine elementweise Negation des operand
-Tensors aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Bei vorzeichenbehafteten Ganzzahlen: Ganzzahlnegation.
- Für vorzeichenlose Ganzzahlen: Bitcast zu vorzeichenbehafteter Ganzzahl, Ganzzahlnegation, Bitcast zurück zu vorzeichenloser Ganzzahl.
- Für Gleitkommazahlen:
negate
von IEEE-754. - Bei komplexen Zahlen: komplexe Negation.
- Für quantisierte Typen:
dequantize_op_quantize(negate, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
nicht
Semantik
Führt elementweise NOT von Tensor operand
aus und erzeugt einen result
-Tensor.
Gehen Sie je nach Elementtyp so vor:
- Für boolesche Werte: logisches NOT.
- Bei Ganzzahlen: Bitweises NOT.
Argumente
Name | Typ | Einschränkungen |
---|---|---|
operand |
Tensor des booleschen oder Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des booleschen oder Ganzzahltyps | (C1) |
Einschränkungen
- (C1)
type(operand) = type(result)
.
Beispiele
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantik
Sorgt dafür, dass die Vorgänge, die die operand
generieren, vor allen Vorgängen ausgeführt werden, die von der result
abhängen, und verhindert, dass Compilertransformationen Vorgänge über die Barriere hinweg verschieben. Ansonsten ist der Vorgang eine Identität, also result = operand
.
Argumente
Name | Typ | Einschränkungen |
---|---|---|
operand |
variadische Anzahl von Tensoren, pro Tensor quantisierte Tensoren oder Tokens | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
variadische Anzahl von Tensoren, pro Tensor quantisierte Tensoren oder Tokens | (C1) |
Einschränkungen
- (C1)
type(operand...) = type(result...)
.
Beispiele
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
oder
Semantik
Führt ein elementweises OR von zwei Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für boolesche Werte: Logisches OR.
- Bei Ganzzahlen: Bitweises OR.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor vom Typ „Ganzzahl“ oder „Boolescher Wert“ | (C1) |
(I2) | rhs |
Tensor vom Typ „Ganzzahl“ oder „Boolescher Wert“ | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des Ganzzahl- oder booleschen Typs | (C1) |
Einschränkungen
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
Outfeed
Semantik
Schreibt inputs
in den Outfeed und generiert ein result
-Token.
Die Semantik von outfeed_config
ist implementierungsabhängig.
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | inputs |
Variadische Anzahl von Tensoren oder quantisierten Tensoren |
(I2) | token |
token |
(I3) | outfeed_config |
Konstante vom Typ string |
Ausgaben
Name | Typ |
---|---|
result |
token |
Beispiele
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
Pad
Semantik
Erweitert operand
, indem der Tensor und die Elemente des Tensors mit der angegebenen padding_value
umgeben werden.
Mit edge_padding_low
und edge_padding_high
wird die Größe des Abstands angegeben, der am unteren Ende (neben Index 0) bzw. am oberen Ende (neben dem höchsten Index) jeder Dimension hinzugefügt wird. Der Wert für den Abstand kann negativ sein. Der absolute Wert des negativen Abstands gibt die Anzahl der Elemente an, die aus der angegebenen Dimension entfernt werden sollen.
Mit interior_padding
wird der Abstand zwischen zwei Elementen in jeder Dimension angegeben. Er darf nicht negativ sein. Der Innenrand erfolgt vor dem Randrand, sodass Elemente aus dem mit Innenrand versehenen Operanden entfernt werden.
Formell ist result[result_index]
so definiert:
operand[operand_index]
, wennresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.- Andernfalls
padding_value
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Nulldimensionaler Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I3) | edge_padding_low |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4) |
(I5) | interior_padding |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2-C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C3–C6) |
Einschränkungen
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Beispiele
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantik
Erzeugt partition_id
des aktuellen Prozesses.
Ausgaben
Name | Typ |
---|---|
result |
0-dimensionaler Tensor vom Typ ui32 |
Beispiele
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
PopCnt
Semantik
Führt eine elementweise Zählung der im operand
-Tensor festgelegten Bits durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Ganzzahltensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahltensor | (C1) |
Einschränkungen
- (C1)
type(operand) = type(result)
.
Beispiele
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
Leistung
Semantik
Führt eine elementweise Exponentiation des lhs
-Tensors mit dem rhs
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Ganzzahlen: Ganzzahlexponentiation.
- Für Gleitkommazahlen:
pow
aus IEEE-754. - Für komplexe Zahlen: komplexe Exponentie
- Für quantisierte Typen:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semantik
Hiermit wird der reelle Teil elementweise aus dem operand
extrahiert und ein result
-Tensor erzeugt. Formeller ausgedrückt: Für jedes Element x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des Gleitkommatyps | (C1), (C2) |
Einschränkungen
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
ist definiert als:complex_element_type(element_type(operand))
, wennis_complex(operand)
.- Andernfalls
element_type(operand)
.
Beispiele
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantik
Empfängt Daten von einem Kanal mit channel_id
und erstellt results
.
Wenn is_host_transfer
true
ist, werden beim Vorgang Daten vom Host übertragen. Andernfalls werden Daten von einem anderen Gerät übertragen. Was das bedeutet, ist implementierungsabhängig. Dieses Flag dupliziert die Informationen unter channel_type
. Wir planen daher, in Zukunft nur eines davon beizubehalten (#666).
results
besteht aus Nutzlastwerten, die zuerst kommen, und einem Token, das als Letztes kommt. Künftig werden wir die Nutzlast und das Token in zwei separate Ausgaben aufteilen, um für mehr Klarheit zu sorgen (#670).
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
Konstante vom Typ si64 |
|
(I3) | channel_type |
Aufzählung von DEVICE_TO_DEVICE und HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
Konstante vom Typ i1 |
(C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens | (C2–C4) |
Einschränkungen
- (C1)
channel_type
ist definiert als:HOST_TO_DEVICE
wennis_host_transfer = true
,- Andernfalls
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
oderis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Beispiele
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantik
Wendet eine Reduktionsfunktion body
auf inputs
und init_values
entlang der dimensions
an und erzeugt results
-Tensoren.
Die Reihenfolge der Reduktionen ist implementierungsdefiniert. Das bedeutet, dass body
und init_values
eine Monoide bilden müssen, um zu gewährleisten, dass der Vorgang für alle Eingaben in allen Implementierungen die gleichen Ergebnisse liefert. Diese Bedingung gilt jedoch nicht für viele beliebte Reduzierungen. Beispielsweise bilden die Addition von Gleitkommazahlen für body
und Null für init_values
kein Monoid, da die Addition von Gleitkommazahlen nicht assoziativ ist.
Formeller ausgedrückt: results...[j0, ..., jR-1] = reduce(input_slices_converted)
, wobei:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, wobei:
andimensions
eingefügt werden.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
für einen Binärbaumschedule
wobei:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
ist ein implementierungsdefinierter vollständiger Binärbaum, dessen In-Order-Durchlauf aus folgenden Schritten besteht:input_slices_converted...[index]
-Werte für alleindex
inindex_space(input_slices_converted)
in aufsteigender lexikografischer Reihenfolge vonindex
.- Dazwischen gibt es einen von der Implementierung definierten Betrag von
init_values_converted
an den von der Implementierung definierten Positionen.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1-C4), (C6), (C7) |
(I2) | init_values |
Variadische Anzahl von nulldimensionalen Tensoren oder pro Tensor quantisierte Tensoren | (C2), (C3) |
(I3) | dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C4), (C5), (C7) |
(I4) | body |
Funktion | (C6) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C3), (C7), (C8) |
Einschränkungen
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
hat den Typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, wobeiis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, mit der Ausnahme, dass die Dimensionsgrößen voninputs...
, diedimensions
entsprechen, nicht enthalten sind. - (C8)
element_type(results[i]) = Ei
für allei
in[0,N)
.
Beispiele
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantik
Führt eine elementweise Umwandlung von operand
in einen anderen Gleitkommatyp durch, der exponent_bits
und mantissa_bits
verwendet, und zurück zum ursprünglichen Gleitkommatyp. Dabei wird ein output
-Tensor erzeugt.
Formeller:
- Die Mantissenbits des ursprünglichen Werts werden so aktualisiert, dass der ursprüngliche Wert unter Verwendung der
roundToIntegralTiesToEven
-Semantik auf den nächsten mitmantissa_bits
dargestellten Wert gerundet wird. - Wenn
mantissa_bits
kleiner als die Anzahl der Mantissenbits des ursprünglichen Werts ist, werden die Mantissenbits aufmantissa_bits
gekürzt. - Wenn die Exponentenbits des Zwischenergebnisses nicht in den von
exponent_bits
angegebenen Bereich passen, kommt es zu einem Überlauf auf unendlich mit dem ursprünglichen Vorzeichen oder zu einem Unterlauf auf null mit dem ursprünglichen Vorzeichen. - Für quantisierte Typen wird
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1) |
(I2) | exponent_bits |
Konstante vom Typ si32 |
(C2) |
(I3) | mantissa_bits |
Konstante vom Typ si32 |
(C3) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
output |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Beispiele
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantik
Führt in jeder Prozessgruppe im StableHLO-Prozessraster eine Reduzierung mit computations
über die Werte des Tensors operand
aus jedem Prozess durch, teilt das Reduktionsergebnis entlang von scatter_dimension
in Teile auf und verteilt die Teilteile auf die Prozesse, um das result
zu erzeugen.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, was so definiert ist:
cross_replica(replica_groups)
wennchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
wennchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
wennchannel_id > 0 and use_global_device_ids = true
.
Danach gilt in jeder process_group
Folgendes:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
für allesender
inprocess_group
, wobeireceiver_index = process_group.index(receiver)
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
Konstante vom Typ si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
2-dimensionale Tensorkonstante vom Typ si64 |
(C3-C5) |
(I4) | channel_id |
Konstante vom Typ si64 |
(C6) |
(I5) | use_global_device_ids |
Konstante vom Typ i1 |
(C6) |
(I6) | computation |
Funktion | (C7) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C8-C9) |
Einschränkungen
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
wird so definiert:num_replicas
, wenncross_replica
verwendet wird.num_replicas
, wenncross_replica_and_partition
verwendet wird.num_processes
, wennflattened_ids
verwendet wird.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Wenn
use_global_device_ids = true
, dannchannel_id > 0
. - (C7)
computation
hat den Typ(tensor<E>, tensor<E>) -> (tensor<E>)
, wobeiis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
, mit folgenden Ausnahmen:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantik
Wendet eine Reduktionsfunktion body
auf Fenster von inputs
und init_values
an und gibt results
zurück.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in results...
aus inputs...
berechnet werden.
Formal gilt Folgendes: results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(siehe Reduzieren), wobei Folgendes gilt:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
Variadische Anzahl von nulldimensionalen Tensoren oder pro Tensor quantisierte Tensoren | (C1), (C13) |
(I3) | window_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Eindimensionale Tensorkonstante vom Typ si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Eindimensionale Tensorkonstante vom Typ si64 |
(C10), (C11), (C15) |
(I7) | padding |
2-dimensionale Tensorkonstante vom Typ si64 |
(C12), (C15) |
(I8) | body |
Funktion | (C13) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C1), (C14–C16) |
Einschränkungen
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
hat den Typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, wobeiis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
wobei:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
für allei
in[0,N)
.
Beispiele
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
Rest
Semantik
Er führt elementweise den Rest der Division von Dividend lhs
und Divisor rhs
durch und erzeugt einen result
-Tensor.
Formal wird das Vorzeichen des Ergebnisses von der Dividende entnommen, und der absolute Wert des Ergebnisses ist immer kleiner als der absolute Wert des Divisors.
Der Rest wird als lhs - d * rhs
berechnet, wobei d
durch Folgendes gegeben ist:
- Für Ganzzahlen:
stablehlo.divide(lhs, rhs)
. - Für Gleitkommazahlen:
division(lhs, rhs)
nach IEEE-754 mit dem RundungsattributroundTowardZero
. - Für komplexe Zahlen: TBD (#997).
- Für quantisierte Typen:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Bei Gleitkommaelementtypen steht dieser Vorgang im Gegensatz zum remainder
-Vorgang der IEEE-754-Spezifikation, bei dem d
der Ganzzahlwert ist, der dem genauen Wert von lhs/rhs
am nächsten kommt, wobei bei Gleichstand eine gerade Zahl gewählt wird.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantik
Erzeugt replica_id
des aktuellen Prozesses.
Ausgaben
Name | Typ |
---|---|
result |
0-dimensionaler Tensor vom Typ ui32 |
Beispiele
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
reshape
Semantik
Führt die Umformung des operand
-Tensors in einen result
-Tensor durch. Konzeptionell bedeutet dies, dass dieselbe kanonische Darstellung beibehalten wird, aber möglicherweise die Form geändert wird, z.B. von tensor<2x3xf32>
zu tensor<3x2xf32>
oder tensor<6xf32>
.
Formeller ausgedrückt: result[result_index] = operand[operand_index]
, wobei result_index
und operand_index
dieselbe Position in der lexikografischen Reihenfolge von index_space(result)
und index_space(operand)
haben.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1-C3) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1–C3) |
Einschränkungen
- (C1)
element_type(result)
ist gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, mit der Ausnahme, dass sichquantization_dimension(operand)
undquantization_dimension(result)
möglicherweise unterscheiden.
- (C2)
size(operand) = size(result)
. - (C3) Wenn
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Beispiele
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
Semantik
Kehrt die Reihenfolge der Elemente im operand
entlang der angegebenen dimensions
um und erzeugt einen result
-Tensor. Formeller ausgedrückt:
result[result_index] = operand[operand_index]
wobei:
operand_index[d] = dim(result, d) - result_index[d] - 1
wennd
indimensions
.- Andernfalls
operand_index[d] = result_index[d]
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1), (C3) |
(I2) | dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C3) |
Einschränkungen
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Beispiele
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
Rng
Semantik
Generiert Zufallszahlen mit dem rng_distribution
-Algorithmus und erzeugt einen result
-Tensor mit einer bestimmten Form shape
.
Wenn rng_distribution = UNIFORM
, werden die Zufallszahlen gemäß der Gleichverteilung über das Intervall [a, b)
generiert. Bei a >= b
ist das Verhalten nicht definiert.
Wenn rng_distribution = NORMAL
, werden die Zufallszahlen gemäß der Normalverteilung mit Mittelwert = a
und Standardabweichung = b
generiert.
Wenn b < 0
, ist das Verhalten nicht definiert.
Die genaue Art und Weise, wie Zufallszahlen generiert werden, ist implementierungsdefiniert. Sie können beispielsweise deterministisch oder nicht deterministisch sein und einen versteckten Zustand verwenden oder nicht.
In Gesprächen mit vielen Stakeholdern wurde festgestellt, dass diese Operation praktisch nicht mehr verwendet wird. Wir planen daher, sie in Zukunft zu entfernen (#597).
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | a |
0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolescher Wert“ oder „Gleitkommazahl“ | (C1), (C2) |
(I2) | b |
0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolescher Wert“ oder „Gleitkommazahl“ | (C1), (C2) |
(I3) | shape |
Eindimensionale Tensorkonstante vom Typ si64 |
(C3) |
(I4) | rng_distribution |
Aufzählung von UNIFORM und NORMAL |
(C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Typ „Ganzzahl“, „Boolescher Wert“ oder „Gleitkomma“ | (C1-C3) |
Einschränkungen
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Wenn
rng_distribution = NORMAL
, dannis_float(a)
. - (C3)
shape(result) = shape
.
Beispiele
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantik
Gibt eine output
mit gleichmäßig verteilten Zufallsbits und einen aktualisierten Ausgabestatus output_state
mit dem Pseudozufallszahlengenerator rng_algorithm
bei einem Anfangsstatus initial_state
zurück. Die Ausgabe ist garantiert eine deterministische Funktion von initial_state
, aber nicht unbedingt zwischen den Implementierungen.
rng_algorithm
ist einer der folgenden Werte:
DEFAULT
: Implementierungsdefinierter Algorithmus.THREE_FRY
: Implementierungsdefinierte Variante des Threefry-Algorithmus.*PHILOX
: Implementierungsdefinierte Variante des Philox-Algorithmus.*
* Siehe Salmon et al. SC 2011. Parallele Zufallszahlen: So einfach wie das Einmaleins.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | rng_algorithm |
Aufzählung von DEFAULT , THREE_FRY und PHILOX |
(C2) |
(I2) | initial_state |
Eindimensionaler Tensor vom Typ ui64 |
(C1), (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
output_state |
Eindimensionaler Tensor vom Typ ui64 |
(C1) |
output |
Tensor vom Typ „Ganzzahl“ oder „Gleitkomma“ |
Einschränkungen
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
ist definiert als:- Implementierungsabhängig, wenn
rng_algorithm = DEFAULT
. 2
, wennrng_algorithm = THREE_FRY
.2
oder3
, wennrng_algorithm = PHILOX
.
- Implementierungsabhängig, wenn
Beispiele
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantik
Führt eine elementweise Rundung auf die nächste Ganzzahl durch, wobei Bindungen von Null entfernt werden, auf dem Tensor operand
und erzeugt einen result
-Tensor. Implementiert den roundToIntegralTiesToAway
-Vorgang aus der IEEE-754-Spezifikation. Bei quantisierten Typen wird dequantize_op_quantize(round_nearest_afz, operand, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantik
Führt eine elementweise Rundung auf die nächste ganze Zahl aus. Bei Unentschieden wird die gerade Ganzzahl gewählt. Der operand
-Tensor wird in einen result
-Tensor umgewandelt. Implementiert den roundToIntegralTiesToEven
-Vorgang aus der IEEE-754-Spezifikation. Bei quantisierten Typen wird dequantize_op_quantize(round_nearest_even, operand, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor vom Gleitkommatyp oder quantisierter Tensor pro Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantik
Führt eine elementweise reziproke Quadratwurzeloperation für den operand
-Tensor durch und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
rSqrt
aus IEEE-754. - Für komplexe Zahlen: komplexe reziproke Quadratwurzel.
- Für quantisierte Typen:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
Streuung
Semantik
Erzeugt results
-Tensoren, die inputs
-Tensoren entsprechen, mit der Ausnahme, dass mehrere durch scatter_indices
angegebene Segmente mit den Werten updates
mithilfe von update_computation
aktualisiert werden.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in updates...
auf Elemente in results...
abgebildet werden. Im Diagramm werden einige Beispielupdates...
-Indizes ausgewählt und es wird ausführlich erläutert, welchen results...
-Indizes sie entsprechen.
Formeller für alle update_index
in index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
wird so definiert:scatter_indices[si0, ..., :, ..., siN]
, wobeisi
einzelne Elemente inupdate_scatter_index
sind und:
an der Positionindex_vector_dim
eingefügt wird, wennindex_vector_dim
<rank(scatter_indices)
.- Andernfalls
[scatter_indices[update_scatter_index]]
.
- Für
d_input
inaxes(inputs[0])
:full_start_index[d_input] = start_index[d_start]
wennd_input = scatter_dims_to_operand_dims[d_start]
.- Andernfalls
full_start_index[d_input] = 0
.
- Für
d_input
inaxes(inputs[0])
:full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
wennd_input = input_batching_dims[i_batching]
undd_start = scatter_indices_batching_dims[i_batching]
.- Andernfalls
full_batching_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, wobeiwi
einzelne Elemente inupdate_window_index
sind und0
an den Indexen ausinserted_window_dims
undinput_batching_dims
eingefügt wird.result_index = full_start_index + full_batching_index + full_window_index
.
Daher gilt: results = exec(schedule, inputs)
, wobei:
schedule
ist eine implementierungsdefinierte Permutation vonindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
wobei:- Wenn
result_index
innerhalb der Grenzen fürshape(results...)
liegt updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
ist eine Kopie vonresults
, bei derresults...[result_index]
aufupdated_values...
gesetzt ist.- Andernfalls:
updated_results = results
.
- Wenn
exec([], results) = results
.
Wenn indices_are_sorted
true
ist, kann die Implementierung davon ausgehen, dass scatter_indices
nach scatter_dims_to_operand_dims
sortiert ist. Andernfalls ist das Verhalten nicht definiert. Formeller ausgedrückt: Für alle i1 < i2
von indices(result)
gilt full_start_index(i1)
<= full_start_index(i2)
.
Wenn unique_indices
true
ist, kann die Implementierung davon ausgehen, dass alle result_index
-Indexe, in die die Daten verteilt werden, eindeutig sind. Wenn unique_indices
true
ist, die Indexe, auf die verteilt wird, aber nicht eindeutig sind, ist das Verhalten nicht definiert.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C1), (C2), (C4–C6), (C11), (C13), (C18), (C21), (C23–C24) |
(I2) | scatter_indices |
Tensor des Ganzzahltyps | (C4), (C15), (C19), (C22) |
(I3) | updates |
variadische Anzahl von Tensoren oder pro Tensor quantisierte Tensoren | (C3-C6), (C8) |
(I4) | update_window_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C19-C21) |
(I9) | index_vector_dim |
Konstante vom Typ si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
Konstante vom Typ i1 |
|
(I11) | unique_indices |
Konstante vom Typ i1 |
|
(I12) | update_computation |
Funktion | (C23) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C24–C25) |
Einschränkungen
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
wobei:update_scatter_dim_sizes = shape(scatter_indices)
, mit der Ausnahme, dass die Dimensionsgröße vonscatter_indices
, dieindex_vector_dim
entspricht, nicht enthalten ist.update_window_dim_sizes <= shape(inputs[0])
, mit der Ausnahme, dass die Dimensionsgrößen ininputs[0]
, dieinserted_window_dims
undinput_batching_dims
entsprechen, nicht enthalten sind.- Bei
combine
wirdupdate_scatter_dim_sizes
auf Achsen platziert, dieupdate_scatter_dims
entsprechen, undupdate_window_dim_sizes
auf Achsen, dieupdate_window_dims
entsprechen.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
hat den Typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, wobeiis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
für allei
in[0,N)
.
Beispiele
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
auswählen
Semantik
Erzeugt einen result
-Tensor, bei dem jedes Element basierend auf dem Wert des entsprechenden Elements von pred
aus dem on_true
- oder on_false
-Tensor ausgewählt wird.
Formeller: result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, wobei pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Bei quantisierten Typen wird dequantize_select_quantize(pred, on_true, on_false, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | pred |
Tensor vom Typ i1 |
(C1) |
(I2) | on_true |
Tensor oder pro Tensor quantisierter Tensor | (C1-C2) |
(I3) | on_false |
Tensor oder pro Tensor quantisierter Tensor | (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C2) |
Einschränkungen
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Beispiele
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantik
Die Werte aus dem source
-Tensor werden mit scatter
basierend auf dem Ergebnis von reduce_window
des input
-Tensors mit select
verstreut und es wird ein result
-Tensor erzeugt.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result
aus operand
und source
berechnet werden.
Formeller:
selected_values = reduce_window_without_init(...)
durch die folgenden Eingaben:inputs = [operand].
window_dimensions
,window_strides
undpadding
, die unverändert verwendet werden.base_dilations = windows_dilations = 1
.body
ist definiert als:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
Dabei funktionieren
E = element_type(operand)
undreduce_window_without_init
genau wiereduce_window
, mit der Ausnahme, dass dieschedule
der zugrunde liegendenreduce
(siehe reduce) keine Initialisierungswerte enthält. Derzeit ist nicht festgelegt, was passiert, wenn das entsprechende Fenster keine Werte enthält (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
wobei:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
, wennselected_values[source_index]
das Elementoperand
vonoperand_index
enthält.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder pro Tensor quantisierter Tensor | (C1-C4), (C6), (C8-C11) |
(I2) | source |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C2) |
(I3) | init_value |
0-dimensionaler Tensor oder quantisierter Tensor pro Tensor | (C3) |
(I4) | window_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C6), (C7) |
(I6) | padding |
2-dimensionale Tensorkonstante vom Typ si64 |
(C2), (C8) |
(I7) | select |
Funktion | (C9) |
(I8) | scatter |
Funktion | (C10) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder pro Tensor quantisierter Tensor | (C11-C12) |
Einschränkungen
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
wobei:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
hat den Typ(tensor<E>, tensor<E>) -> tensor<i1>
, wobeiE = element_type(operand)
. - (C10)
scatter
hat den Typ(tensor<E>, tensor<E>) -> tensor<E>
, wobeiis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Beispiele
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
senden
Semantik
Sendet inputs
an einen Kanal channel_id
und generiert ein result
-Token.
Wenn is_host_transfer
den Wert true
hat, werden Daten an den Host übertragen. Andernfalls werden die Daten auf ein anderes Gerät übertragen. Was das bedeutet, ist implementierungsabhängig. Dieses Flag dupliziert die in channel_type
angegebenen Informationen. Daher planen wir, in Zukunft nur eines davon beizubehalten (#666).
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | inputs |
Variadische Anzahl von Tensoren oder quantisierten Tensoren | |
(I2) | token |
token |
|
(I3) | channel_id |
Konstante vom Typ si64 |
|
(I4) | channel_type |
Aufzählung von DEVICE_TO_DEVICE und DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
Konstante vom Typ i1 |
(C1) |
Ausgaben
Name | Typ |
---|---|
result |
token |
Einschränkungen
- (C1)
channel_type
wird so definiert:DEVICE_TO_HOST
wennis_host_transfer = true
,- Andernfalls
DEVICE_TO_DEVICE
.
Beispiele
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantik
Führt eine elementweise Linksverschiebung am lhs
-Tensor um die rhs
-Anzahl von Bits durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Ganzzahltensor | (C1) |
(I2) | rhs |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahltensor | (C1) |
Einschränkungen
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantik
Führt eine elementweise arithmetische Operation zur Rechtsverschiebung für den lhs
-Tensor um die rhs
-Anzahl von Bits durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Ganzzahltensor | (C1) |
(I2) | rhs |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahltensor | (C1) |
Einschränkungen
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantik
Führt einen elementweisen logischen Rechtsverschiebevorgang auf dem lhs
-Tensor um rhs
Bits aus und gibt einen result
-Tensor zurück.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Ganzzahltensor | (C1) |
(I2) | rhs |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahltensor | (C1) |
Einschränkungen
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
Signieren
Semantik
Gibt das Vorzeichen von operand
elementweise zurück und erzeugt einen result
-Tensor.
Formal kann die Semantik für jedes Element x
mit der Python-Syntax so ausgedrückt werden:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Bei quantisierten Typen wird dequantize_op_quantize(sign, operand, type(result))
ausgeführt.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor eines vorzeichenbehafteten Ganzzahl-, Gleitkomma- oder komplexer Typ oder eines quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder quantisierter Tensor pro Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
Sinus
Semantik
Führt eine elementweise Sinusoperation auf dem operand
-Tensor aus und erzeugt einen result
-Tensor. Gehen Sie je nach Elementtyp so vor:
- Für Gleitkommazahlen:
sin
von IEEE-754. - Bei komplexen Zahlen: komplexer Sinus.
- Für quantisierte Typen:
dequantize_op_quantize(sine, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
Slice
Semantik
Extrahiert ein Segment aus dem operand
mit statisch berechneten Startindexen und erzeugt einen result
-Tensor. start_indices
enthält die Startindizes des Ausschnitts für jede Dimension, limit_indices
die Endindizes (exklusiv) des Ausschnitts für jede Dimension und strides
die Schritte für jede Dimension.
Formal gilt: result[result_index] = operand[operand_index]
, wobei operand_index = start_indices + result_index * strides
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1-C3), (C5) |
(I2) | start_indices |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3), (C5) |
(I4) | strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C5) |
Einschränkungen
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Beispiele
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
Sortieren
Semantik
Sortiert eindimensionale Scheiben von inputs
entlang der Dimension dimension
nach einer comparator
und gibt results
zurück.
Im Gegensatz zu ähnlichen Eingaben bei anderen Vorgängen sind bei dimension
negative Werte zulässig. Die Semantik ist unten beschrieben. Künftig ist dies aus Gründen der Einheitlichkeit möglicherweise nicht mehr zulässig (#1377).
Wenn is_stable
„true“ ist, ist die Sortierung stabil, d. h., die relative Reihenfolge der Elemente, die vom Vergleichsoperator als gleich eingestuft werden, wird beibehalten. Bei einer einzelnen Eingabe werden zwei Elemente e1
und e2
vom Vergleicher nur dann als gleich betrachtet, wenn comparator(e1, e2) = comparator(e2, e1) = false
. In der folgenden Formalisierung wird dargestellt, wie dies auf mehrere Eingaben verallgemeinert wird.
Formeller für alle result_index
in index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, wobeiriN
einzelne Elemente inresult_index
sind und:
beiadjusted_dimension
eingefügt wird.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- Dabei sortiert
sort
ein eindimensionales Segment in nicht absteigender Reihenfolge. Dabei wird davon ausgegangen, dasscomparator_together
true
zurückgibt, wenn das linke Argument kleiner als das rechte zweite Argument ist. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1–C5) |
(I2) | dimension |
Konstante vom Typ si64 |
(C4) |
(I3) | is_stable |
Konstante vom Typ i1 |
|
(I4) | comparator |
Funktion | (C5) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C2), (C3) |
Einschränkungen
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, wobeiR = rank(inputs[0])
. - (C5)
comparator
hat den Typ(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, wobeiEi = element_type(inputs[i])
.
Beispiele
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantik
Führt eine elementweise Quadratwurzel-Operation auf dem operand
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
squareRoot
von IEEE-754. - Bei komplexen Zahlen: komplexe Quadratwurzel
- Für quantisierte Typen:
dequantize_op_quantize(sqrt, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semantik
Führt eine elementweise Subtraktion der zwei Tensoren lhs
und rhs
durch und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Ganzzahlen: Subtraktion von Ganzzahlen.
- Für Gleitkommazahlen:
subtraction
von IEEE-754. - Bei komplexen Zahlen: komplexe Subtraktion.
- Für quantisierte Typen:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Ganzzahl-, Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantik
Führt eine elementweise Tangentenoperation am operand
-Tensor durch und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
tan
von IEEE-754. - Bei komplexen Zahlen: komplexer Tangens.
- Für quantisierte Typen:
dequantize_op_quantize(tan, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semantik
Führt einen elementweisen hyperbolischen Tangens-Vorgang auf dem operand
-Tensor aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für Gleitkommazahlen:
tanh
aus IEEE-754. - Für komplexe Zahlen: komplexer hyperbolischer Tangens
- Für quantisierte Typen:
dequantize_op_quantize(tanh, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
Transponieren
Semantik
Permutet die Abmessungen des operand
-Tensors mit permutation
an und erzeugt einen result
-Tensor. Formeller ausgedrückt: result[result_index] = operand[operand_index]
, wobei result_index[d] = operand_index[permutation[d]]
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1–C4) |
(I2) | permutation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2-C4) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1), (C3-C4) |
Einschränkungen
- (C1)
element_type(result)
ist gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, mit der Ausnahme, dass sichquantization_dimension(operand)
undquantization_dimension(result)
möglicherweise unterscheiden.
- (C2)
permutation
ist eine Permutation vonrange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Wenn
is_per_axis_quantized(result)
, dannquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Beispiele
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantik
Löst Batches von linearen Gleichungssystemen mit unter- oder oberdiagonalen Koeffizientenmatrizen.
Formeller ausgedrückt: Bei a
und b
ist result[i0, ..., iR-3, :, :]
die Lösung für op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
, wenn left_side
true
ist, oder für x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
, wenn left_side
false
ist. Dabei wird die Variable x
berechnet, wobei op(a)
durch transpose_a
bestimmt wird. transpose_a
kann eine der folgenden Optionen sein:
NO_TRANSPOSE
: Vorgang mita
ausführenTRANSPOSE
: Vorgang auf der Transponierung vona
ausführen.ADJOINT
: Operation für die konjugierte Transponierung vona
ausführen.
Eingabedaten werden nur aus dem unteren Dreieck von a
gelesen, wenn lower
= true
, andernfalls aus dem oberen Dreieck von a
. Die Ausgabedaten werden im selben Dreieck zurückgegeben. Die Werte im anderen Dreieck sind implementierungsspezifisch.
Wenn unit_diagonal
wahr ist, kann die Implementierung davon ausgehen, dass die diagonalen Elemente von a
= 1 sind. Andernfalls ist das Verhalten nicht definiert.
Führt für quantisierte Typen dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | a |
Tensor des Gleitkomma- oder komplexen Typs oder quantisierter Tensor pro Tensor | (C1-C3) |
(I2) | b |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1-C4) |
(I3) | left_side |
Konstante vom Typ i1 |
(C3) |
(I4) | lower |
Konstante vom Typ i1 |
|
(I5) | unit_diagonal |
Konstante vom Typ i1 |
|
(I6) | transpose_a |
enum von NO_TRANSPOSE , TRANSPOSE und ADJOINT |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Gleitkomma- oder komplexer Tensor oder pro Tensor quantisierter Tensor | (C1) |
Einschränkungen
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) Die Beziehung zwischen
shape(a)
undshape(b)
ist wie folgt definiert:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Beispiele
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
Tupel
Semantik
Erzeugt ein result
-Tupel aus den Werten val
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | val |
Variadische Anzahl von Werten | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tupel | (C1) |
Einschränkungen
- (C1)
result
hat den Typtuple<E0, ..., EN-1>
, wobeiEi = type(val[i])
.
Beispiele
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantik
Führt eine elementweise Umwandlung des quantisierten Tensors operand
in einen Gleitkommatensor result
gemäß den Quantisierungsparametern aus, die vom Typ operand
definiert sind.
Formeller: result = dequantize(operand)
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
quantisierter Tensor | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des Gleitkommatyps | (C1), (C2) |
Einschränkungen
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Beispiele
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantik
Führt eine elementweise Konvertierung eines Gleitkommatensors oder quantisierten Tensors operand
in einen quantisierten Tensor result
gemäß den Quantisierungsparametern aus, die vom Typ result
definiert sind.
Formeller gesprochen
- Wenn
is_float(operand)
:result = quantize(operand, type(result))
.
- Wenn
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
Tensor vom Gleitkomma- oder Quantisierungstyp | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
quantisierter Tensor | (C1), (C2) |
Einschränkungen
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Beispiele
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
while
Semantik
Erzeugt die Ausgabe nach einmaliger oder häufiger Ausführung der body
-Funktion, während die Funktion cond
true
ausgibt. Formeller ausgedrückt, kann die Semantik mithilfe der Python-Syntax so ausgedrückt werden:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Das Verhalten einer Endlosschleife ist noch nicht festgelegt (#383).
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | operand |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C1-C3) |
(I2) | cond |
Funktion | (C1) |
(I3) | body |
Funktion | (C2) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens | (C3) |
Einschränkungen
- (C1)
cond
hat den Typ(T0, ..., TN-1) -> tensor<i1>
, wobeiTi = type(operand[i])
. - (C2)
body
hat den Typ(T0, ..., TN-1) -> (T0, ..., TN-1)
, wobeiTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Beispiele
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semantik
Führt ein elementweises XOR von zwei Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Je nach Elementtyp gilt Folgendes:
- Für boolesche Werte: Logisches XOR.
- Für Ganzzahlen: bitweises XOR.
Eingaben
Label | Name | Typ | Einschränkungen |
---|---|---|---|
(I1) | lhs |
Tensor des booleschen oder Ganzzahltyps | (C1) |
(I2) | rhs |
Tensor des booleschen oder Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkungen |
---|---|---|
result |
Tensor des booleschen oder Ganzzahltyps | (C1) |
Einschränkungen
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Dialekt-Interop
Derzeit enthalten StableHLO-Programme manchmal Vorgänge, die nicht von StableHLO definiert sind.
Modul, Funktion, Aufruf und Rückgabe
StableHLO verwendet vorgelagerte MLIR-Vorgänge für ModuleOp, FuncOp, CallOp und ReturnOp. Dies dient der besseren Interoperabilität mit vorhandenen MLIR-Maschinen, da viele nützliche Karten/Tickets für FuncOp und ModuleOp geschrieben werden und viele Kompilierungspipelines erwarten, dass diese Operationen vorhanden sind. Für diese Vorgänge gelten vollständige Kompatibilitätsgarantien. Sollten sich diese Vorgänge auf inkompatible Weise ändern (z.B. Entfernung), werden StableHLO-Äquivalente hinzugefügt, um die Kompatibilität zu erhalten.
CHLO
Das CHLO-Opset enthält Vorgänge auf höherer Ebene, die in StableHLO zerlegt werden. Derzeit gibt es keine Kompatibilitätsgarantien für CHLO. Für Kompatibilitätsgarantien muss vor der Serialization der Pass chlo-legalize-to-stablehlo verwendet werden.
Formvorgänge
In der Community ist es üblich, bestimmte Operationen aus den Kern-MLIR-Dialekten in dynamischen StableHLO-Programmen zu verwenden, um Formberechnungen durchzuführen.
Dazu gehören in der Regel shape
-Dialekt-Befehle wie shape_of
oder num_elements
, tensor
-Dialekt-Befehle wie dim
oder from_elements
und der vordefinierte Typ index
.
Im Dynamism RFC > O2 werden diese als nicht relevant eingestuft. Aus Gründen der Interoperabilität wird jedoch eine gewisse Unterstützung für index
-Typen angeboten. Für diese Optionen oder Typen gibt es keine Kompatibilitätsgarantien. Mit dem Durchlauf shape-legalize-to-stablehlo können diese Vorgänge in vollständig unterstützte StableHLO-Vorgänge konvertiert werden.
Veraltete Vorgänge
Es gibt mehrere StableHLO-Vorgänge, die von MHLO übernommen wurden, die jedoch eingestellt wurden und aus der StableHLO entfernt wurden. Alle Details zu diesen Entfernungen finden Sie im StableHLO v1.0 Cleanup #2283. Das Tracker-Problem für diese Einstellung ist #2340.
Diese Vorgänge lassen sich in einige Kategorien unterteilen:
- Kategorie „Nicht in HLO“ von StableHLO-Vorgängen – sie waren ursprünglich Teil des StableHLO-Opsets, aber später wurde davon ausgegangen, dass sie nicht gut geeignet waren:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(Nr. 3). - Nicht verwendete Vorgänge: Diese Vorgänge waren möglicherweise einmal nützlich, wurden aber entweder nicht ausreichend entwickelt oder die Pipelines, in denen sie verwendet wurden, wurden so umgestaltet, dass sie nicht mehr erforderlich sind. Dazu gehören
map
,tuple
(#598),get_tuple_element
,rng
,complex
-Vergleiche #560 und diewindow_reversal
-Konvolution (#1181).
Einige dieser Vorgänge können einfach entfernt werden, da sie mithilfe vorhandener Vorgänge (broadcast
, create_token
, cross-replica-sum
, dot
, unary_einsum
) ausgedrückt werden können. Sie werden nach Ablauf des bestehenden Kompatibilitätszeitraums (6 Monate) entfernt. Andere werden noch entfernt (einsum
, get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
Vergleiche, window_reversal
). Ausstehendes Community-Feedback wird entweder entfernt oder mit vollständiger Unterstützung der Spezifikation hinzugefügt. Solange diese Ops Futures nicht bekannt sind, wird die Kompatibilität nur für 6 Monate garantiert.
Ausführung
Sequenzielle Ausführung
Ein StableHLO-Programm wird ausgeführt, indem Eingabewerte für die Funktion main
bereitgestellt und Ausgabewerte berechnet werden. Die Ausgabewerte einer Funktion werden berechnet, indem der Graph der Vorgänge ausgeführt wird, der auf der entsprechenden return
-Operation basiert.
Die Ausführungsreihenfolge ist implementierungsdefiniert, solange sie auf Dataflow ausgerichtet ist, d.h., wenn Vorgänge vor ihrer Verwendung ausgeführt werden. In StableHLO verbrauchen alle Nebeneffekte Operationen ein Token und erzeugen ein Token (mehrere Token können über after_all
in ein Token umgewandelt werden), sodass die Ausführungsreihenfolge der Nebeneffekte auch mit Dataflow übereinstimmt. Im folgenden Programm gibt es beispielsweise zwei mögliche Ausführungsreihenfolgen: %0
→ %1
→ %2
→ return
und %1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Formeller ausgedrückt ist ein StableHLO-Prozess eine Kombination aus:
1) einem StableHLO-Programm, 2) Ausführungsstatus (noch nicht ausgeführt, bereits ausgeführt) und 3) Zwischenwerten, an denen der Prozess gerade arbeitet.
Der Prozess beginnt mit Eingabewerten für die main
-Funktion, geht über den Graphen der Vorgänge mit Aktualisierung des Vorgangsstatus und der Zwischenwerte und endet mit Ausgabewerten. Weitere Formalisierung ist noch nicht festgelegt (#484).
Parallele Ausführung
StableHLO-Programme können parallel ausgeführt werden und in einem 2D-Prozessraster von num_replicas
× num_partitions
organisiert werden, die beide den Typ ui32
haben.
Im StableHLO-Prozess-Raster werden num_replicas * num_partitions
StableHLO-Prozesse gleichzeitig ausgeführt. Jeder Prozess hat eine eindeutige process_id = (replica_id, partition_id)
. replica_id
in replica_ids = range(num_replicas)
und partition_id
in partition_ids = range(num_partitions)
haben beide den Typ ui32
.
Die Größe des Prozessrasters ist für jedes Programm statisch bekannt (in Zukunft planen wir, es zu einem expliziten Teil der StableHLO-Programme #650 zu machen) und die Position innerhalb des Prozessrasters ist für jeden Prozess statisch bekannt. Jeder Prozess hat über die replica_id
- und partition_id
-Vorgänge Zugriff auf seine Position im Prozess-Raster.
Innerhalb des Prozessrasters können die Programme alle gleich sein (im Stil „Einzelnes Programm, mehrere Daten“), alle unterschiedlich sein (im Stil „Mehrere Programme, mehrere Daten“) oder etwas dazwischen. Künftig planen wir, die Unterstützung für andere Definitionen paralleler StableHLO-Programme einzuführen, einschließlich GSPMD (#619).
Innerhalb des Prozess-Rasters sind die Prozesse größtenteils unabhängig voneinander. Sie haben separate Vorgangsstatus, separate Eingabe-/Zwischen-/Ausgabewerte und die meisten Vorgänge werden zwischen den Prozessen separat ausgeführt, mit Ausnahme einer kleinen Anzahl von kollektiven Vorgängen, die unten beschrieben werden.
Da für die Ausführung der meisten Vorgänge nur Werte aus demselben Prozess verwendet werden, ist es in der Regel eindeutig, diese Werte anhand ihrer Namen zu referenzieren.
Bei der Beschreibung der Semantik von Gruppenoperationen ist dies jedoch nicht ausreichend. Daher wird die Schreibweise name@process_id
verwendet, um sich auf den Wert name
innerhalb eines bestimmten Prozesses zu beziehen. (Aus dieser Perspektive kann name
ohne Qualifikation als Kurzform für name@(replica_id(), partition_id())
betrachtet werden.)
Die Ausführungsreihenfolge zwischen Prozessen ist implementierungsdefiniert, mit Ausnahme der Synchronisierung, die durch Punkt-zu-Punkt-Kommunikation und kollektive Operationen eingeführt wird, wie unten beschrieben.
Punkt-zu-Punkt-Kommunikation
StableHLO-Prozesse können über StableHLO-Kanäle miteinander kommunizieren. Ein Kanal wird durch eine positive ID vom Typ si64
dargestellt. Über verschiedene Vorgänge ist es möglich, Werte an Kanäle zu senden und von Kanälen zu empfangen.
Weitere Informationen, z. B. woher diese Kanal-IDs stammen, wie sie in Verarbeitungsprogrammen erkannt werden und welche Art von Synchronisierung dadurch eingeführt wird, sind noch nicht verfügbar (#484).
Streaming-Kommunikation
Jeder StableHLO-Prozess hat Zugriff auf zwei Streaming-Schnittstellen:
- Infeed, aus dem gelesen werden kann.
- Outfeed, in den geschrieben werden kann.
Im Gegensatz zu Channels, die für die Kommunikation zwischen Prozessen verwendet werden und daher Prozesse an beiden Enden haben, wird bei In-Feeds und Out-Feeds das andere Ende implementiert.
Weitere Formalisierungen, z. B. wie die Streamingkommunikation die Ausführungsreihenfolge beeinflusst und welche Art von Synchronisierung dadurch eingeführt wird, sind noch offen (#484).
Collective-Vorgänge
In StableHLO gibt es sechs kollektive Vorgänge: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
und reduce_scatter
. Alle diese Vorgänge teilen die Prozesse im StableHLO-Prozess-Raster in StableHLO-Prozessgruppen auf und führen innerhalb jeder Prozessgruppe unabhängig von anderen Prozessgruppen eine gemeinsame Berechnung aus.
Innerhalb jeder Prozessgruppe können gemeinsame Operationen eine Synchronisierungsbarriere darstellen. Eine weitere Formalisierung, z. B. die Erläuterung, wann genau diese Synchronisierung erfolgt, wie genau die Prozesse diese Barriere erreichen und was passiert, wenn sie dies nicht tun, ist noch nicht abgeschlossen (#484).
Wenn die Prozessgruppe eine partitionenübergreifende Kommunikation umfasst, d. h. es gibt Prozesse in der Prozessgruppe, deren Partitions-IDs sich unterscheiden, ist für die Ausführung der kollektiven Operation ein Kanal erforderlich und die kollektive Operation muss eine positive channel_id
vom Typ si64
angeben. Für die replikübergreifende Kommunikation sind keine Kanäle erforderlich.
Die von den gemeinsamen Vorgängen ausgeführten Berechnungen sind spezifisch für einzelne Operationen und werden oben in den einzelnen Vorgangsabschnitten beschrieben. Die Strategien, mit denen das Prozessraster in Prozessgruppen unterteilt wird, sind jedoch für alle diese Betriebsarten identisch und werden in diesem Abschnitt beschrieben. Genauer gesagt unterstützt StableHLO die folgenden vier Strategien.
cross_replica
Innerhalb jeder Prozessgruppe findet nur die replikübergreifende Kommunikation statt. Bei dieser Strategie wird replica_groups
, eine Liste von Listen mit Replik-IDs, verwendet, um ein kartesisches Produkt von replica_groups
× partition_ids
zu berechnen. replica_groups
muss eindeutige Elemente haben und alle replica_ids
abdecken. Formal mit der Python-Syntax:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Beispiel: Aus replica_groups = [[0, 1], [2, 3]]
und num_partitions = 2
wird cross_replica
zu [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Innerhalb jeder Prozessgruppe findet nur eine abteilungsübergreifende Kommunikation statt. Bei dieser Strategie wird partition_groups
– eine Liste von Listen mit Partitions-IDs – verwendet, um ein kartesisches Produkt von partition_groups
× replica_ids
zu berechnen.
partition_groups
muss eindeutige Elemente haben und alle partition_ids
abdecken.
In Python-Syntax:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Beispiel: Für partition_groups = [[0, 1]]
und num_replicas = 4
wird durch cross_partition
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
erzeugt.
cross_replica_and_partition
Innerhalb jeder Prozessgruppe kann es sowohl zu replikations- als auch zu partitionenübergreifender Kommunikation kommen. Diese Strategie verwendet replica_groups
– eine Liste mit Listen von Replikat-IDs – und berechnet kartesische Produkte jeder replica_group
nach partition_ids
. replica_groups
muss eindeutige Elemente haben und alle replica_ids
abdecken. Formal mit der Python-Syntax:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Beispiel: Aus replica_groups = [[0, 1], [2, 3]]
und num_partitions = 2
wird cross_replica_and_partition
zu [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Bei dieser Strategie wird flattened_id_groups
, eine Liste von Listen mit „zusammengeführten“ Prozess-IDs im Format replica_id * num_partitions + partition_id
, in Prozess-IDs umgewandelt. flattened_id_groups
muss eindeutige Elemente haben und alle process_ids
abdecken. In Python-Syntax:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Beispielsweise erzeugt flattened_ids
für flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
und num_partitions = 2
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Genauigkeit
Derzeit bietet StableHLO keine Garantien für die numerische Genauigkeit. Dies kann sich jedoch in Zukunft ändern (#1156).
Ausführungssemantik eines quantisierten Vorgangs
Die Interpretation quantisierter StableHLO-Vorgänge kann je nach Hardwareanforderungen und -funktionen variieren. Beispielsweise kann bei einigen Hardwaren die Quantisierung mit der Strategie „Quantisierung aufheben, Gleitkommaoperation ausführen und schließlich Quantisierung“ interpretiert werden. Andere können die gesamte Berechnung mit ganzzahliger Arithmetik durchführen. Die Interpretation von quantisierten StableHLO-Vorgängen wird daher ausschließlich von der jeweiligen Implementierung bestimmt. Die Interpretation der hybriden Quantisierung (#1575) sollte auf der in der Spezifikation vorgeschriebenen Semantik basieren (über 1792).
Fehler
StableHLO-Programme werden anhand einer umfangreichen Reihe von Einschränkungen für einzelne Vorgänge validiert, wodurch viele Fehlerklassen vor der Laufzeit ausgeschlossen werden. Fehlerbedingungen sind jedoch weiterhin möglich, z. B. durch Ganzzahlüberläufe oder Zugriffe außerhalb des Gültigkeitsbereichs. Sofern nicht ausdrücklich anders angegeben, führen alle diese Fehler zu einem implementierungsdefinierten Verhalten. Dies kann sich jedoch in Zukunft ändern (#1157).
Gleitkommaausnahmen
Als Ausnahme von dieser Regel haben Gleitkommaausnahmen in StableHLO-Programmen ein klar definiertes Verhalten. Bei Operationen, die gemäß IEEE-754-Standard zu Ausnahmen führen (ungültiger Vorgang, Division durch Null, Überlauf, Unterlauf oder ungenaue Ausnahmen), werden Standardergebnisse (wie im Standard definiert) ausgegeben und die Ausführung wird fortgesetzt, ohne dass das entsprechende Statusflag gesetzt wird. Das entspricht der raiseNoFlag
-Ausnahmebehandlung des Standards. Ausnahmen für nicht standardmäßige Vorgänge (z. B. komplexe Arithmetik und bestimmte transzendentale Funktionen) sind implementierungsabhängig.
Nicht übereinstimmende Formen
StableHLO unterstützt Tensoren mit dynamischer Form. Die Formen müssen jedoch zur Laufzeit übereinstimmen, andernfalls ist das Verhalten nicht definiert. StableHLO bietet nicht explizit eine Operation, die bestätigen kann, dass ein Tensor zur Laufzeit eine bestimmte Form hat. Für die Generierung des korrekten Codes ist der Ersteller verantwortlich.
Als Beispiel ist das folgende Programm gültig. Bei der Laufzeit müssen die genauen Formen von %arg0
und %arg1
jedoch identisch sein, da das Verhalten des Programms sonst nicht definiert ist:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Für die Beschreibung der Syntax wird in diesem Dokument die modifizierte ISO-Variante der EBNF-Syntax (ISO/IEC 14977:1996, Wikipedia) mit zwei Änderungen verwendet: 1) Regeln werden mit ::=
anstelle von =
definiert,
2) Die Konkatenierung wird durch Nebeneinanderstellung und nicht durch ,
ausgedrückt.
Zur Beschreibung der Semantik (d. h. in den Abschnitten „Types“, „Constants“ und „Ops“) verwenden wir Formeln, die auf der Python-Syntax basieren. Diese Erweiterung unterstützt jetzt prägnante Array-Operationen, wie unten beschrieben. Das funktioniert gut für kleine Code-Snippets. In seltenen Fällen, in denen größere Code-Snippets erforderlich sind, verwenden wir die Vanilla-Python-Syntax, die immer explizit eingeführt wird.
Formeln
Sehen wir uns anhand eines Beispiels aus der dot_general
-Spezifikation an, wie Formeln funktionieren. Eine der Einschränkungen für diesen Vorgang sieht so aus:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Die in dieser Formel verwendeten Namen stammen aus zwei Quellen: 1) globale Funktionen, z. B. dim
, 2) Mitgliedsdefinitionen des entsprechenden Programmelements, z. B. lhs
-, lhs_batching_dimensions
-, rhs
- und rhs_batching_dimensions
-Eingaben, die im Abschnitt „Eingaben“ von dot_general
definiert sind.
Wie bereits erwähnt, basiert die Syntax dieser Formel auf Python mit einigen prägnanten Erweiterungen. Um die Formel besser zu verstehen, wandeln wir sie in die Standard-Python-Syntax um.
A): In diesen Formeln wird =
für Gleichheit verwendet. Der erste Schritt zur Python-Syntax besteht also darin, =
durch ==
zu ersetzen:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Außerdem unterstützen diese Formeln Ellipsen (...
), mit denen Skalarausdrücke in Tensorausdrücke umgewandelt werden. Kurz gesagt bedeutet f(xs...)
in etwa: „Für jeden Skalar x
im Tensor xs
einen Skalar f(x)
berechnen und dann alle diese Skalarergebnisse zusammen als Tensorergebnis zurückgeben“. In der Standard-Python-Syntax sieht unsere Beispielformel so aus: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Dank Ellipsen ist es oft möglich, die Arbeit auf der Ebene einzelner Skalare zu vermeiden. In einigen schwierigen Fällen kann jedoch eine semi-informelle Syntax niedrigerer Ebene verwendet werden, wie in der start_indices[bi0, ..., :, ..., biN]
-Formel aus der gather
-Spezifikation. Aus Gründen der Übersichtlichkeit geben wir keinen genauen Formalismus für die Übersetzung dieser Syntax in Vanilla-Python an. Wir hoffen, dass sie trotzdem von Fall zu Fall intuitiv verständlich ist.
Bitte teilen Sie uns mit, wenn bestimmte Formeln unklar sind. Wir werden versuchen, sie zu verbessern.
Außerdem werden Sie feststellen, dass Formeln Ellipsen verwenden, um alle Arten von Listen zu erweitern, einschließlich Tensoren, Listen von Tensoren (die z.B. aus einer variadischen Anzahl von Tensoren stammen können). In diesem Bereich bieten wir keinen genauen Formalismus (z.B. sind Listen nicht einmal Teil des StableHLO-Typsystems) und basieren stattdessen auf der intuitiven Bedienbarkeit.
C) Das letzte bemerkenswerte Notationsmittel, das wir verwenden, ist die implizite Übertragung. Das StableHLO-Opset unterstützt zwar kein implizites Broadcasting, die Formeln jedoch schon, was der Präzision dient. Wenn ein Skalar in einem Kontext verwendet wird, in dem ein Tensor erwartet wird, wird der Skalar auf die erwartete Form gesendet.
Im Beispiel für dot_general
ist hier eine weitere Einschränkung:
0 <= lhs_batching_dimensions < rank(lhs)
. Wie in der dot_general
-Spezifikation definiert, ist lhs_batching_dimensions
ein Tensor. 0
und rank(lhs)
sind jedoch Skalare. Nachdem implizites Broadcasting angewendet wurde, wird die Formel zu [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Wenn diese Formel auf einen bestimmten dot_general
-Vorgang angewendet wird, ergibt sich ein Tensor von Booleschen Werten. Wenn Formeln als Einschränkungen verwendet werden, gilt die Einschränkung, wenn die Formel entweder true
oder einen Tensor ergibt, der nur true
Elemente hat.
Namen
Der lexikalische Bereich in Formeln umfasst: 1) globale Funktionen, 2) Elementdefinitionen,
3) lokale Definitionen. Die Liste der globalen Funktionen finden Sie unten. Die Liste der Elementdefinitionen hängt vom Programmelement ab, auf das die Notation angewendet wird:
- Bei Vorgängen umfassen die Mitgliederdefinitionen Namen, die in den Abschnitten „Eingaben“ und „Ausgaben“ eingeführt wurden.
- Alle anderen Mitgliedsdefinitionen enthalten strukturelle Teile des Programmelements, die nach den entsprechenden EBNF-Nichtterminals benannt sind. In den meisten Fällen werden die Namen dieser strukturellen Teile durch Umwandlung der Namen der Nichtterminals in Snake Case (z. B.
IntegerLiteral
=>integer_literal
) ermittelt. Manchmal werden die Namen dabei jedoch abgekürzt (z. B.QuantizationStorageType
=>storage_type
). In diesem Fall werden die Namen ähnlich wie die Abschnitte „Eingaben“/„Ausgaben“ in den Betriebsspezifikationen explizit eingeführt. - Darüber hinaus enthalten Mitgliederdefinitionen immer
self
, um auf das entsprechende Programmelement zu verweisen.
Werte
Bei der Auswertung von Formeln werden die folgenden Arten von Werten verwendet:
1) Value
(tatsächliche Werte, z. B. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
; ihre Typen sind immer bekannt),
2) Placeholder
(zukünftige Werte, z. B. lhs
, rhs
oder result
; ihre tatsächlichen Werte sind noch nicht bekannt, nur ihre Typen),
3) Type
(Typen wie im Abschnitt „Typen“ definiert),
4) Function
(globale Funktionen wie im Abschnitt „Funktionen“ definiert).
Je nach Kontext können sich Namen auf unterschiedliche Werte beziehen. Im Abschnitt „Semantik“ für Operatoren (und entsprechende Abschnitte für andere Programmelemente) wird die Laufzeitlogik definiert, sodass alle Eingaben als Value
verfügbar sind.
Im Abschnitt „Einschränkungen“ für Operatoren (und entsprechende Funktionen) wird dagegen die Logik für die „Kompilierungszeit“ definiert, d. h. etwas, das in der Regel vor der Laufzeit ausgeführt wird. Daher sind nur konstante Eingaben als Value
und andere Eingaben nur als Placeholder
verfügbar.
Namen | Unter „Semantik“ | Unter „Einschränkungen“ |
---|---|---|
Globale Funktionen | Function |
Function |
Konstante Eingaben | Value |
Value |
Nicht konstante Eingaben | Value |
Placeholder |
Ausgaben | Value |
Placeholder |
Lokale Definitionen | Abhängig von der Definition | Abhängig von der Definition |
Sehen wir uns ein Beispiel für einen transpose
-Vorgang an:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Bei diesem Vorgang ist permutation
eine Konstante und steht daher sowohl in der Semantik als auch in den Einschränkungen als Value
zur Verfügung. Im Gegensatz dazu sind operand
und result
in der Semantik als Value
, aber in Einschränkungen nur als Placeholder
verfügbar.
Funktionen
Konstruktion von Typen
Es gibt keine Funktionen, mit denen Typen erstellt werden können. Stattdessen verwenden wir direkt die Typsyntax, da sie in der Regel prägnanter ist. Beispiel: (tensor<E>, tensor<E>) -> (tensor<E>)
anstelle von function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funktionen für Typen
element_type
ist für Tensortypen und quantisierte Tensortypen definiert und gibt jeweils denTensorElementType
- oderQuantizedTensorElementType
-Teil des entsprechendenTensorType
oderQuantizedTensorType
zurück.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
ist eine Verknüpfung füris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
ist eine Tastenkombination füris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
prüft, ob der Typx
auf den Typy
hochgestuft werden kann. Wennx
undy
QuantizedTensorElementType
s sind, wird das Angebot nur auf denstorage_type
angewendet. Diese spezielle Version des Angebots wird derzeit im Zusammenhang mit der Berechnung von Rabatten verwendet. Weitere Informationen finden Sie im RFC.
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
ist eine Tastenkombination füris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Verfügbar für alle Typen. Beispiel:is_float(x)
gibttrue
zurück, wennx
eineFloatType
ist. Wennx
ein Wert oder Platzhalter ist, ist diese Funktion eine Verknüpfung füris_type_name(type(x))
.max_value(x: Type) -> Value
gibt den Maximalwert einerTensorElementType
zurück. Wennx
keinTensorElementType
ist, wirdNone
zurückgegeben.min_value(x: Type) -> Value
gibt den kleinstmöglichen Wert einerTensorElementType
zurück. Wennx
keinTensorElementType
ist, wirdNone
zurückgegeben.member_name(x: Value | Placeholder | Type) -> Any
. Verfügbar für alle Mitgliedsdefinitionenmember_name
aller Typen. Beispiel:tensor_element_type(x)
gibt denTensorElementType
-Teil einer entsprechendenTensorType
zurück. Wennx
ein Wert oder Platzhalter ist, ist diese Funktion eine Kurzform fürmember_name(type(x))
. Wennx
kein Typ mit einem entsprechenden Mitglied oder einem Wert oder Platzhalter eines solchen Typs ist, wirdNone
zurückgegeben.is_empty_algorithm(*args: Type)
prüft, ob alle Punktalgorithmusfelder aufNone
gesetzt sind. Das ist erforderlich, da Punktalgorithmen standardmäßige Verhaltensweisen haben, die von der Implementierung abhängen. Die Angabe eines Standardwerts wäre daher falsch.
Konstruktion von Werten
operation_name(*xs: Value | Type) -> Value
. Verfügbar für alle Vorgänge.add(lhs, rhs)
nimmt beispielsweise zwei Tensorwertelhs
undrhs
an und gibt die Ausgabe der Auswertung des Vorgangsadd
mit diesen Eingaben zurück. Bei einigen Vorgängen, z. B.broadcast_in_dim
, sind die Ausgabetypen „tragfähig“, d. h. sie sind zur Auswertung eines Vorgangs erforderlich. In diesem Fall nimmt die Funktion diese Typen als Argumente an.
Funktionen für Werte
Alle Python-Operatoren und -Funktionen sind verfügbar. So sind beispielsweise sowohl Abonnements als auch Schnitte aus Python verfügbar, um Tensoren, quantisierte Tensoren und Tupel zu indexieren.
to_destination_type(x: Value, destination_type: Type) -> Value
ist auf Tensoren definiert und gibt den konvertierten Wert vonx
basierend auftype(x)
unddestination_type
so zurück:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Es gibt erste Diskussionen zur Zusammenführung der Vorgänge convert
, uniform_quantize
und uniform_dequantize
(#1576).
Nach dem Zusammenführen benötigen wir die obige Funktion nicht mehr und können stattdessen den Namen der Operation für convert
verwenden.
is_nan(x: Value) -> Value
ist für Tensoren definiert und gibttrue
zurück, wenn alle Elemente vonx
NaN
sind, andernfallsfalse
. Wennx
kein Tensor ist, wirdNone
zurückgegeben.is_sorted(x: Value) -> Value
ist für Tensoren definiert und gibttrue
zurück, wenn die Elemente vonx
in aufsteigender lexikalischer Reihenfolge ihrer Indizes sortiert sind, andernfallsfalse
. Wennx
kein Tensor ist, wirdNone
zurückgegeben.is_unique(x: Value) -> Value
ist für Tensoren definiert und gibttrue
zurück, wennx
keine doppelten Elemente enthält, andernfallsfalse
. Wennx
kein Tensor ist, wirdNone
zurückgegeben.member_name(x: Value) -> Any
ist für alle Mitgliederdefinitionenmember_name
aller Werte definiert.real_part(x)
gibt beispielsweise denRealPart
-Teil einer entsprechendenComplexConstant
zurück. Wennx
kein Wert mit einem entsprechenden Mitglied ist, wirdNone
zurückgegeben.same(x: Value) -> Value
ist für Tensoren definiert und gibttrue
zurück, wenn die Elemente vonx
alle gleich sind, ansonstenfalse
. Wenn der Tensor keine Elemente hat, zählt dies als „alle gleich einander“, d.h. die Funktion gibttrue
zurück. Wennx
kein Tensor ist, wirdNone
zurückgegeben.split(x: Value, num_results: Value, axis: Value) -> Value
ist auf Tensoren definiert und gibtnum_results
-Schichten vonx
entlang der Achseaxis
zurück. Wennx
kein Tensor oderdim(x, axis) % num_results != 0
ist, wirdNone
zurückgegeben.is_defined_in_parent_scope(x: Value) -> Value
ist für Strings definiert und gibttrue
zurück, wennx
der Name einer Funktion ist, die im selben Gültigkeitsbereich wie die übergeordnete Funktion der entsprechenden Operation definiert ist.is_namespaced_op_name(x: Value) -> Value
ist für Strings definiert und gibttrue
zurück, wennx
ein gültiger Vorgangsname ist, d. h. der folgende reguläre Ausdruck berücksichtigt wird:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Formberechnungen
axes(x: Value | Placeholder | Type) -> Value
ist eine Tastenkombination fürrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
ist eine Tastenkombination fürshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
ist eine Tastenkombination fürlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
ist für Tensoren definiert und gibtsize(x)
-Indizes für die entsprechendenTensorType
zurück, sortiert in aufsteigender lexikographischer Reihenfolge, also[0, ..., 0]
,[0, ..., 1]
, …,shape(x) - 1
. Wennx
kein Tensortyp, kein quantisierter Tensortyp, kein Wert oder kein Platzhalter eines dieser Typen ist, wirdNone
zurückgegeben.rank(x: Value | Placeholder | Type) -> Value
ist eine Tastenkombination fürsize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
wird im Abschnitt „Funktionen für Typen“ übermember_name
definiert.size(x: Value | Placeholder | Type) -> Value
ist eine Tastenkombination fürreduce(lambda x, y: x * y, shape(x))
.
Quantisierungsberechnungen
def baseline_element_type(x: Value | Placeholder | Type) -> Type
ist eine Tastenkombination fürelement_type(baseline_type(x))
.baseline_type
wird für Tensortypen und quantisierte Tensortypen definiert und transformiert sie in eine "Baseline", d.h. einen Typ mit derselben Form, aber mit den Quantisierungsparametern des Elementtyps, die auf Standardwerte zurückgesetzt werden. Dies ist ein praktischer Trick, um sowohl Tensor- als auch quantisierte Tensortypen einheitlich zu vergleichen, was häufig erforderlich ist. Bei quantisierten Typen können so Typen verglichen werden, wobei die Quantisierungsparameter ignoriert werden. Das bedeutet, dassshape
,storage_type
,expressed_type
,storage_min
,storage_max
undquantization_dimension
(für den quantisierten Typ pro Achse) alle übereinstimmen müssen, sichscales
undzero points
aber unterscheiden können.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
ist für quantisierte Tensortypen definiert und wandelt sie in Gleitkomma-Tensortypen um. Dazu werden quantisierte Elemente, die Ganzzahlwerte des Speichertyps darstellen, mithilfe des Nullpunkts und der Skala, die mit dem quantisierten Elementtyp verknüpft sind, in entsprechende Gleitkommawerte des Ausdruckstyps umgewandelt.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
wird für Gleitkomma-Tensortypen definiert und wandelt sie in quantisierte Tensortypen um. Dazu werden Gleitkommawerte des Ausdruckstyps mithilfe des Nullpunkts und der Skalierung, die mit dem quantisierten Elementtyp verknüpft sind, in entsprechende Ganzzahlwerte des Speichertyps umgewandelt.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
wird verwendet, um elementweise Berechnungen für quantisierte Tensoren anzugeben. Er führt eine Dequantisierung durch, d. h. wandelt quantisierte Elemente in ihre Ausdruckstypen um, führt dann einen Vorgang aus und quantisiert die Ergebnisse dann wieder, d. h. wandelt sie in ihre Speichertypen um. Derzeit funktioniert diese Funktion nur für die Quantisierung pro Tensor. Die Quantisierung nach Achsen ist noch in der Entwicklungsphase (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
- Mit
hybrid_dequantize_then_op
wird eine Quantisierung nur für Gewichte für eine hybride Operation angegeben, bei der der linke Operand als Gleitkommazahl und der rechte Operand als quantisierter Typ akzeptiert wird. Er dequantisiert quantisierte Eingaben in ihre expliziten Typen und führt eine Berechnung als Gleitkommazahl durch. Der Elementtyp des Gleitkommazahl-Tensors und der ausgedrückte Typ des quantisierten rhs-Tensors sollten identisch sein.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Rasterberechnungen
cross_partition(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt zu "cross_Replikat".cross_replica(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt „cross_replica“.cross_replica_and_partition(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt "cross_replica_and_partition".flattened_ids(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt „flattened_ids“.
Dynamismus
StableHLO-Werte können dynamische Dimensionsgrößen haben, z.B. tensor<?xi64>
.
StableHLO-Werte können jedoch keine dynamische Anzahl von Dimensionen haben (nicht sortierter Dynamismus, z.B. tensor<*xi64>
). Operanden und Ergebnisse dürfen dynamische Dimensionsgrößen verwenden, auch wenn es Einschränkungen bei den Größen gibt. Einschränkungen werden nach Möglichkeit statisch überprüft. Andernfalls werden sie auf die Laufzeit zurückgestellt und Abweichungen führen zu nicht definiertem Verhalten. Siehe Beispiele unten.
Formenabweichungen bei unären elementweisen Vorgängen
Betrachten Sie das folgende Beispielprogramm:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Ein solches Programm ist ungewöhnlich, da es nicht üblich ist, die Form des Ergebnisses, aber nicht die Form der Eingabe zu kennen. Dies ist jedoch ein gültiges StableHLO-Programm. Der abs
-Vorgang in diesem Programm kann nicht statisch validiert werden, da die genaue Form des Operanden unbekannt ist. Die Formen sind jedoch kompatibel und das kann statisch geprüft werden: ?
könnte bei Laufzeit 2
sein, ohne dass es zu Problemen kommt. ?
kann sich jedoch auch als eine andere Ganzzahl herausstellen. In diesem Fall ist das Verhalten nicht definiert.
Wenn die Größe einer Dimension im Ergebnis dynamisch ist, kann es nicht zu einem nicht definierten Verhalten kommen. Es gibt keine „erwartete“ Größe, daher kann es auch keine Abweichung geben.
Nicht übereinstimmende Formen bei binären elementweisen Vorgängen
Betrachten Sie das folgende Beispielprogramm:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Bei binären elementweisen Vorgängen müssen die Formen der Eingaben und des Ergebnisses zur Laufzeit übereinstimmen. Zum Zeitpunkt der Kompilierung müssen statische Dimensionen gleich sein. Andernfalls müssen sie lediglich kompatibel sein. Wenn eine Dimension in den Eingaben dynamisch ist, kann es bei der Laufzeit zu undefiniertem Verhalten kommen, da die dynamische Größe möglicherweise nicht mit der entsprechenden Größe im anderen Operanden übereinstimmt (sei es statisch oder dynamisch). Wenn alle Eingaben statisch sind, spielt es keine Rolle, ob das Ergebnis dynamisch ist oder nicht: Statisch bekannte Dimensionen werden statisch geprüft und dynamische Dimensionen erzwingen keine Einschränkungen.
Abweichungen bei der Form für Vorgänge, bei denen die Ausgabeform als Operand verwendet wird
Betrachten Sie das folgende Beispielprogramm:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Die Werte im Formoperanden zur Laufzeit müssen der Form des Ergebnisses entsprechen. Andernfalls ist das Verhalten nicht definiert. Das bedeutet, dass %arg0
während der Laufzeit den Wert dense<[3, 4]> : tensor<2xi32>
haben muss. Wenn der Formoperand konstant ist, kann dies statisch überprüft werden. Wenn die Ergebnisform vollständig dynamisch ist, kann keine Abweichung auftreten.