Stabile HLO ist ein Vorgangssatz für High-Level-Vorgänge (HLO) in Modellen für maschinelles Lernen (ML). StableHLO dient als Übertragbarkeitsebene zwischen verschiedenen ML-Frameworks und ML-Compilern: ML-Frameworks, die StableHLO-Programme produzieren, sind mit ML-Compilern kompatibel, die StableHLO-Programme verwenden.
Unser Ziel ist es, die ML-Entwicklung zu vereinfachen und zu beschleunigen, indem wir mehr Interoperabilität zwischen verschiedenen ML-Frameworks (wie TensorFlow, JAX und PyTorch) und ML-Compilern (wie XLA und IREE) schaffen. Zu diesem Zweck enthält dieses Dokument eine Spezifikation für die Programmiersprache StableHLO.
Diese Spezifikation enthält drei Hauptabschnitte. Zuerst wird im Abschnitt Programme die Struktur von StableHLO-Programmen beschrieben, die aus StableHLO-Funktionen, die wiederum aus StableHLO-Vorgängen bestehen, bestehen. Innerhalb dieser Struktur legt der Abschnitt Ops die Semantik einzelner Vorgänge fest. Der Abschnitt Ausführung enthält die Semantik für alle diese Vorgänge, die zusammen innerhalb eines Programms ausgeführt werden. Schließlich wird im Abschnitt Notation die in der Spezifikation verwendete Notation erläutert.
Programme
Program ::= {Func}
StableHLO-Programme bestehen aus einer beliebigen Anzahl von StableHLO-Funktionen.
Unten sehen Sie ein Beispielprogramm mit der Funktion @main
, die 3 Eingaben (%image
, %weights
und %bias
) und 1 Ausgabe hat. Der Hauptteil der Funktion
hat 6 Operationen.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funktionen
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
StableHLO-Funktionen (auch als benannte Funktionen bezeichnet) haben eine Kennung, Ein-/Ausgaben und einen Textkörper. Für die Zukunft planen wir, zusätzliche Metadaten für Funktionen einzuführen, um die Kompatibilität mit HLO zu verbessern (#425, #626, #740, #744).
IDs
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
StableHLO-Kennungen ähneln Kennungen in vielen Programmiersprachen, mit zwei Besonderheiten: 1) Alle Kennungen haben Siegel, die verschiedene Arten von Kennungen unterscheiden, 2) Wertkennungen können vollständig numerisch sein, um die Generierung von StableHLO-Programmen zu vereinfachen.
Typen
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
StableHLO-Typen werden in Werttypen (auch als erstklassige Typen bezeichnet) kategorisiert, die StableHLO-Werte darstellen und nicht wertbezogene Typen, die andere Programmelemente beschreiben. Stabile HLO-Typen ähneln Typen in vielen Programmiersprachen, wobei die Hauptbesonderheit die domainspezifische Art von StableHLO ist, was zu einigen ungewöhnlichen Ergebnissen führt (z.B. sind skalare Typen keine Werttypen).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}
Tensor-Typen stehen für Tensoren, also mehrdimensionale Arrays. Sie haben eine Form und einen Elementtyp, wobei eine Form nicht negative Dimensionsgrößen in aufsteigender Reihenfolge der entsprechenden Abmessungen darstellt (auch als Achsen bezeichnet), die von 0
bis R-1
nummeriert sind. Die Anzahl der R
-Dimensionen wird als Rang bezeichnet. Zum Beispiel ist tensor<2x3xf32>
ein Tensortyp mit der Form 2x3
und dem Elementtyp f32
. Sie hat zwei Dimensionen (oder mit anderen Worten, zwei Achsen) – 0. Dimension und 1. Dimension – deren Größe 2 und 3 sind. Es hat den Rang 2.
Damit wird die Unterstützung für statische Formen definiert, wenn die Dimensionsgrößen statisch bekannt sind. Für die Zukunft planen wir, auch dynamische Formen zu unterstützen, bei denen die Dimensionsgrößen teilweise oder vollständig unbekannt sind (Nr. 8). Außerdem planen wir, Tensortypen über Dimensionsgrößen und Elementtypen hinaus zu erweitern, um beispielsweise Layouts (#629) und Datendichte (#1078) einzubeziehen.
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Name | Typ | Einschränkung |
---|---|---|
storage_type |
Ganzzahltyp | (C1–C4), (C9) |
storage_min |
Ganzzahlkonstante | (C2), (C4), (C8) |
storage_max |
Ganzzahlkonstante | (C3), (C4), (C8) |
expressed_type |
Gleitkommatyp | (C1), (C5) |
quantization_dimension |
Optionale Ganzzahlkonstante | (C11–C13) |
scales |
variadische Zahl von Gleitkommakonstanten | (C5-C7), (C10), (C11), (C13) |
zero_points |
variadische Zahl ganzzahliger Konstanten | (C8–C10) |
Quantisierte Elementtypen stellen Ganzzahlwerte eines Speichertyps im Bereich von storage_min
bis storage_max
(einschließlich) dar, die Gleitkommawerten eines ausgedrückten Typs entsprechen. Für einen gegebenen ganzzahligen Wert i
kann der entsprechende Gleitkommawert f
als f = (i - zero_point) * scale
berechnet werden, wobei scale
und zero_point
als Quantisierungsparameter bezeichnet werden. storage_min
und storage_max
sind in der Grammatik optional, haben aber die Standardwerte min_value(storage_type)
bzw. max_value(storage_type)
. Quantisierte Elementtypen unterliegen den folgenden Einschränkungen:
- (C1)
num_bits(storage_type) < num_bits(expressed_type)
. - (C2)
type(storage_min) = storage_type
. - (C3)
type(storage_max) = storage_type
. - (C4)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C5)
type(scales...) = expressed_type
. - (C6)
0 < scales
. - (C7)
is_finite(scales...)
. - (C8)
storage_min <= zero_points <= storage_max
. - (C9)
type(zero_points...) = storage_type
. - (C10)
size(scales) = size(zero_points)
. - (C11) Wenn
is_empty(quantization_dimension)
, dannsize(scales) = 1
. - (C12)
0 <= quantization_dimension
.
Derzeit ist QuantizationScale
eine Gleitkommakonstante. Es besteht jedoch ein großes Interesse an ganzzahligen Skalen, die durch Multiplikatoren und Verschiebungen dargestellt werden. Dies ist jedoch für die Zukunft geplant (#1404).
Es wird derzeit über die Semantik von QuantizationZeroPoint
erörtert, einschließlich Typ, Werte und ob es nur einen oder potenziell mehrere Nullpunkte in einem quantisierten Tensortyp geben kann. Basierend auf den Ergebnissen dieser Diskussion kann sich die Angabe um die Nullpunkte in Zukunft ändern (#1405).
In einer weiteren laufenden Diskussion geht es um die Semantik von QuantizationStorageMin
und QuantizationStorageMax
, um zu bestimmen, ob diese Werte und die Werte quantisierter Tensoren Einschränkungen unterliegen (#1406).
Außerdem planen wir, die Darstellung unbekannter Skalen und Nullpunkte zu untersuchen, ähnlich wie die Darstellung unbekannter Dimensionsgrößen (#1407).
Quantisierte Tensortypen stellen Tensoren mit quantisierten Elementen dar. Diese Tensoren sind genau die gleichen wie reguläre Tensoren, mit der Ausnahme, dass ihre Elemente quantisierte Elementtypen anstelle von regulären Elementtypen haben.
Bei quantisierten Tensoren kann die Quantisierung pro Tensor erfolgen, d. h., sie kann einen scale
und zero_point
für den gesamten Tensor oder pro Achse haben, d. h. mit mehreren scales
und zero_points
, einem Paar pro Segment einer bestimmten Dimension quantization_dimension
. Formaler gibt es in einem Tensor-t
mit Quantisierung pro Achse dim(t, quantization_dimension)
-Slices von quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
usw. Alle Elemente im i
. Slice verwenden scales[i]
und zero_points[i]
als Quantisierungsparameter. Quantisierte Tensortypen haben die folgenden Einschränkungen:
- Für die Quantisierung pro Tensor:
- Keine zusätzlichen Einschränkungen.
- Für die Quantisierung pro Achse:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Tokentypen stellen Tokens dar, also intransparente Werte, die von einigen Vorgängen erzeugt und konsumiert werden. Tokens werden verwendet, um die Ausführungsreihenfolge für Vorgänge festzulegen, wie im Abschnitt Ausführung beschrieben.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Tupeltypen stellen Tupel dar, d.h. heterogene Listen. Tupel sind ein Legacy-Feature, das nur für die Kompatibilität mit HLO existiert. In HLO werden Tupel verwendet, um variadische Eingaben und Ausgaben darzustellen. In StableHLO werden variadische Eingaben und Ausgaben nativ unterstützt. In StableHLO werden Tupel nur vollständig verwendet, wobei z. B. T
, tuple<T>
und tuple<tuple<T>>
je nach Implementierung erheblich voneinander abweichen können. Wir planen, in Zukunft Änderungen an HLO ABI vorzunehmen, die es uns ermöglichen könnten, Tupeltypen aus StableHLO zu entfernen (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Elementtypen stellen Elemente von Tensortypen dar. Im Gegensatz zu vielen Programmiersprachen sind diese Typen in StableHLO nicht die erste Klasse. Das bedeutet, dass StableHLO-Programme Werte dieser Typen nicht direkt darstellen können. Daher ist es idiomatisch, skalare Werte vom Typ T
mit 0-dimensionalen Tensorwerten vom Typ tensor<T>
darzustellen.
- Boolescher Typ stellt die booleschen Werte
true
undfalse
dar. - Ganzzahltypen können entweder vorzeichenbehaftet (
si
) oder vorzeichenlos (ui
) sein und eine der unterstützten Bitbreiten (4
,8
,16
,32
oder64
) haben. Vorzeichenbehaftete Typen (siN
) stellen Ganzzahlwerte von-2^(N-1)
bis einschließlich2^(N-1)-1
dar. Typen ohne Vorzeichen (uiN
) stellen Ganzzahlwerte von0
bis einschließlich2^N-1
dar. - Gleitkommatypen können einer der folgenden sein:
f8E4M3FN
- undf8E5M2
-Typen, die denE4M3
- undE5M2
-Codierungen des FP8-Formats entsprechen, das unter FP8-Formate für Deep Learning beschrieben wird.f8E4M3FNUZ
- undf8E5M2FNUZ
-Typen, die denE4M3
- undE5M2
-Codierungen der FP8-Formate entsprechen, die unter Numerische 8-Bit-Formate für neuronale Deep-Learning-Netzwerke beschrieben sind.f8E4M3B11FNUZ
, der derE4M3
-Codierung der FP8-Formate entspricht, die unter HFP8-Training und -Inferenz für tiefe neuronale Deep-Learning-Netzwerke beschrieben werden.bf16
-Typ, der dembfloat16
-Format entspricht, das in BFloat16: Das Geheimnis für hohe Leistung auf Cloud TPUs beschrieben wird.f16
-,f32
- undf64
-Typen, die den in IEEE 754-Standard beschriebenen Formatenbinary16
("Halbgenauigkeit"),binary32
("Einfache Genauigkeit") bzw.binary64
("Doppelte Genauigkeit") entsprechen.
- Komplexe Typen stellen komplexe Werte dar, die einen reellen Teil und einen imaginären Teil desselben Elementtyps haben. Unterstützte komplexe Typen sind
complex<f32>
(beide Teile sind vom Typf32
) undcomplex<f64>
(beide Teile sind vom Typf64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Funktionstypen stellen sowohl benannte als auch anonyme Funktionen dar. Sie haben Eingabetypen (die Liste der Typen auf der linken Seite von ->
) und Ausgabetypen (die Liste der Typen auf der rechten Seite von ->
). In vielen Programmiersprachen sind Funktionstypen erste Klasse, aber nicht in StableHLO.
StringType ::= 'string'
Der Stringtyp stellt Bytesequenzen dar. Im Gegensatz zu vielen Programmiersprachen ist der Stringtyp in StableHLO nicht die erste Klasse, sondern wird nur verwendet, um statische Metadaten für Programmelemente anzugeben.
Operations
StableHLO-Vorgänge (auch Vorgänge genannt) stellen eine geschlossene Gruppe von übergeordneten Vorgängen in Modellen für maschinelles Lernen dar. Wie oben erwähnt, orientiert sich die StableHLO-Syntax stark an MLIR, die nicht unbedingt die ergonomische Alternative ist, aber wohl am besten zum Ziel von StableHLO passt, das Ziel von StableHLO, mehr Interoperabilität zwischen ML-Frameworks und ML-Compilern zu erreichen.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
StableHLO-Vorgänge (die auch als Vorgänge bezeichnet werden) haben einen Namen, Ein-/Ausgaben und eine Signatur. Der Name besteht aus dem Präfix stablehlo.
und einer Mnemonic, die eine der unterstützten Vorgänge eindeutig identifiziert. Unten finden Sie eine umfassende Liste aller unterstützten Vorgänge.
Derzeit enthalten StableHLO-Programme in der Wildnis manchmal Vorgänge, die in diesem Dokument nicht beschrieben werden. Für die Zukunft planen wir, diese Vorgänge entweder in das StableHLO-Opset aufzunehmen oder zu verhindern, dass sie in StableHLO-Programmen erscheinen. In der Zwischenzeit finden Sie hier eine Liste dieser Vorgänge:
builtin.module
,func.func
,func.call
undfunc.return
(#425).chlo
-Vorgänge (#602)- „Nicht in HLO“-Kategorie von StableHLO-Vorgängen – sie waren ursprünglich Teil des StableHLO-Vorgangs, wurden aber später als nicht richtig erwiesen:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(Nr. 3). - Kategorie „Dynamism“ von StableHLO-Vorgängen – sie wurden von MHLO gestartet, aber wir haben sie noch nicht angegeben:
compute_reshape_shape
,cstr_reshapable
,dynamic_broadcast_in_dim
,dynamic_conv
,dynamic_gather
,dynamic_iota
,dynamic_pad
,dynamic_reshape
,real_dynamic_slice
,set_dimension_size
(Nr. 8). - Formberechnungen, einschließlich
arith
-,shape
- undtensor
-Vorgängen (Nr. 8).
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Ops verarbeiten Eingaben und generieren Ausgaben. Eingaben werden in Eingabewerte (berechnet während der Ausführung), Eingabefunktionen (statisch bereitgestellt, weil in StableHLO bei Funktionen keine Werte der ersten Klasse sind) und Eingabeattribute (ebenfalls statisch bereitgestellt) kategorisiert. Die Art der Ein- und Ausgaben, die von einem Vorgang verbraucht und erzeugt werden, hängt von seiner Gedächtnisstütze ab. Beispielsweise verbraucht der Vorgang add
zwei Eingabewerte und erzeugt einen Ausgabewert. Im Gegensatz dazu verbraucht die select_and_scatter
-Operation 3 Eingabewerte, 2 Eingabefunktionen und 3 Eingabeattribute.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Eingabefunktionen (auch als anonyme Funktionen bezeichnet) sind den benannten Funktionen sehr ähnlich, mit folgenden Unterschieden: 1) Sie haben keine Kennung (also der Name „anonymous“), 2) sie deklarieren keine Ausgabetypen (Ausgabetypen werden aus dem return
-Vorgang in der Funktion abgeleitet).
Die Syntax für Eingabefunktionen enthält einen aktuell nicht verwendeten Teil (siehe Unused
-Produktion oben), der aus Gründen der Kompatibilität mit MLIR dient. Beim MLIR gibt es ein allgemeineres Konzept von "Regionen", in denen mehrere "Blocks" von Operationen enthalten sein können, die über Jump-Operationen miteinander verbunden sind. Diese Blöcke haben IDs, die der Unused
-Produktion entsprechen, damit sie voneinander unterschieden werden können.
StableHLO hat keine Jump-Ops, sodass der entsprechende Teil der MLIR-Syntax nicht verwendet wird (aber immer noch vorhanden ist).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Eingabeattribute haben einen Namen und einen Wert, der eine der unterstützten Konstanten ist. Sie sind die primäre Möglichkeit, statische Metadaten für Programmelemente anzugeben. Der Vorgang concatenate
verwendet beispielsweise das Attribut dimension
, um die Dimension anzugeben, mit der die Eingabewerte verkettet werden. In ähnlicher Weise verwendet der slice
-Vorgang mehrere Attribute wie start_indices
und limit_indices
, um die Grenzen anzugeben, die zum Aufteilen des Eingabewerts verwendet werden.
Derzeit enthalten StableHLO-Programme in der Wildnis manchmal Attribute, die in diesem Dokument nicht beschrieben werden. Wir planen, diese Attribute in Zukunft entweder in das StableHLO-Opset aufzunehmen oder zu verhindern, dass sie in StableHLO-Programmen erscheinen. In der Zwischenzeit finden Sie hier die Liste dieser Attribute:
layout
(#629)mhlo.frontend_attributes
(#628).mhlo.sharding
(#619)output_operand_aliases
(#740)- Standortmetadaten (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Die Vorgangssignatur besteht aus den Typen aller Eingabewerte (die Liste der Typen auf der linken Seite von ->
) und den Typen aller Ausgabewerte (die Liste der Typen auf der rechten Seite von ->
). Streng genommen sind Eingabetypen redundant und Ausgabetypen fast immer redundant (da Ausgabetypen bei den meisten StableHLO-Vorgängen aus Eingaben abgeleitet werden können). Dennoch ist die Vorgangssignatur bewusst Teil der StableHLO-Syntax, da sie mit MLIR kompatibel ist.
Unten sehen Sie ein Beispiel für eine Operation, deren Gedächtnisstütze select_and_scatter
ist. Er verbraucht 3 Eingabewerte (%operand
, %source
und %init_value
), 2 Eingabefunktionen und 3 Eingabeattribute (window_dimensions
, window_strides
und padding
). Beachten Sie, dass die Signatur des Vorgangs nur die Typen seiner Eingabewerte enthält (aber nicht die Typen von Eingabefunktionen und -attributen, die inline bereitgestellt werden).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Konstanten
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
StableHLO-Konstanten haben ein Literal und einen Typ, die zusammen einen StableHLO-Wert darstellen. Im Allgemeinen ist der Typ Teil der konstanten Syntax, es sei denn, er ist eindeutig (z.B. hat eine boolesche Konstante eindeutig den Typ i1
, während eine Ganzzahlkonstante mehrere mögliche Typen haben kann).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Boolesche Konstanten stellen die booleschen Werte true
und false
dar. Boolesche Konstanten haben den Typ i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Ganzzahlkonstanten stellen Ganzzahlwerte über Strings in Dezimal- oder Hexadezimalschreibweise dar. Andere Basen, z.B. binäre oder Oktalwerte, werden nicht unterstützt. Ganzzahlkonstanten haben die folgenden Einschränkungen:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Gleitkommakonstanten stellen Gleitkommawerte über Strings in Dezimal oder wissenschaftlicher Notation dar. Außerdem kann die Hexadezimalschreibweise verwendet werden, um die zugrunde liegenden Bits direkt im Gleitkommaformat des entsprechenden Typs anzugeben. Für Gleitkommakonstanten gelten folgende Einschränkungen:
- (C1) Wenn die nicht-hexadezimale Notation verwendet wird,
is_wellformed(float_literal, float_type)
. - (C2) Wenn die Hexadezimalschreibweise verwendet wird:
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Komplexe Konstanten stellen komplexe Werte mithilfe von Listen aus einem reellen Teil (kommt zuerst) und eines imaginären Teils (zweiter) dar. Beispiel: (1.0, 0.0) : complex<f32>
steht für 1.0 + 0.0i
und (0.0, 1.0) : complex<f32>
für 0.0 + 1.0i
. Die Reihenfolge, in der diese Teile dann im Speicher gespeichert werden, hängt von der Implementierung ab. Für komplexe Konstanten gelten die folgenden Einschränkungen:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Tensor-Konstanten stellen Tensorwerte mithilfe verschachtelter Listen dar, die über die NumPy-Notation angegeben werden. dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
stellt beispielsweise einen Tensorwert mit der folgenden Zuordnung von Indizes zu Elementen dar: {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
, {1, 2} => 6
. Die Reihenfolge, in der diese Elemente dann im Arbeitsspeicher gespeichert werden, ist durch die Implementierung definiert. Für Tensorkonstanten gelten folgende Einschränkungen:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, wobei:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, wobei:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- Andernfalls
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Quantisierte Tensorkonstanten stellen quantisierte Tensorwerte mit derselben Notation wie Tensorkonstanten dar, wobei Elemente als Konstanten ihres Speichertyps angegeben werden. Quantisierte Tensorkonstanten haben die folgenden Einschränkungen:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Stringliterale bestehen aus Byte, die mit ASCII-Zeichen und Escape-Sequenzen angegeben werden. Sie sind codierungsunabhängig, sodass die Interpretation dieser Byte implementierungsspezifisch definiert ist. Stringliterale haben den Typ string
.
Operativer Betrieb
abs
Semantik
Führt die elementweise Abs-Operationen für den operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für vorzeichenbehaftete Ganzzahlen: ganzzahliger Modulus.
- Für Gleitkommazahlen:
abs
aus IEEE-754. - Bei komplexen Zahlen: komplexer Modulus.
- Für quantisierte Typen:
dequantize_op_quantize(abs, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor einer vorzeichenbehafteten Ganzzahl, Gleitkommazahl oder eines komplexen Typs oder eines pro Tensor quantisierten Tensors | (C1–C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor einer vorzeichenbehafteten Ganzzahl oder Gleitkommazahl oder pro Tensor quantisierter Tensor | (C1–C2) |
Einschränkung
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
ist so definiert:complex_element_type(element_type(operand))
, wennis_complex(operand)
.- Andernfalls
baseline_element_type(operand)
.
Beispiele
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
Hinzufügen
Semantik
Führt die elementweise Addition der beiden Tensoren lhs
und rhs
durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches ODER
- Für Ganzzahlen: ganzzahlige Addition
- Für Gleitkommazahlen:
addition
aus IEEE-754. - Bei komplexen Zahlen: komplexe Addition.
- Für quantisierte Typen:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
(I2) | rhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantik
Damit wird sichergestellt, dass die Vorgänge, die inputs
generieren, vor Vorgängen ausgeführt werden, die von result
abhängen. Die Ausführung dieses Vorgangs bewirkt nichts, sondern dient nur zum Einrichten von Datenabhängigkeiten von result
bis inputs
.
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | inputs |
variadische Anzahl von token |
Ausgaben
Name | Typ |
---|---|
result |
token |
Beispiele
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantik
Verkettet in jeder Prozessgruppe im StableHLO-Prozessraster die Werte des Tensors operand
von jedem Prozess entlang von all_gather_dim
und erzeugt einen result
-Tensor.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, die so definiert ist:
cross_replica(replica_groups)
wennchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
wennchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
wennchannel_id > 0 and use_global_device_ids = true
.
Danach geschieht in jedem process_group
Folgendes:
operands@receiver = [operand@sender for sender in process_group]
für allereceiver
inprocess_group
.result@process = concatenate(operands@process, all_gather_dim)
für alleprocess
inprocess_group
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C6) |
(I2) | all_gather_dim |
Konstante vom Typ si64 |
(C1), (C6) |
(I3) | replica_groups |
2-dimensionale Tensorkonstante vom Typ si64 |
(C2–C4) |
(I4) | channel_id |
Konstante vom Typ si64 |
(C5) |
(I5) | use_global_device_ids |
Konstante vom Typ i1 |
(C5) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C6) |
Einschränkung
- (C1)
0 <= all_gather_dim < rank(operand)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
ist so definiert:num_replicas
, wenncross_replica
verwendet wird.num_replicas
, wenncross_replica_and_partition
verwendet wird.num_processes
, wennflattened_ids
verwendet wird.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Wenn
use_global_device_ids = true
, dannchannel_id > 0
. - (C6)
type(result) = type(operand)
außer:dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
all_reduce
Semantik
Wendet in jeder Prozessgruppe im StableHLO-Prozessraster die Reduktionsfunktion computation
auf die Werte des Tensors operand
aus jedem Prozess an und erzeugt einen result
-Tensor.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, die so definiert ist:
cross_replica(replica_groups)
wennchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
wennchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
wennchannel_id > 0 and use_global_device_ids = true
.
Danach geschieht in jedem process_group
Folgendes:
result@process[result_index] = exec(schedule)
für eine binäre Baumstrukturschedule
, wobei Folgendes gilt:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
ist ein implementierungsdefinierter Binärbaum, dessen Durchlauf in der Reihenfolgeto_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))
ist.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C5), (C6) |
(I2) | replica_groups |
variadische Anzahl eindimensionaler Tensorkonstanten vom Typ si64 |
(C1–C3) |
(I3) | channel_id |
Konstante vom Typ si64 |
(C4) |
(I4) | use_global_device_ids |
Konstante vom Typ i1 |
(C4) |
(I5) | computation |
Funktion | (C5) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C6–C7) |
Einschränkung
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
ist so definiert:num_replicas
, wenncross_replica
verwendet wird.num_replicas
, wenncross_replica_and_partition
verwendet wird.num_processes
, wennflattened_ids
verwendet wird.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Wenn
use_global_device_ids = true
, dannchannel_id > 0
. - (C5)
computation
hat den Typ(tensor<E>, tensor<E>) -> (tensor<E>)
, wobeiis_promotable(element_type(operand), E)
ist. - (C6)
shape(result) = shape(operand)
. - (C7)
element_type(result) = E
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
all_to_all
Semantik
Teilt in jeder Prozessgruppe im StableHLO-Prozessraster die Werte des Tensors operand
entlang von split_dimension
in Teile auf, verteilt sie auf die Prozesse, verkettet die verteilten Teile entlang von concat_dimension
und erzeugt einen result
-Tensor.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, die so definiert ist:
cross_replica(replica_groups)
, wennchannel_id <= 0
.cross_partition(replica_groups)
, wennchannel_id > 0
.
Danach geschieht in jedem process_group
Folgendes:
split_parts@sender = split(operand@sender, split_count, split_dimension)
für allesender
inprocess_group
.scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]
, wobeireceiver_index = process_group.index(receiver)
.result@process = concatenate(scattered_parts@process, concat_dimension)
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1–C3), (C9) |
(I2) | split_dimension |
Konstante vom Typ si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
Konstante vom Typ si64 |
(C3), (C9) |
(I4) | split_count |
Konstante vom Typ si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
2-dimensionale Tensorkonstante vom Typ si64 |
(C5–C8) |
(I6) | channel_id |
Konstante vom Typ si64 |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C9) |
Einschränkung
- (C1)
0 <= split_dimension < rank(operand)
. - (C2)
dim(operand, split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operand)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
ist so definiert:num_replicas
, wenncross_replica
verwendet wird.num_partitions
, wenncross_partition
verwendet wird.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(result) = type(operand)
außer:dim(result, split_dimension) = dim(operand, split_dimension) / split_count
.dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
// [5, 6],
// [9, 10],
// [13, 14]]
// %result@(1, 0): [[3, 4],
// [7, 8],
// [11, 12],
// [15, 16]]
sowie
Semantik
Führt ein elementweises AND von zwei Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches UND.
- Für Ganzzahlen: bitweises UND.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
(I2) | rhs |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
Einschränkung
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantik
Führt eine elementweise atan2-Operation für den lhs
- und rhs
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
atan2
aus IEEE-754. - Bei komplexen Zahlen: komplex atan2.
- Für quantisierte Typen:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
(I2) | rhs |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantik
Berechnet Gradienten mehrerer Eingaben von batch_norm_training
, die von grad_output
zurückpropagiert werden, und erzeugt die Tensoren grad_operand
, grad_scale
und grad_offset
. Formaler kann dieser Vorgang als Zerlegung vorhandener StableHLO-Vorgänge mithilfe der Python-Syntax wie folgt ausgedrückt werden:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Führt für quantisierte Typen den Befehl dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1–C3), (C5) |
(I2) | scale |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C4), (C5) |
(I3) | mean |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C4) |
(I4) | variance |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C4) |
(I5) | grad_output |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C2), (C3) |
(I6) | epsilon |
Konstante vom Typ f32 |
|
(I7) | feature_index |
Konstante vom Typ si64 |
(C1), (C5) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
grad_operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C2), (C3) |
grad_scale |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C4) |
grad_offset |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C4) |
Einschränkung
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
undgrad_offset
haben dasselbebaseline_element_type
. - (C3)
operand
,grad_output
undgrad_operand
haben die gleiche Form. - (C4)
scale
,mean
,variance
,grad_scale
undgrad_offset
haben die gleiche Form. - (C5)
size(scale) = dim(operand, feature_index)
.
Beispiele
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantik
Normalisiert den operand
-Tensor in allen Dimensionen mit Ausnahme der feature_index
-Dimension und erzeugt einen result
-Tensor. Formaler kann dieser Vorgang als Zerlegung vorhandener StableHLO-Vorgänge mithilfe der Python-Syntax ausgedrückt werden:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Führt für quantisierte Typen den Befehl dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1–C7) |
(I2) | scale |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C3) |
(I3) | offset |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C4) |
(I4) | mean |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C5) |
(I5) | variance |
Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ | (C2), (C6) |
(I6) | epsilon |
Konstante vom Typ f32 |
|
(I7) | feature_index |
Konstante vom Typ si64 |
(C1), (C3–C6) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C2), (C7) |
Einschränkung
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
undresult
haben denselbenbaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantik
Berechnet den Mittelwert und die Varianz über alle Dimensionen mit Ausnahme der Dimension feature_index
und normalisiert den Tensor operand
, der die Tensoren output
, batch_mean
und batch_var
erzeugt. Formaler kann dieser Vorgang als Zerlegung vorhandener StableHLO-Vorgänge mithilfe der Python-Syntax so ausgedrückt werden:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Führt für quantisierte Typen den Befehl dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
(I2) | scale |
Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor | (C2), (C3) |
(I3) | offset |
Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor | (C2), (C4) |
(I4) | epsilon |
Konstante vom Typ f32 |
(C1), (C3–C6) |
(I5) | feature_index |
Konstante vom Typ si64 |
(C1), (C3–C6) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
output |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C7) |
batch_mean |
Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor | (C2), (C5) |
batch_var |
Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor | (C2), (C6) |
Einschränkung
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
undoutput
haben denselbenbaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Beispiele
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantik
Führt eine Bitcast-Operation auf dem operand
-Tensor aus und erzeugt einen result
-Tensor, bei dem die Bits des gesamten operand
-Tensors mit dem Typ des result
-Tensors neu interpretiert werden.
Formeller anhand von E = element_type(operand)
, E' = element_type(result)
und R = rank(operand)
:
- Wenn
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Wenn
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Wenn
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
gibt die speicherinterne Darstellung eines bestimmten Werts zurück. Sein Verhalten ist implementierungsabhängig, da die genaue Darstellung der Tensoren durch die Implementierung sowie die genaue Darstellung von Elementtypen ebenfalls durch die Implementierung definiert wird.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1–C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1–C2) |
Einschränkung
- (C1) Für
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
undR = rank(operand)
gilt:- Wenn
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Wenn
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
für alle0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Wenn
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
für alle0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Wenn
- (C2) Wenn
is_complex(operand) or is_complex(result)
, dannis_complex(operand) and is_complex(result)
.
Beispiele
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantik
Erweitert die Dimensionen und/oder den Rang eines Eingabetensors durch Duplizieren der Daten im operand
-Tensor und Erzeugt einen result
-Tensor. Formaler gesagt: result[result_index] = operand[operand_index]
, wobei für alle d
in axes(operand)
Folgendes gilt:
operand_index[d] = 0
, wenndim(operand, d) = 1
.- Andernfalls
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1–C2), (C5–C6) |
(I2) | broadcast_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2–C6) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1), (C3), (C5-C6) |
Einschränkung
- (C1)
element_type(result)
wird gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, allerdings können sichquantization_dimension(operand)
,scales(operand)
undzero_points(operand)
vonquantization_dimension(result)
,scales(result)
undzero_points(result)
unterscheiden.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Für alle
d
inaxes(operand)
:dim(operand, d) = 1
oderdim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Wenn
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Wenn
dim(operand, quantization_dimension(operand)) = 1
, dannscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Beispiele
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
Supportanfrage
Semantik
Erzeugt die Ausgabe aus der Ausführung genau einer Funktion in branches
in Abhängigkeit vom Wert von index
. Formell. Für result = selected_branch()
gilt:
selected_branch = branches[index]
, wenn0 <= index < size(branches)
.- Andernfalls
selected_branch = branches[-1]
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | index |
0-dimensionaler Tensor vom Typ si32 |
|
(I2) | branches |
variadische Anzahl von Funktionen | (C1–C4) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C4) |
Einschränkung
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Beispiele
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CBRT
Semantik
Führt eine elementweise kubische Wurzeloperation für den operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
rootn(x, 3)
aus IEEE-754. - Bei komplexen Zahlen: komplexe Kubikwurzel.
- Für quantisierte Typen:
dequantize_op_quantize(cbrt, operand, type(result))
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
Ceil
Semantik
Führt einen elementweisen Ceil des operand
-Tensors aus und erzeugt einen result
-Tensor.
Implementiert den Vorgang roundToIntegralTowardPositive
aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(ceil, operand, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semantik
Berechnet die Cholesky-Zersetzung eines Batches von Matrizen.
Formaler gesagt ist result[i0, ..., iR-3, :, :]
für alle i
in index_space(result)
eine Cholesky-Zerlegung von a[i0, ..., iR-3, :, :]
in Form einer Matrix für das untere Dreieck (wenn lower
true
ist) oder eine obere Dreiecksmatrix (wenn lower
false
ist).
Die Ausgabewerte im gegenüberliegenden Dreieck, d. h. das strenge obere Dreieck bzw. das strikte untere Dreieck entsprechend, sind implementierungsdefiniert.
Wenn i
vorhanden ist, bei der die Eingabematrix keine positive Definite Matrix des Hermitian ist, ist das Verhalten nicht definiert.
Führt für quantisierte Typen den Befehl dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | a |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1–C3) |
(I2) | lower |
0-dimensionale Tensorkonstante vom Typ i1 |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Beispiele
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
einschränken
Semantik
Bindet jedes Element des Tensors operand
zwischen einem Mindest- und Höchstwert ein und erzeugt einen result
-Tensor. Formaler: result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
, wobei min_element = rank(min) = 0 ? min[] : min[result_index]
, max_element = rank(max) = 0 ? max[] : max[result_index]
. Führt für quantisierte Typen den Befehl dequantize_op_quantize(clamp, min, operand, max, type(result))
aus.
Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung komplexer Zahlen für diese Operation (#560) in Zukunft einzustellen.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | min |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C3) |
(I2) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1–C4) |
(I3) | max |
Tensor oder quantisierter Tensor pro Tensor | (C2), (C3) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C4) |
Einschränkung
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantik
Senden Sie in jeder Prozessgruppe im StableHLO-Prozessraster den Wert des Tensors operand
vom Quellprozess an die Zielprozesse und erzeugen Sie einen result
-Tensor.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, die so definiert ist:
cross_replica(replica_groups)
, wennchannel_id <= 0
.cross_partition(replica_groups)
, wennchannel_id > 0
.
Danach wird result@process
angegeben durch:
operand@process_groups[i, 0]
, wenn eini
vorhanden ist, sodass sich der Prozess inprocess_groups[i]
befindet.- Andernfalls
broadcast_in_dim(constant(0, element_type(result)), [], type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor | (C3) |
(I2) | replica_groups |
variadische Anzahl eindimensionaler Tensorkonstanten vom Typ si64 |
(C1), (C2) |
(I3) | channel_id |
Konstante vom Typ si64 |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor | (C3) |
Einschränkung
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, wobeiN
so definiert ist:num_replicas
, wenncross_replica
verwendet wird.num_partitions
, wenncross_partition
verwendet wird.
- (C3)
type(result) = type(operand)
.
Beispiele
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantik
Sendet in jeder Prozessgruppe im StableHLO-Prozessraster den Wert des Tensors operand
vom Quellprozess an den Zielprozess und erzeugt einen result
-Tensor.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, die so definiert ist:
cross_replica(source_target_pairs)
, wennchannel_id <= 0
.cross_partition(source_target_pairs)
, wennchannel_id > 0
.
Danach wird result@process
angegeben durch:
operand@process_groups[i, 0]
, wenn einei
mitprocess_groups[i, 1] = process
vorhanden ist.- Andernfalls
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C5) |
(I2) | source_target_pairs |
2-dimensionale Tensorkonstante vom Typ si64 |
(C1–C4) |
(I3) | channel_id |
Konstante vom Typ si64 |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, wobeiN
so definiert ist:num_replicas
, wenncross_replica
verwendet wird.num_partitions
, wenncross_partition
verwendet wird.
- (C5)
type(result) = type(operand)
.
Beispiele
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
Nutzer*innen
Semantik
Führt einen elementweisen Vergleich der Tensoren lhs
und rhs
gemäß comparison_direction
und compare_type
durch und erzeugt einen result
-Tensor.
Die Werte von comparison_direction
und compare_type
haben die folgende Semantik:
Für boolesche und ganzzahlige Elementtypen:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Für Gleitkomma-Elementtypen mit compare_type = FLOAT
implementiert der Vorgang die folgenden IEEE-754-Vorgänge:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Bei Gleitkomma-Elementtypen mit compare_type = TOTALORDER
verwendet der Vorgang die Kombination aus totalOrder
- und compareQuietEqual
-Vorgängen aus IEEE-754. Diese Funktion scheint ungenutzt zu sein und wird demnächst entfernt (#584).
Bei komplexen Elementtypen wird der lexikografische Vergleich von (real, imag)
-Paaren mit den bereitgestellten comparison_direction
und compare_type
durchgeführt.
Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung für komplexe Zahlen in Zukunft einzustellen, wenn comparison_direction
GE
, GT
, LE
oder LT
ist (#560).
Für quantisierte Typen. Führt dequantize_compare(lhs, rhs,
comparison_direction)
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C1–C3) |
(I2) | rhs |
Tensor oder quantisierter Tensor pro Tensor | (C1–C2) |
(I3) | comparison_direction |
Aufzählung von EQ , NE , GE , GT , LE und LT |
|
(I4) | compare_type |
Aufzählung von FLOAT , TOTALORDER , SIGNED und UNSIGNED |
(C3) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des booleschen Typs | (C2) |
Einschränkung
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
ist so definiert:SIGNED
, wennis_signed_integer(element_type(lhs))
.UNSIGNED
, wennis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oderTOTALORDER
, wennis_float(element_type(lhs))
.FLOAT
, wennis_complex(element_type(lhs))
.
Beispiele
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
komplex
Semantik
Führt eine elementweise Umwandlung in einen komplexen Wert aus einem Paar reeller und imaginärer Werte, lhs
und rhs
, durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor vom Typ f32 oder f64 |
(C1–C3) |
(I2) | rhs |
Tensor vom Typ f32 oder f64 |
(C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des komplexen Typs | (C2), (C3) |
Einschränkung
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
hat den Typcomplex<E>
, wobeiE = element_type(lhs)
ist.
Beispiele
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
concatenate
Semantik
Verkettet inputs
entlang der Dimension dimension
in der gleichen Reihenfolge wie die angegebenen Argumente und erzeugt einen result
-Tensor. Formeller gesagt, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, wobei:
id = d0 + ... + dk-1 + kd
.d
ist gleichdimension
undd0
, ... sind died
. Dimensionsgrößen voninputs
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1–C6) |
(I2) | dimension |
Konstante vom Typ si64 |
(C2), (C4), (C6) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C5–C6) |
Einschränkung
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
mit Ausnahme vondim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
mit Ausnahme von:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Beispiele
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
Konstante
Semantik
Erzeugt einen output
-Tensor aus einer konstanten value
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | value |
Konstante | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
output |
Tensor oder quantisierter Tensor | (C1) |
Einschränkung
- (C1)
type(value) = type(output)
.
Beispiele
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
eine Conversion ausführen
Semantik
Führt eine elementweise Umwandlung von einem Elementtyp in einen anderen auf dem operand
-Tensor durch und erzeugt einen result
-Tensor.
Bei Conversions vom Typ boolean-to-any-supported-type wird der Wert false
in null und der Wert true
in eins umgewandelt. Bei any-supported-type-to-boolean wird ein Nullwert in false
und Werte ungleich null in true
umgewandelt. Im Folgenden erfahren Sie, wie dies bei komplexen Typen funktioniert.
Bei Konvertierungen mit integer-to-integer, integer-to-floating-point oder floating-point-to-floating-point und wenn der Quellwert im Zieltyp genau dargestellt werden kann, ist der Ergebniswert genau diese Darstellung. Andernfalls ist das Verhalten noch nicht festgelegt (#180).
Bei Konvertierungen mit floating-point-to-integer wird der Bruchteil abgeschnitten. Wenn der abgeschnittene Wert im Zieltyp nicht dargestellt werden kann, gilt das Verhalten noch offen (#180).
Umwandlungen von komplexen bis komplex folgen demselben Verhalten wie Gleitkomma-zu-Gleitkomma-Konvertierungen zur Konvertierung von Real- und imaginären Teilen.
Bei Konvertierungen vom Typ complex-to-any-other-type und complex-to-any-other-type wird der imaginäre Quellwert ignoriert bzw. der imaginäre Zielwert auf null gesetzt. Die Umwandlung des tatsächlichen Teils folgt der Gleitkommawertkonvertierung.
Im Prinzip könnte dieser Vorgang Dequantisierung (Umwandlung von quantisierten Tensoren in reguläre Tensoren), Quantisierung (Umwandlung von regulären Tensoren in quantisierte Tensoren) und Requantisierung (Umwandlung zwischen quantisierten Tensoren) ausdrücken, aber derzeit haben wir spezielle Operationen dafür: uniform_dequantize
für den ersten Anwendungsfall und uniform_quantize
für den zweiten und dritten Anwendungsfall. In Zukunft werden diese beiden Vorgänge möglicherweise zu convert
zusammengeführt (#1576).
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor | (C1) |
Einschränkung
- (C1)
shape(operand) = shape(result)
.
Beispiele
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
Faltung
Semantik
Berechnet Punktprodukte zwischen Fenstern von lhs
und Segmenten von rhs
und erzeugt result
. Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result
aus lhs
und rhs
berechnet werden.
Formeller können Sie die folgende Neuformulierung der Eingaben in Bezug auf lhs
in Betracht ziehen, um Fenster von lhs
auszudrücken:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Dabei werden die folgenden Hilfsfunktionen verwendet:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
, wobeij[d] = i[permutation[d]]
.
Wenn feature_group_count = 1
und batch_group_count = 1
, dann für alle output_spatial_index
in index_space(dim(result, output_spatial_dimensions...))
, result[result_shape(:, output_spatial_index, :)] = dot_product
, wobei:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
: Diese Funktion scheint ungenutzt zu sein und wird demnächst entfernt (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Wenn feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Wenn batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Führt für quantisierte Typen den Befehl dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C10-C11), (C14) (C25), (C27-C30) |
(I2) | rhs |
Tensor oder quantisierter Tensor | (C1), (C14-C16), (C25), (C27-C32) |
(I3) | window_strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2–C3), (C25) |
(I4) | padding |
2-dimensionale Tensorkonstante vom Typ si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C5–C6), (C25) |
(I6) | rhs_dilation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C7–C8), (C25) |
(I7) | window_reversal |
Eindimensionale Tensorkonstante vom Typ i1 |
(C9) |
(I8) | input_batch_dimension |
Konstante vom Typ si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
Konstante vom Typ si64 |
(C11), (C13–C14) |
(I10) | input_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
Konstante vom Typ si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
Konstante vom Typ si64 |
(C15-C16), (C18), (C25), (C32) |
(I13) | kernel_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C17–C18), (C25) |
(I14) | output_batch_dimension |
Konstante vom Typ si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
Konstante vom Typ si64 |
(C20), (C25), (C33) |
(I16) | output_spatial_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C19–C20), (C25) |
(I17) | feature_group_count |
Konstante vom Typ si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
Konstante vom Typ si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
variadische Anzahl von Enums von DEFAULT , HIGH und HIGHEST |
(C24) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C25–C28), (C30–C31), (C33) |
Einschränkung
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Angegebene
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Angegebene
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Angegebene
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
ist so definiert:dim(lhs, input_batch_dimension) / batch_group_count
, wennresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
, wennresult_dim = output_feature_dimension
.- Andernfalls
num_windows
. Dabei gilt: output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Wenn der Vorgang nicht quantisierte Tensoren verwendet:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Wenn der Vorgang quantisierte Tensoren verwendet:
- (C28)
is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
. - (C29)
storage_type(lhs) = storage_type(rhs)
. - (C30)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C31) Wenn
is_per_tensor_quantized(rhs)
, dannis_per_tensor_quantized(result)
. - (C32) Wenn
is_per_axis_quantized(rhs)
, dannquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C33) Wenn
is_per_axis_quantized(result)
, dannquantization_dimension(result) = output_feature_dimension
.
- (C28)
Beispiele
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
Kosinus
Semantik
Führt eine elementweise Kosinusoperation für den operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
cos
aus IEEE-754. - Bei komplexen Zahlen: komplexer Kosinus.
- Für quantisierte Typen:
dequantize_op_quantize(cosine, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantik
Führt eine elementweise Zählung der Anzahl führender Null-Bits im operand
-Tensor durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Ganzzahltyps | (C1) |
Einschränkung
- (C1)
type(operand) = type(result)
.
Beispiele
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantik
Kapselt einen implementierungsdefinierten Vorgang call_target_name
ein, der inputs
und called_computations
annimmt und results
erzeugt. Mit has_side_effect
, backend_config
und api_version
können zusätzliche durch die Implementierung definierte Metadaten bereitgestellt werden.
Derzeit enthält dieser Vorgang eine ziemlich unorganisierte Sammlung von Metadaten, die die organische Entwicklung des entsprechenden Vorgangs im XLA-Compiler widerspiegelt. Wir planen, diese Metadaten künftig zu vereinheitlichen (#741).
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | inputs |
variadische Anzahl von Werten |
(I2) | call_target_name |
Konstante vom Typ string |
(I3) | has_side_effect |
Konstante vom Typ i1 |
(I4) | backend_config |
Konstante vom Typ string |
(I5) | api_version |
Konstante vom Typ si32 |
(I6) | called_computations |
variadische Anzahl von Konstanten vom Typ string |
Ausgaben
Name | Typ |
---|---|
results |
variadische Anzahl von Werten |
Beispiele
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
Dividieren
Semantik
Führt die elementweise Division der Tensoren lhs
und Divisor rhs
durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Ganzzahlen: Ganzzahldivision, die den algebraischen Quotienten erzeugt, wobei jeder Bruchteil verworfen wird.
- Für Gleitkommazahlen:
division
aus IEEE-754. - Bei komplexen Zahlen: komplexe Division.
- Für quantisierte Typen:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantik
Berechnet Punktprodukte zwischen Segmenten von lhs
und Segmenten von rhs
und erzeugt einen result
-Tensor.
Formell: result[result_index] = dot_product
, wobei:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
, wobeisize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
undsize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Führt für quantisierte Typen den Befehl dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
aus.
Gibt nur die Semantik für die Quantisierung pro Tensor an. Die Quantisierung pro Achse wird ausgeführt (#1574). Zukünftig werden wir auch erwägen, die Unterstützung für die Hybridquantisierung hinzuzufügen (#1575).
precision_config
steuert den Kompromiss zwischen Geschwindigkeit und Genauigkeit bei Berechnungen auf Beschleuniger-Back-Ends. Dabei kann es sich um einen der folgenden Werte handeln (derzeit ist die Semantik dieser enum-Werte zu niedrig, wir planen jedoch, dies in #755 zu beheben):
DEFAULT
: schnellste Berechnung, aber am wenigsten genaue Näherung an die ursprüngliche Zahl.HIGH
: Langsamere Berechnung, aber genauere Annäherung an die ursprüngliche Zahl.HIGHEST
: langsamste Berechnung, aber genaueste Näherung an die ursprüngliche Zahl.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C5-C6), (C9-C10), (C12-C16) |
(I2) | rhs |
Tensor oder quantisierter Tensor pro Tensor | (C7–C10), (C12) |
(I3) | lhs_batching_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C8), (C10) |
(I7) | precision_config |
variadische Anzahl von Enums von DEFAULT , HIGH und HIGHEST |
(C11) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C12), (C14), (C16) |
Einschränkung
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Wenn der Vorgang nicht quantisierte Tensoren verwendet:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Wenn der Vorgang quantisierte Tensoren verwendet:
- (C14)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C15)
storage_type(lhs) = storage_type(rhs)
. - (C16)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C17)
zero_points(rhs) = 0
.
- (C14)
Beispiele
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_slice
Semantik
Extrahiert mithilfe von dynamisch berechneten Startindexen ein Slice aus dem operand
und erzeugt einen result
-Tensor. start_indices
enthält die Startindexe des Segments für jede Dimension, für die eine Anpassung möglich ist, und slice_sizes
die Größe des Segments für jede Dimension. Formaler gesagt, result[result_index] = operand[operand_index]
, wobei:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C2), (C4) |
(I2) | start_indices |
variadische Anzahl von 0-dimensionalen Tensoren vom Typ „Ganzzahl“ | (C2), (C3) |
(I3) | slice_sizes |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C5) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C5) |
Einschränkung
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Beispiele
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantik
Erzeugt einen result
-Tensor, der dem operand
-Tensor entspricht, mit der Ausnahme, dass das Slice, das bei start_indices
beginnt, mit den Werten in update
aktualisiert wird.
Formell ist result[result_index]
so definiert:
update[update_index]
, wenn0 <= update_index < shape(update)
, wobei:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- Andernfalls
operand[result_index]
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1–C4), (C6) |
(I2) | update |
Tensor oder quantisierter Tensor pro Tensor | (C2), (C3), (C6) |
(I3) | start_indices |
variadische Anzahl von 0-dimensionalen Tensoren vom Typ „Ganzzahl“ | (C4), (C5) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Beispiele
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
Exponentialfunktionen
Semantik
Führt eine elementweise exponentielle Operation auf dem operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
exp
aus IEEE-754. - Bei komplexen Zahlen: komplex exponentiell.
- Für quantisierte Typen:
dequantize_op_quantize(exponential, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantik
Führt eine elementweise exponentielle minus eins Operation auf dem operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
expm1
aus IEEE-754. - Bei komplexen Zahlen: komplexes Exponential minus eins.
- Für quantisierte Typen:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Semantik
Führt die Vorwärts- und Inverse Fourier-Transformationen für reale und komplexe Ein-/Ausgaben aus.
fft_type
ist einer der folgenden Werte:
FFT
: Komplexe zu komplexen FFT weiterleiten.IFFT
: Umgekehrte komplex-komplexe FFT.RFFT
: Eine reelle zu komplexe FFT weiterleiten.IRFFT
: Umgekehrte reell-komplexe FFT (nimmt komplex an, gibt reelle Zahlen zurück).
Formeller ausgedrückt, erzeugt die Funktion fft
, die eindimensionale Tensoren komplexer Typen als Eingabe verwendet, eindimensionale Tensoren derselben Typen wie die Ausgabe und berechnet die diskrete Fourier-Transformation:
Für fft_type = FFT
ist result
als Endergebnis einer Reihe von L-Berechnungen definiert, wobei L = size(fft_length)
ist. Beispiel für L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Wenn die Funktion ifft
die gleiche Typsignatur hat und den Kehrwert von fft
berechnet, gilt außerdem Folgendes:
Für fft_type = IFFT
ist result
als Kehrwert der Berechnungen für fft_type = FFT
definiert. Beispiel für L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Wenn die Funktion rfft
, die eindimensionale Tensoren von Gleitkommatypen annimmt, eindimensionale Tensoren komplexer Typen mit derselben Gleitkommasemantik erzeugt und so funktioniert:
rfft(real_operand) = truncated_result
, wobeicomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
Wenn die diskrete Fourier-Transformation für reelle Operanden berechnet wird, definieren die ersten N/2 + 1
-Elemente des Ergebnisses den Rest des Ergebnisses eindeutig, sodass das Ergebnis von rfft
abgeschnitten wird, um die Berechnung redundanter Elemente zu vermeiden.
Für fft_type = RFFT
ist result
als Endergebnis einer Reihe von L-Berechnungen definiert, wobei L = size(fft_length)
ist. Beispiel für L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Wenn die Funktion irfft
die gleiche Typsignatur hat und den Kehrwert von rfft
berechnet, gilt Folgendes:
Für fft_type = IRFFT
ist result
als Kehrwert der Berechnungen für fft_type = RFFT
definiert. Beispiel für L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
Aufzählung von FFT , IFFT , RFFT und IRFFT |
(C2), (C5) |
(I3) | fft_length |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C3), (C4) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkomma- oder komplexen Typs | (C2), (C4), (C5) |
Einschränkung
- (C1)
size(fft_length) <= rank(operand)
. - (C2) Die Beziehung zwischen den Elementtypen
operand
undresult
variiert:- Wenn
fft_type = FFT
,element_type(operand)
undelement_type(result)
denselben komplexen Typ haben. - Wenn
fft_type = IFFT
,element_type(operand)
undelement_type(result)
denselben komplexen Typ haben. - Bei
fft_type = RFFT
istelement_type(operand)
ein Gleitkommatyp undelement_type(result)
ein komplexer Typ derselben Gleitkommasemantik. - Bei
fft_type = IRFFT
istelement_type(operand)
ein komplexer Typ undelement_type(result)
ein Gleitkommatyp mit derselben Gleitkommasemantik.
- Wenn
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Wenn zwischen
operand
undresult
ein Tensorreal
eines Gleitkommatyps vorhanden ist, dannshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
mit Ausnahme von:- Wenn
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Wenn
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Wenn
Beispiele
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
Etage
Semantik
Führt das elementweise Stockwerk des operand
-Tensors aus und erzeugt einen result
-Tensor.
Implementiert den Vorgang roundToIntegralTowardNegative
aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(floor, operand, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
sammeln
Semantik
Erfasst Segmente des Tensors operand
aus den in start_indices
angegebenen Offsets und erzeugt einen result
-Tensor.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result
den Elementen in operand
zugeordnet werden. Im Diagramm werden einige Beispiele für result
-Indexe ausgewählt und erklärt, welchen operand
-Indizes sie entsprechen.
Formeller gesagt, result[result_index] = operand[operand_index]
, wobei:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
ist so definiert:start_indices[bi0, ..., :, ..., biN]
, wobeibi
einzelne Elemente inbatch_index
sind und:
in den Indexindex_vector_dim
eingefügt wird, wennindex_vector_dim
<rank(start_indices)
ist.- Andernfalls
[start_indices[batch_index]]
.
- Für
d_operand
inaxes(operand)
:full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
wennd_operand = start_index_map[d_start]
.- Andernfalls
full_start_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, wobeioi
einzelne Elemente inoffset_index
sind und0
an Indizes voncollapsed_slice_dims
eingefügt wird.operand_index = full_start_index + full_offset_index
.
Wenn indices_are_sorted
den Wert true
hat, kann die Implementierung davon ausgehen, dass start_indices
in Bezug auf start_index_map
sortiert sind. Andernfalls ist das Verhalten nicht definiert. Formell für alle i1 < i2
aus indices(result)
, full_start_index(i1) <= full_start_index(i2)
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
Tensor des Ganzzahltyps | (C2), (C3), (C13) |
(I3) | offset_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4-C5), (C13) |
(I4) | collapsed_slice_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C6-C8), (C13) |
(I5) | start_index_map |
Eindimensionale Tensorkonstante vom Typ si64 |
(C3), (C9), (C10) |
(I6) | index_vector_dim |
Konstante vom Typ si64 |
(C2), (C3), (C13) |
(I7) | slice_sizes |
Eindimensionale Tensorkonstante vom Typ si64 |
(C8), (C11–C13) |
(I8) | indices_are_sorted |
Konstante vom Typ i1 |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C5), (C13–C14) |
Einschränkung
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
, wobei:batch_dim_sizes = shape(start_indices)
, allerdings ist die Dimensionsgröße vonstart_indices
, dieindex_vector_dim
entspricht, nicht enthalten.offset_dim_sizes = shape(slice_sizes)
, außer dass die Dimensionsgrößen inslice_sizes
, diecollapsed_slice_dims
entsprechen, nicht enthalten sind.combine
platziertbatch_dim_sizes
an den Achsen, diebatch_dims
entsprechen, undoffset_dim_sizes
an den Achsen, dieoffset_dims
entsprechen.
- (C14)
element_type(operand) = element_type(result)
.
Beispiele
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
get_dimension_size
Semantik
Erzeugt die Größe der angegebenen dimension
von operand
. Formeller gesagt: result = dim(operand, dimension)
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor | (C1) |
(I2) | dimension |
Konstante vom Typ si64 |
(C1) |
Ausgaben
Name | Typ |
---|---|
result |
0-dimensionaler Tensor vom Typ si32 |
Einschränkung
- (C1)
0 <= dimension < rank(operand)
.
Beispiele
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantik
Extrahiert ein Element an der Position index
des Tupels operand
und erzeugt eine result
. Formeller gesagt: result = operand[index]
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
Konstante vom Typ si32 |
(C1), (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
alle unterstützten Typen | (C2) |
Einschränkung
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Beispiele
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
if
Semantik
Erzeugt die Ausgabe aus der Ausführung genau einer Funktion aus true_branch
oder false_branch
, abhängig vom Wert von pred
. Formeller gesagt: result =
pred ? true_branch() : false_branch()
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | pred |
0-dimensionaler Tensor vom Typ i1 |
|
(I2) | true_branch |
Funktion | (C1–C3) |
(I3) | false_branch |
Funktion | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C3) |
Einschränkung
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Beispiele
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
Bild
Semantik
Extrahiert den Imaginärteil elementweise aus dem operand
und erzeugt einen result
-Tensor. Formeller für jedes Element x
: imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps | (C1), (C2) |
Einschränkung
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
ist so definiert:complex_element_type(element_type(operand))
, wennis_complex(operand)
.- Andernfalls
element_type(operand)
.
Beispiele
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
Einspeisung
Semantik
Liest Daten aus dem Feed und erzeugt results
.
Die Semantik von infeed_config
ist implementierungsdefiniert.
results
bestehen aus Nutzlastwerten, die an erster Stelle stehen, und einem Token, das zuletzt steht. Zur besseren Verständlichkeit planen wir, die Nutzlast und das Token in Zukunft auf zwei separate Ausgaben aufzuteilen (#670).
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
Konstante vom Typ string |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C1–C3) |
Einschränkung
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
oderis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Beispiele
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
Semantik
Füllt einen output
-Tensor mit Werten in aufsteigender Reihenfolge, beginnend bei null, entlang der iota_dimension
-Dimension. Formell
output[result_index] = constant(is_quantized(output) ?
quantize(result_index[iota_dimension], element_type(output)) :
result_index[iota_dimension], element_type(output))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
output |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Einschränkung
- (C1)
0 <= iota_dimension < rank(output)
.
Beispiele
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantik
Führt eine elementweise Prüfung des Werts in x
durch, d.h. ist weder +Inf, -Inf noch NaN, und erzeugt einen y
-Tensor. Implementiert den Vorgang isFinite
aus der IEEE-754-Spezifikation. Bei quantisierten Typen ist das Ergebnis immer true
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | x |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
y |
Tensor des booleschen Typs | (C1) |
Einschränkung
- (C1)
shape(x) = shape(y)
.
Beispiele
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantik
Führt eine elementweise Logarithmusoperation auf dem Tensor operand
durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
log
aus IEEE-754. - Bei komplexen Zahlen: komplexer Logarithmus.
- Für quantisierte Typen:
dequantize_op_quantize(log, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantik
Führt einen elementweisen Logarithmus plus eine Operation auf dem operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
logp1
aus IEEE-754. - Bei komplexen Zahlen: komplexer Logarithmus plus eins.
- Für quantisierte Typen:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
Logistik
Semantik
Führt eine elementweise logistische Operation für den operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
division(1, addition(1, exp(-x)))
aus IEEE-754. - Bei komplexen Zahlen: komplexe logistische Daten.
- Für quantisierte Typen:
dequantize_op_quantize(logistic, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
Karte
Semantik
Wendet eine Kartenfunktion computation
auf inputs
entlang der dimensions
an und erzeugt einen result
-Tensor.
Formeller gesagt: result[result_index] = computation(inputs...[result_index])
.
dimensions
werden derzeit nicht verwendet und wahrscheinlich in Zukunft entfernt (#487).
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1–C4) |
(I2) | dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C3) |
(I3) | computation |
Funktion | (C4) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C4) |
Einschränkung
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
hat den Typ(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, wobeiEi = element_type(inputs[i])
undE' = element_type(result)
verwendet werden.
Beispiele
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
Maximum
Semantik
Führt eine elementweise maximale Operation für die Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches ODER
- Für Ganzzahlen: ganzzahliges Maximum.
- Für Gleitkommazahlen:
maximum
aus IEEE-754. - Bei komplexen Zahlen: das lexikografische Maximum für das
(real, imaginary)
-Paar. Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung komplexer Zahlen für diese Operation (#560) in Zukunft einzustellen. - Für quantisierte Typen:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
(I2) | rhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
Minimum
Semantik
Führt eine elementweise minimale Operation für die Tensoren lhs
und rhs
durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches UND.
- Für Ganzzahlen: minimale Ganzzahl.
- Für Gleitkommazahlen:
minimum
aus IEEE-754. - Bei komplexen Zahlen: das lexikografische Minimum für das
(real, imaginary)
-Paar. Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung komplexer Zahlen für diese Operation (#560) in Zukunft einzustellen. - Für quantisierte Typen:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
(I2) | rhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
Multiplizieren
Semantik
Führt ein elementweises Produkt der beiden Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches UND.
- Für Ganzzahlen: Ganzzahlmultiplikation.
- Für Gleitkommazahlen:
multiplication
aus IEEE-754. - Bei komplexen Zahlen: komplexe Multiplikation.
- Für quantisierte Typen:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
(I2) | rhs |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
Semantik
Führt die elementweise Negation des Tensors operand
durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei vorzeichenbehafteten Ganzzahlen: Ganzzahlennegation.
- Für vorzeichenlose Ganzzahlen: Bitcast zu vorzeichenbehaftete Ganzzahl, Ganzzahl-Negation, Bitcast zurück in vorzeichenlose Ganzzahl.
- Für Gleitkommazahlen:
negate
aus IEEE-754. - Bei komplexen Zahlen: komplexe Negation.
- Für quantisierte Typen:
dequantize_op_quantize(negate, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
nicht
Semantik
Führt ein elementweises NOT des Tensors operand
aus und erzeugt einen result
-Tensor.
Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches NOT.
- Für Ganzzahlen: Das bitweise NOT.
Argumente
Name | Typ | Einschränkung |
---|---|---|
operand |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
Einschränkung
- (C1)
type(operand) = type(result)
.
Beispiele
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantik
Sorgt dafür, dass die Vorgänge, die das operand
erzeugen, vor Vorgängen ausgeführt werden, die von result
abhängig sind, und verhindert, dass Compiler-Transformationen Vorgänge über die Barriere verschieben. Ansonsten ist der Vorgang eine Identität, z.B. result = operand
.
Argumente
Name | Typ | Einschränkung |
---|---|---|
operand |
variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens pro Tensor | (C1) |
Einschränkung
- (C1)
type(operand...) = type(result...)
.
Beispiele
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
oder
Semantik
Führt ein elementweises OR von zwei Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches ODER
- Für Ganzzahlen: bitweises OR.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor des Typs „Ganzzahl“ oder „Boolescher Wert“ | (C1) |
(I2) | rhs |
Tensor des Typs „Ganzzahl“ oder „Boolescher Wert“ | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Typs „Ganzzahl“ oder „Boolescher Wert“ | (C1) |
Einschränkung
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
Outfeed
Semantik
Schreibt inputs
in den Outfeed und generiert ein result
-Token.
Die Semantik von outfeed_config
ist implementierungsdefiniert.
Eingaben
Label | Name | Typ |
---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierten Tensoren |
(I2) | token |
token |
(I3) | outfeed_config |
Konstante vom Typ string |
Ausgaben
Name | Typ |
---|---|
result |
token |
Beispiele
%result = "stablehlo.outfeed"(%inputs0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
Feld
Semantik
Maximiert operand
durch Auffüllung um den Tensor sowie zwischen die Elemente des Tensors mit dem angegebenen padding_value
.
edge_padding_low
und edge_padding_high
geben den Abstand an, der am unteren Ende (neben Index 0) bzw. am oberen Rand (neben dem höchsten Index) jeder Dimension hinzugefügt wird. Der Grad des Innenabstands kann negativ sein, wobei der absolute Wert des negativen Werts die Anzahl der Elemente angibt, die aus der angegebenen Dimension entfernt werden sollen.
interior_padding
gibt den Abstand zwischen zwei Elementen in jeder Dimension an, der nicht negativ sein darf. Das Innen-Padding erfolgt vor dem Rand-Padding. Bei einem negativen Rand-Padding werden Elemente aus dem Operanden mit Auffüllung entfernt.
Formell ist result[result_index]
so definiert:
operand[operand_index]
, wennresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.- Andernfalls
padding_value
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C2), (C4) |
(I2) | padding_value |
0-dimensionaler Tensor oder quantisierter Tensor pro Tensor | (C1) |
(I3) | edge_padding_low |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Eindimensionale Tensorkonstante vom Typ si64 |
(C1), (C4) |
(I5) | interior_padding |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2–C4) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C3–C6) |
Einschränkung
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Beispiele
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantik
Erzeugt partition_id
des aktuellen Prozesses.
Ausgaben
Name | Typ |
---|---|
result |
0-dimensionaler Tensor vom Typ ui32 |
Beispiele
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Semantik
Führt eine elementweise Zählung der im operand
-Tensor festgelegten Anzahl von Bits durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Ganzzahltyps | (C1) |
Einschränkung
- (C1)
type(operand) = type(result)
.
Beispiele
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
Leistung
Semantik
Führt die elementweise Exponentiierung des lhs
-Tensors mit dem rhs
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Ganzzahlen: ganzzahlige Exponentiation.
- Für Gleitkommazahlen:
pow
aus IEEE-754. - Bei komplexen Zahlen: komplexe Exponentiation.
- Für quantisierte Typen:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semantik
Extrahiert den reellen Teil elementweise aus operand
und erzeugt einen result
-Tensor. Formeller für jedes Element x
: real(x) = is_complex(x) ? real_part(x) : x
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder komplexen Typs | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps | (C1), (C2) |
Einschränkung
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
ist so definiert:complex_element_type(element_type(operand))
, wennis_complex(operand)
.- Andernfalls
element_type(operand)
.
Beispiele
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
Recv
Semantik
Erhält Daten von einem Kanal mit channel_id
und erzeugt results
.
Wenn is_host_transfer
den Wert true
hat, überträgt der Vorgang Daten vom Host. Andernfalls werden Daten von einem anderen Gerät übertragen. Was das bedeutet, ist
implementierungsdefiniert. Dieses Flag dupliziert die in channel_type
bereitgestellten Informationen. Daher planen wir, in Zukunft nur eines davon zu behalten (#666).
results
bestehen aus Nutzlastwerten, die an erster Stelle stehen, und einem Token, das zuletzt steht. Zur besseren Verständlichkeit planen wir, die Nutzlast und das Token in Zukunft auf zwei separate Ausgaben aufzuteilen (#670).
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
Konstante vom Typ si64 |
|
(I3) | channel_type |
Aufzählung von DEVICE_TO_DEVICE und HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
Konstante vom Typ i1 |
(C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C2–C4) |
Einschränkung
- (C1)
channel_type
ist so definiert:HOST_TO_DEVICE
, wennis_host_transfer = true
,- Andernfalls
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
oderis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Beispiele
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
reduce
Semantik
Wendet eine Reduktionsfunktion body
auf inputs
und init_values
entlang der dimensions
an und erzeugt results
-Tensoren.
Die Reihenfolge der Reduzierungen ist implementierungsdefiniert. Das bedeutet, dass body
und init_values
ein Monoid bilden müssen, um zu garantieren, dass der Vorgang für alle Eingaben und Implementierungen dieselben Ergebnisse liefert. Diese Bedingung gilt jedoch nicht für viele gängige Reduzierungen. Beispielsweise bilden die Gleitkommazahlen für body
und Null für init_values
kein Monoid, da das Addieren von Gleitkommazahlen nicht assoziativ ist.
Formeller gesagt, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, wobei:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, wobei:
beidimensions
eingefügt wird.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
für eine binäre Baumstrukturschedule
, wobei Folgendes gilt:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
ist ein implementierungsdefinierter vollständiger binärer Baum, dessen Durchlauf in einer bestimmten Reihenfolge aus folgenden Elementen besteht:input_slices_converted...[index]
-Werte für alleindex
inindex_space(input_slices_converted)
in aufsteigender lexikografischer Reihenfolge vonindex
.- Wird mit einer implementierungsdefinierten Menge von
init_values_converted
an implementierungsdefinierten Positionen eingefügt.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1–C4), (C6), (C7) |
(I2) | init_values |
variadische Anzahl von 0-dimensionalen Tensoren oder quantisierten Tensoren pro Tensor | (C2), (C3) |
(I3) | dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C4), (C5), (C7) |
(I4) | body |
Funktion | (C6) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C3), (C7), (C8) |
Einschränkung
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
hat den Typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, wobeiis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, außer dass die Dimensionsgrößen voninputs...
, diedimensions
entsprechen, nicht enthalten sind. - (C8)
element_type(results[i]) = Ei
für allei
in[0,N)
.
Beispiele
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantik
Führt die elementweise Umwandlung von operand
in einen anderen Gleitkommatyp, der exponent_bits
und mantissa_bits
verwendet, sowie zurück zum ursprünglichen Gleitkommatyp durch und erzeugt einen output
-Tensor.
Formeller:
- Die Mantissen-Bits des ursprünglichen Werts werden aktualisiert, um den ursprünglichen Wert auf den nächsten Wert zu runden, der mit
mantissa_bits
unter Verwendung derroundToIntegralTiesToEven
-Semantik dargestellt werden kann. - Wenn
mantissa_bits
dann kleiner als die Anzahl der Mantissen-Bits des ursprünglichen Werts ist, werden die Mantissen-Bits aufmantissa_bits
gekürzt. - Wenn dann die Exponentenbits des Zwischenergebnisses nicht in den von
exponent_bits
bereitgestellten Bereich passen, läuft das Zwischenergebnis mit dem ursprünglichen Vorzeichen ins Unendlichkeit über oder geht mit dem ursprünglichen Vorzeichen auf null über. - Führt für quantisierte Typen den Befehl
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
(I2) | exponent_bits |
Konstante vom Typ si32 |
(C2) |
(I3) | mantissa_bits |
Konstante vom Typ si32 |
(C3) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
output |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Beispiele
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantik
Führt in jeder Prozessgruppe im StableHLO-Prozessraster die Reduktion mit computations
über die Werte des Tensors operand
der einzelnen Prozesse durch, teilt das Reduktionsergebnis entlang von scatter_dimension
in Teile auf und verteilt die Teile auf die Prozesse, um die result
zu erzeugen.
Der Vorgang teilt das StableHLO-Prozessraster in process_groups
auf, die so definiert ist:
cross_replica(replica_groups)
wennchannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
wennchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
wennchannel_id > 0 and use_global_device_ids = true
.
Danach geschieht in jedem process_group
Folgendes:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
für allesender
inprocess_group
, wobeireceiver_index = process_group.index(receiver)
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
Konstante vom Typ si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
2-dimensionale Tensorkonstante vom Typ si64 |
(C3–C5) |
(I4) | channel_id |
Konstante vom Typ si64 |
(C6) |
(I5) | use_global_device_ids |
Konstante vom Typ i1 |
(C6) |
(I6) | computation |
Funktion | (C7) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C8–C9) |
Einschränkung
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
ist so definiert:num_replicas
, wenncross_replica
verwendet wird.num_replicas
, wenncross_replica_and_partition
verwendet wird.num_processes
, wennflattened_ids
verwendet wird.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Wenn
use_global_device_ids = true
, dannchannel_id > 0
. - (C7)
computation
hat den Typ(tensor<E>, tensor<E>) -> (tensor<E>)
, wobeiis_promotable(element_type(operand), E)
ist. - (C8)
shape(result) = shape(operand)
außer:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Beispiele
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantik
Wendet die Reduzierungsfunktion body
auf die Fenster von inputs
und init_values
an und erzeugt results
.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in results...
aus inputs...
berechnet werden.
Formaler gilt für results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(siehe reduzieren). Dabei gilt:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
variadische Anzahl von 0-dimensionalen Tensoren oder quantisierten Tensoren pro Tensor | (C1), (C13) |
(I3) | window_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Eindimensionale Tensorkonstante vom Typ si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Eindimensionale Tensorkonstante vom Typ si64 |
(C10), (C11), (C15) |
(I7) | padding |
2-dimensionale Tensorkonstante vom Typ si64 |
(C12), (C15) |
(I8) | body |
Funktion | (C13) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1), (C14–C16) |
Einschränkung
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
hat den Typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, wobeiis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
, wobei:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
für allei
in[0,N)
.
Beispiele
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
Rest
Semantik
Führt den elementweisen Rest der Tensoren lhs
und Divisor rhs
aus und erzeugt einen result
-Tensor.
Formaler wird das Vorzeichen des Ergebnisses vom Dividenden entnommen, und der absolute Wert des Ergebnisses ist immer kleiner als der absolute Wert des Divisors.
Der Rest wird als lhs - d * rhs
berechnet, wobei d
durch Folgendes gegeben ist:
- Für Ganzzahlen:
stablehlo.divide(lhs, rhs)
. - Für Gleitkommazahlen:
division(lhs, rhs)
aus IEEE-754 mit RundungsattributroundTowardZero
. - Für komplexe Zahlen: TBD (#997).
- Für quantisierte Typen:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Bei Gleitkomma-Elementtypen steht dieser Vorgang im Gegensatz zum remainder
-Vorgang aus der IEEE-754-Spezifikation, bei dem d
ein ganzzahliger Wert ist, der dem exakten Wert von lhs/rhs
am nächsten ist, aber gleichbedeutend mit einem geraden Wert ist.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantik
Erzeugt replica_id
des aktuellen Prozesses.
Ausgaben
Name | Typ |
---|---|
result |
0-dimensionaler Tensor vom Typ ui32 |
Beispiele
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
Form ändern
Semantik
Führt die Umformung des operand
-Tensors in einen result
-Tensor durch. Konzeptionell geht es darum, dieselbe kanonische Darstellung beizubehalten, aber möglicherweise die Form zu ändern, z.B. von tensor<2x3xf32>
in tensor<3x2xf32>
oder tensor<6xf32>
.
Formell weiter ist result[result_index] = operand[operand_index]
, wobei result_index
und operand_index
dieselbe Position in der lexikografischen Reihenfolge von index_space(result)
und index_space(operand)
haben.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1–C3) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1–C3) |
Einschränkung
- (C1)
element_type(result)
wird gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, allerdings können sichquantization_dimension(operand)
undquantization_dimension(result)
unterscheiden.
- (C2)
size(operand) = size(result)
. - (C3) Wenn
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Beispiele
// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
Semantik
Kehrt die Reihenfolge der Elemente in operand
entlang der angegebenen dimensions
um und erzeugt einen result
-Tensor. Formaler gesagt, result[result_index] = operand[operand_index]
, wobei:
operand_index[d] = dim(result, d) - result_index[d] - 1
wennd
indimensions
ist.- Andernfalls
operand_index[d] = result_index[d]
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C3) |
(I2) | dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C3) |
Einschränkung
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Beispiele
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
RNG
Semantik
Erzeugt Zufallszahlen mit dem rng_distribution
-Algorithmus und erzeugt einen result
-Tensor der Form shape
.
Bei rng_distribution = UNIFORM
werden die Zufallszahlen entsprechend der gleichmäßigen Verteilung über das Intervall [a, b)
generiert. Wenn a >= b
, ist das Verhalten nicht definiert.
Bei rng_distribution = NORMAL
werden die Zufallszahlen entsprechend der Normalverteilung mit einem Mittelwert = a
und der Standardabweichung = b
generiert.
Wenn b < 0
, ist das Verhalten nicht definiert.
Die genaue Generierung von Zufallszahlen ist von der Implementierung definiert. Sie können beispielsweise deterministisch sein oder nicht, und sie können einen verborgenen Status verwenden.
In Gesprächen mit vielen Stakeholdern wurde festgestellt, dass diese Operation faktisch verworfen wurde. Daher planen wir, sie in Zukunft zu entfernen (#597).
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | a |
0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolesch“ oder „Gleitkomma“ | (C1), (C2) |
(I2) | b |
0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolesch“ oder „Gleitkomma“ | (C1), (C2) |
(I3) | shape |
Eindimensionale Tensorkonstante vom Typ si64 |
(C3) |
(I4) | rng_distribution |
Aufzählung von UNIFORM und NORMAL |
(C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Ganzzahl-, booleschen oder Gleitkommatyps | (C1–C3) |
Einschränkung
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Wenn
rng_distribution = NORMAL
, dannis_float(a)
. - (C3)
shape(result) = shape
.
Beispiele
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantik
Gibt ein output
mit einheitlichen Zufallsbits und einem aktualisierten Ausgabestatus output_state
unter Verwendung des Pseudozufallszahlengenerator-Algorithmus rng_algorithm
bei einem Ausgangszustand initial_state
zurück. Es wird garantiert, dass die Ausgabe eine deterministische Funktion von initial_state
ist. Es ist jedoch nicht garantiert, dass sie zwischen Implementierungen deterministisch ist.
rng_algorithm
ist einer der folgenden Werte:
DEFAULT
: Implementierungsdefinierter Algorithmus.THREE_FRY
: Implementierungsdefinierte Variante des Threefry-Algorithmus.*PHILOX
: Implementierungsdefinierte Variante des Philox-Algorithmus.*
* Siehe Salmon et al. SC 2011. Parallele Zufallszahlen – so einfach wie 1, 2, 3.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | rng_algorithm |
Aufzählung von DEFAULT , THREE_FRY und PHILOX |
(C2) |
(I2) | initial_state |
Eindimensionaler Tensor vom Typ ui64 |
(C1), (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
output_state |
Eindimensionaler Tensor vom Typ ui64 |
(C1) |
output |
Tensor des Ganzzahl- oder Gleitkommatyps |
Einschränkung
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
ist so definiert:- Implementierung definiert, wenn
rng_algorithm = DEFAULT
. 2
, wennrng_algorithm = THREE_FRY
.2
oder3
, wennrng_algorithm = PHILOX
.
- Implementierung definiert, wenn
Beispiele
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantik
Führt eine elementweise Rundung zur nächsten Ganzzahl auf dem Tensor operand
durch und hebt die Verknüpfungen von Null auf, um einen result
-Tensor zu erzeugen. Implementiert den Vorgang roundToIntegralTiesToAway
aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(round_nearest_afz, operand, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantik
Führt eine elementweise Rundung zur nächsten Ganzzahl auf und löst die Verbindungen zur geraden Ganzzahl auf dem Tensor operand
aus und erzeugt einen result
-Tensor. Implementiert den Vorgang roundToIntegralTiesToEven
aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(round_nearest_even, operand, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
RSS
Semantik
Führt eine elementweise reziproke Quadratwurzeloperation auf dem operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
rSqrt
aus IEEE-754. - Bei komplexen Zahlen: komplexe reziproke Quadratwurzel.
- Für quantisierte Typen:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
Semantik
Erzeugt results
-Tensoren, die den Tensoren inputs
entsprechen, mit der Ausnahme, dass mehrere von scatter_indices
angegebene Segmente mithilfe von update_computation
mit den Werten updates
aktualisiert werden.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in updates...
den Elementen in results...
zugeordnet werden. Im Diagramm werden einige Beispiele für updates...
-Indizes ausgewählt und detailliert erläutert, welchen results...
-Indizes sie entsprechen.
Formell für alle update_index
in index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
ist so definiert:scatter_indices[si0, ..., :, ..., siN]
, wobeisi
einzelne Elemente inupdate_scatter_index
sind und:
in den Indexindex_vector_dim
eingefügt wird, wennindex_vector_dim
<rank(scatter_indices)
ist.- Andernfalls
[scatter_indices[update_scatter_index]]
.
- Für
d_input
inaxes(inputs[0])
:full_start_index[d_input] = start_index[d_start]
, wennd_input = scatter_dims_to_operand_dims[d_start]
.- Andernfalls
full_start_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, wobeiwi
einzelne Elemente inupdate_window_index
sind und0
an Indizes voninserted_window_dims
eingefügt wird.result_index = full_start_index + full_window_index
.
Daher gilt results = exec(schedule, inputs)
, wobei Folgendes gilt:
schedule
ist eine implementierungsdefinierte Permutation vonindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
, wobei:- Wenn
result_index
innerhalb der Grenzen fürshape(results...)
liegt updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
ist eine Kopie vonresults
, wobeiresults...[result_index]
aufupdated_values...
festgelegt ist.- Andernfalls:
updated_results = results
.
- Wenn
exec([], results) = results
.
Wenn indices_are_sorted
den Wert true
hat, kann die Implementierung davon ausgehen, dass scatter_indices
in Bezug auf scatter_dims_to_operand_dims
sortiert sind. Andernfalls ist das Verhalten nicht definiert. Formell: Für alle i1 < i2
aus indices(result)
gilt: full_start_index(i1)
<= full_start_index(i2)
.
Wenn unique_indices
den Wert true
hat, kann die Implementierung davon ausgehen, dass alle result_index
-Indizes, auf die verstreut sind, eindeutig sind. Wenn unique_indices
den Wert true
hat, die Indizes jedoch nicht eindeutig sind, ist das Verhalten nicht definiert.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1), (C2), (C4-C6), (C10), (C13), (C15-C16) |
(I2) | scatter_indices |
Tensor des Ganzzahltyps | (C4), (C11), (C14) |
(I3) | updates |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C3–C6), (C8) |
(I4) | update_window_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C7), (C8) |
(I5) | inserted_window_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C9), (C10) |
(I6) | scatter_dims_to_operand_dims |
Eindimensionale Tensorkonstante vom Typ si64 |
(C11–C13) |
(I7) | index_vector_dim |
Konstante vom Typ si64 |
(C4), (C11), (C14) |
(I8) | indices_are_sorted |
Konstante vom Typ i1 |
|
(I9) | unique_indices |
Konstante vom Typ i1 |
|
(I10) | update_computation |
Funktion | (C15) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C15–C17) |
Einschränkung
- (C1)
same(shape(inputs...))
. - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
. - (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
, wobei:update_scatter_dim_sizes = shape(scatter_indices)
, allerdings ist die Dimensionsgröße vonscatter_indices
, dieindex_vector_dim
entspricht, nicht enthalten.update_window_dim_sizes <= shape(inputs[0])
, außer dass die Dimensionsgrößen ininputs[0]
, dieinserted_window_dims
entsprechen, nicht enthalten sind.combine
setztupdate_scatter_dim_sizes
an den Achsen, dieupdate_scatter_dims
entsprechen, undupdate_window_dim_sizes
an den Achsen, dieupdate_window_dims
entsprechen.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(inserted_window_dims) and is_sorted(update_window_dims)
. - (C10)
0 <= inserted_window_dims < rank(inputs[0])
. - (C11)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C12)
is_unique(scatter_dims_to_operand_dims)
. - (C13)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C14)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C15)
update_computation
hat den Typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, wobeiis_promotable(element_type(inputs[i]), Ei)
ist. - (C16)
shape(inputs...) = shape(results...)
. - (C17)
element_type(results[i]) = Ei
für allei
in[0,N)
.
Beispiele
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
// [[1, 2], [5, 6], [7, 8], [7, 8]],
// [[10, 11], [12, 13], [14, 15], [16, 17]],
// [[18, 19], [20, 21], [21, 22], [23, 24]]
// ]
auswählen
Semantik
Erzeugt einen result
-Tensor, bei dem jedes Element aus dem on_true
- oder on_false
-Tensor basierend auf dem Wert des entsprechenden Elements von pred
ausgewählt wird.
Formeller: result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, wobei pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Führt für quantisierte Typen den Befehl dequantize_select_quantize(pred, on_true, on_false, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | pred |
Tensor vom Typ i1 |
(C1) |
(I2) | on_true |
Tensor oder quantisierter Tensor pro Tensor | (C1–C2) |
(I3) | on_false |
Tensor oder quantisierter Tensor pro Tensor | (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C2) |
Einschränkung
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Beispiele
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantik
Streut die Werte vom Tensor source
mit scatter
anhand des Ergebnisses von reduce_window
des Tensors input
unter Verwendung von select
und erzeugt einen result
-Tensor.
Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result
aus operand
und source
berechnet werden.
Formeller:
selected_values = reduce_window_without_init(...)
mit den folgenden Eingaben:- `inputs = [operand].
window_dimensions
,window_strides
undpadding
, die unverändert verwendet werden.base_dilations = windows_dilations = 1
.body
ist definiert als:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
wobei
E = element_type(operand)
undreduce_window_without_init
genau wiereduce_window
funktionieren, mit der Ausnahme, dass derschedule
der zugrunde liegendenreduce
(siehe reduzieren) keine Initialisierungswerte enthält. Derzeit ist nicht festgelegt, was geschieht, wenn das entsprechende Fenster keine Werte enthält (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
, wobei:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
, wennselected_values[source_index]
das Elementoperand
ausoperand_index
enthält.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1–C4), (C6), (C8–C11) |
(I2) | source |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C2) |
(I3) | init_value |
0-dimensionaler Tensor oder quantisierter Tensor pro Tensor | (C3) |
(I4) | window_dimensions |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C6), (C7) |
(I6) | padding |
2-dimensionale Tensorkonstante vom Typ si64 |
(C2), (C8) |
(I7) | select |
Funktion | (C9) |
(I8) | scatter |
Funktion | (C10) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C11–C12) |
Einschränkung
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
, wobei:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
hat den Typ(tensor<E>, tensor<E>) -> tensor<i1>
, wobeiE = element_type(operand)
ist. - (C10)
scatter
hat den Typ(tensor<E>, tensor<E>) -> tensor<E>
, wobeiis_promotable(element_type(operand), E)
ist. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Beispiele
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
Senden
Semantik
Sendet inputs
an einen Kanal channel_id
und erzeugt ein result
-Token.
Wenn is_host_transfer
den Wert true
hat, werden durch den Vorgang Daten an den Host übertragen. Andernfalls werden die Daten auf ein anderes Gerät übertragen. Was das bedeutet, ist
implementierungsdefiniert. Dieses Flag dupliziert die in channel_type
bereitgestellten Informationen. Daher planen wir, in Zukunft nur eines davon zu behalten (#666).
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierten Tensoren | |
(I2) | token |
token |
|
(I3) | channel_id |
Konstante vom Typ si64 |
|
(I4) | channel_type |
Aufzählung von DEVICE_TO_DEVICE und DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
Konstante vom Typ i1 |
(C1) |
Ausgaben
Name | Typ |
---|---|
result |
token |
Einschränkung
- (C1)
channel_type
ist so definiert:DEVICE_TO_HOST
, wennis_host_transfer = true
,- Andernfalls
DEVICE_TO_DEVICE
.
Beispiele
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantik
Führt eine elementweise Linksverschiebungsoperation auf dem lhs
-Tensor mit der Anzahl von rhs
Bits durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor des Ganzzahltyps | (C1) |
(I2) | rhs |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Ganzzahltyps | (C1) |
Einschränkung
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantik
Führt eine elementweise arithmetische Rechtsverschiebungsoperation auf dem lhs
-Tensor mit der Anzahl von rhs
Bits durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor des Ganzzahltyps | (C1) |
(I2) | rhs |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Ganzzahltyps | (C1) |
Einschränkung
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantik
Führt eine elementweise logische Rechtsverschiebungsoperation am lhs
-Tensor mit der Anzahl von rhs
Bits durch und erzeugt einen result
-Tensor.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor des Ganzzahltyps | (C1) |
(I2) | rhs |
Tensor des Ganzzahltyps | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Ganzzahltyps | (C1) |
Einschränkung
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
Signieren
Semantik
Gibt das Vorzeichen von operand
elementweise zurück und erzeugt einen result
-Tensor.
Formaler kann die Semantik für jedes Element x
mit der Python-Syntax so ausgedrückt werden:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Führt für quantisierte Typen den Befehl dequantize_op_quantize(sign, operand, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor einer vorzeichenbehafteten Ganzzahl, Gleitkommazahl oder eines komplexen Typs oder eines pro Tensor quantisierten Tensors | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor einer vorzeichenbehafteten Ganzzahl, Gleitkommazahl oder eines komplexen Typs oder eines pro Tensor quantisierten Tensors | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
Sinus
Semantik
Führt eine elementweise Sinusoperation für den operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
sin
aus IEEE-754. - Bei komplexen Zahlen: komplexer Sinus.
- Für quantisierte Typen:
dequantize_op_quantize(sine, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
Slice
Semantik
Extrahiert mithilfe von statisch berechneten Startindexen ein Slice aus dem operand
und erzeugt einen result
-Tensor. start_indices
enthält die Startindexe des Segments für jede Dimension, limit_indices
die Endindexe (ausschließlich) für das Segment für jede Dimension und strides
die Schritte für die einzelnen Dimensionen.
Formeller gesagt: result[result_index] = operand[operand_index]
, wobei operand_index = start_indices + result_index * strides
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor pro Tensor | (C1–C3), (C5) |
(I2) | start_indices |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C3), (C5) |
(I4) | strides |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2), (C4) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor pro Tensor | (C1), (C5) |
Einschränkung
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Beispiele
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = dense<[1, 2]> : tensor<2xi64>,
limit_indices = dense<[3, 4]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
sort
Semantik
Sortiert eindimensionale Segmente von inputs
entlang der Dimension dimension
nach einem comparator
zusammen und erstellt results
.
Im Gegensatz zu ähnlichen Eingaben in anderen Vorgängen erlaubt dimension
mit der unten beschriebenen Semantik negative Werte. Zukünftig kann dies aus Konsistenzgründen nicht zulässig sein (#1377).
Wenn is_stable
auf „true“ gesetzt ist, ist die Sortierung stabil, d. h., die relative Reihenfolge der Elemente, die vom Vergleichsoperator als gleich betrachtet werden, wird beibehalten. Wenn eine einzelne Eingabe vorhanden ist, werden die beiden Elemente e1
und e2
vom Vergleichsoperator als gleichbedeutend angesehen, wenn und nur dann comparator(e1, e2) = comparator(e2, e1) = false
. In der folgenden Formalisierung wird gezeigt, wie sich dies auf mehrere Eingaben verallgemeinern lässt.
Formell für alle result_index
in index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, wobeiriN
einzelne Elemente inresult_index
sind und:
beiadjusted_dimension
eingefügt wird.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- Dabei sortiert
sort
ein eindimensionales Segment in nicht absteigender Reihenfolge und erwartet, dasscomparator_together
true
zurückgibt, wenn das Argument auf der linken Seite kleiner als das zweite Argument ist. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | inputs |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C1–C5) |
(I2) | dimension |
Konstante vom Typ si64 |
(C4) |
(I3) | is_stable |
Konstante vom Typ i1 |
|
(I4) | comparator |
Funktion | (C5) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor | (C2), (C3) |
Einschränkung
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, wobeiR = rank(inputs[0])
. - (C5)
comparator
hat den Typ(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, wobeiEi = element_type(inputs[i])
ist.
Beispiele
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantik
Führt eine elementweise Quadratwurzeloperation auf dem operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
squareRoot
aus IEEE-754. - Bei komplexen Zahlen: komplexe Quadratwurzel.
- Für quantisierte Typen:
dequantize_op_quantize(sqrt, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semantik
Führt die elementweise Subtraktion der beiden Tensoren lhs
und rhs
durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Ganzzahlen: Subtraktion von Ganzzahlen.
- Für Gleitkommazahlen:
subtraction
aus IEEE-754. - Bei komplexen Zahlen: komplexe Subtraktion.
- Für quantisierte Typen:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
(I2) | rhs |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Beispiele
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
Tanh
Semantik
Führt eine elementweise hyperbolische Tangensoperation am operand
-Tensor durch und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Für Gleitkommazahlen:
tanh
aus IEEE-754. - Bei komplexen Zahlen: komplexer hyperbolischer Tangens
- Für quantisierte Typen:
dequantize_op_quantize(tanh, operand, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_type(operand) = baseline_type(result)
.
Beispiele
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
Transponieren
Semantik
Permutet die Dimensionen des Tensors operand
mit permutation
und erzeugt einen result
-Tensor. Formeller gesagt: result[result_index] = operand[operand_index]
, wobei result_index[d] = operand_index[permutation[d]]
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor oder quantisierter Tensor | (C1–C4) |
(I2) | permutation |
Eindimensionale Tensorkonstante vom Typ si64 |
(C2–C4) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor oder quantisierter Tensor | (C1), (C3–C4) |
Einschränkung
- (C1)
element_type(result)
wird gegeben durch:element_type(operand)
, wenn!is_per_axis_quantized(operand)
.element_type(operand)
, allerdings können sichquantization_dimension(operand)
undquantization_dimension(result)
unterscheiden.
- (C2)
permutation
ist eine Permutation vonrange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Wenn
is_per_axis_quantized(result)
, dannquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Beispiele
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantik
Löst Reihen von linearen Gleichungssystemen mit unteren oder oberen Dreiecksmatrizen.
Formell ist bei a
und b
result[i0, ..., iR-3, :, :]
die Lösung für op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
, wenn left_side
true
oder x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
ist, wenn left_side
false
ist. Dabei wird die Variable x
ermittelt, wobei op(a)
durch transpose_a
bestimmt wird. Dabei kann es sich um einen der folgenden Werte handeln:
NO_TRANSPOSE
: Vorgang mita
unverändert ausführen.TRANSPOSE
: Vorgang beim Transponieren vona
ausführen.ADJOINT
: Eine Operation beim konjugierten Transponieren vona
durchführen.
Eingabedaten werden nur aus dem unteren Dreieck von a
gelesen, wenn lower
gleich true
oder dem oberen Dreieck von a
ist. Andernfalls werden die Eingabedaten gelesen. Die Ausgabedaten werden im selben Dreieck zurückgegeben. Die Werte im anderen Dreieck sind implementierungsdefiniert.
Wenn unit_diagonal
„true“ ist, kann die Implementierung davon ausgehen, dass die diagonalen Elemente von a
gleich 1 sind. Andernfalls ist das Verhalten nicht definiert.
Führt für quantisierte Typen den Befehl dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
aus.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | a |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1–C3) |
(I2) | b |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1–C4) |
(I3) | left_side |
Konstante vom Typ i1 |
(C3) |
(I4) | lower |
Konstante vom Typ i1 |
|
(I5) | unit_diagonal |
Konstante vom Typ i1 |
|
(I6) | transpose_a |
Aufzählung von NO_TRANSPOSE , TRANSPOSE und ADJOINT |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor | (C1) |
Einschränkung
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) Die Beziehung zwischen
shape(a)
undshape(b)
ist so definiert:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Beispiele
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semantik
Erzeugt ein result
-Tupel aus den Werten val
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | val |
variadische Anzahl von Werten | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
tuple | (C1) |
Einschränkung
- (C1)
result
hat den Typtuple<E0, ..., EN-1>
, wobeiEi = type(val[i])
.
Beispiele
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantik
Führt die elementweise Konvertierung des quantisierten Tensors operand
in einen Gleitkomma-Tensor result
gemäß den durch den Typ operand
definierten Quantisierungsparametern durch.
Formeller gesagt: result = dequantize(operand)
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
quantisierter Tensor | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor des Gleitkommatyps | (C1), (C2) |
Einschränkung
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Beispiele
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantik
Führt eine elementweise Umwandlung des Gleitkommatensors oder des quantisierten Tensors operand
in einen quantisierten Tensor result
gemäß den durch den Typ result
definierten Quantisierungsparametern durch.
Formell
- Wenn
is_float(operand)
:result = quantize(operand, type(result))
.
- Wenn
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
Tensor des Gleitkomma- oder quantisierten Typs | (C1), (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
quantisierter Tensor | (C1), (C2) |
Einschränkung
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Beispiele
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
während
Semantik
Erzeugt die Ausgabe, indem die Funktion body
null oder mehr Mal ausgeführt wird, während die Funktion cond
den Wert true
ausgibt. Formell kann die Semantik mit der Python-Syntax folgendermaßen ausgedrückt werden:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Das Verhalten einer Endlosschleife ist noch nicht festgelegt (#383).
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | operand |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C1–C3) |
(I2) | cond |
Funktion | (C1) |
(I3) | body |
Funktion | (C2) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
results |
variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens | (C3) |
Einschränkung
- (C1)
cond
hat den Typ(T0, ..., TN-1) -> tensor<i1>
, wobeiTi = type(operand[i])
. - (C2)
body
hat den Typ(T0, ..., TN-1) -> (T0, ..., TN-1)
, wobeiTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Beispiele
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
Xor
Semantik
Führt ein elementweises XOR der beiden Tensoren lhs
und rhs
aus und erzeugt einen result
-Tensor. Folgendes wird je nach Elementtyp ausgeführt:
- Bei booleschen Werten: logisches XOR.
- Für Ganzzahlen: bitweises XOR.
Eingaben
Label | Name | Typ | Einschränkung |
---|---|---|---|
(I1) | lhs |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
(I2) | rhs |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
Ausgaben
Name | Typ | Einschränkung |
---|---|---|
result |
Tensor vom Typ „Boolescher Wert“ oder „Integer“ | (C1) |
Einschränkung
- (C1)
type(lhs) = type(rhs) = type(result)
.
Beispiele
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Umsetzung
Sequenzielle Ausführung
Ein StableHLO-Programm wird ausgeführt, indem Eingabewerte für die main
-Funktion bereitgestellt und Ausgabewerte berechnet werden. Ausgabewerte einer Funktion werden durch Ausführen des Graphen der Vorgänge berechnet, die im entsprechenden return
-Vorgang verwurzelt sind.
Die Ausführungsreihenfolge ist implementierungsdefiniert, solange sie am Datenfluss ausgerichtet ist, d.h. wenn Vorgänge vor ihrer Verwendung ausgeführt werden. In StableHLO verbrauchen alle Nebenwirkungen ein Token und erzeugen ein Token (mehrere Tokens können über after_all
zu einem Token zusammengefasst werden), sodass die Ausführungsreihenfolge von Nebeneffekten ebenfalls auf Dataflow abgestimmt ist. Mögliche Ausführungsreihenfolgen für das obige Beispielprogramm sind %0
→ %1
→ %2
→ %3
→ %4
→ return
oder %3
→ %0
→ %1
→ %2
→ %4
→ return
.
Formal ist ein StableHLO-Prozess eine Kombination aus: 1) einem StableHLO-Programm, 2) Vorgangsstatus (noch nicht ausgeführt, bereits ausgeführt) und 3) Zwischenwerten, an denen der Prozess arbeitet.
Der Prozess beginnt mit Eingabewerten für die Funktion main
, durchläuft die Grafik der Vorgänge, die Vorgangsstatus und Zwischenwerte aktualisiert, und endet mit Ausgabewerten. Die weitere Formalisierung steht noch nicht fest (#484).
Parallele Ausführung
StableHLO-Programme können parallel ausgeführt werden, organisiert in einem 2D-Prozessraster aus num_replicas
nach num_partitions
, die beide den Typ ui32
haben.
Im StableHLO-Prozessraster werden num_replicas * num_partitions
der StableHLO-Prozesse gleichzeitig ausgeführt. Jeder Prozess hat eine eindeutige process_id = (replica_id, partition_id)
, wobei replica_id
in replica_ids = range(num_replicas)
und partition_id
in partition_ids = range(num_partitions)
den Typ ui32
haben.
Die Größe des Prozessrasters ist für jedes Programm statisch bekannt (in Zukunft planen wir, es zu einem expliziten Teil der StableHLO-Programme #650 zu machen) und die Position innerhalb des Prozessrasters ist für jeden Prozess statisch bekannt. Jeder Prozess hat über die Operationen replica_id
und partition_id
Zugriff auf seine Position im Prozessraster.
Innerhalb des Prozessrasters können die Programme gleich sein (im Stil „Einzelnes Programm, mehrere Daten“), unterschiedlich sein (im Stil „Mehrere Programme, mehrere Daten“) oder etwas dazwischen. Für die Zukunft planen wir, Unterstützung für andere Redewendungen zur Definition paralleler StableHLO-Programme anzubieten, einschließlich GSPMD (#619).
Innerhalb des Prozessrasters sind die Prozesse größtenteils unabhängig voneinander – sie haben unterschiedliche Betriebsstatus sowie separate Eingabe-/Zwischen-/Ausgabewerte und die meisten Vorgänge werden getrennt zwischen Prozessen ausgeführt, mit Ausnahme einer kleinen Anzahl von kollektiven Vorgängen, die unten beschrieben werden.
Da die Ausführung der meisten Vorgänge nur Werte aus demselben Prozess verwendet, ist es normalerweise eindeutig, mit ihren Namen auf diese Werte zu verweisen.
Bei der Beschreibung der Semantik kollektiver Vorgänge ist dies jedoch nicht ausreichend. Daher kann die Notation name@process_id
in einem bestimmten Prozess auf den Wert name
verweisen. Aus dieser Sicht kann nicht qualifizierte name
als Abkürzung für name@(replica_id(), partition_id())
betrachtet werden.
Die prozessübergreifende Ausführungsreihenfolge ist von der Implementierung definiert, mit Ausnahme der Synchronisierung, die durch die Punkt-zu-Punkt-Kommunikation und kollektive Vorgänge wie unten beschrieben eingeführt wird.
Punkt-zu-Punkt-Kommunikation
Stabile HLO-Prozesse können über StableHLO-Kanäle miteinander kommunizieren. Ein Kanal wird durch eine positive ID vom Typ si64
dargestellt. Über verschiedene Vorgänge ist es möglich, Werte an Kanäle zu senden und von Kanälen zu empfangen.
Eine weitere Formulierung, z.B. woher diese Kanal-IDs stammen, wie Prozesse von Programmen erkannt werden und welche Art von Synchronisierung durch sie eingeleitet wird, ist noch nicht festgelegt (#484).
Streamingkommunikation
Jeder StableHLO-Prozess hat Zugriff auf zwei Streaming-Schnittstellen:
- InFeed, aus dem gelesen werden kann.
- Outfeed (Ausgang), in den geschrieben werden kann.
Im Gegensatz zu Channels, die für die Kommunikation zwischen Prozessen verwendet werden und daher Prozesse an beiden Enden enthalten, ist bei Infeeds und Outfeeds die jeweils andere Implementierung definiert.
Eine weitere Formalisierung, z.B. wie sich die Streamingkommunikation auf die Ausführungsreihenfolge auswirkt und welche Art von Synchronisierung damit eingeleitet wird, steht noch nicht fest (#484).
Kollektive Vorgänge
In StableHLO gibt es sechs gemeinsame Vorgänge: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
und reduce_scatter
. Alle diese Vorgänge teilen die Prozesse im StableHLO-Prozessraster in StableHLO-Prozessgruppen auf und führen eine gemeinsame Berechnung innerhalb jeder Prozessgruppe aus, unabhängig von anderen Prozessgruppen.
Innerhalb jeder Prozessgruppe können kollektive Vorgänge eine Synchronisierungsbarriere darstellen. Eine weitere Formulierung, die z.B. erläutert wird, wann genau diese Synchronisierung stattfindet, wie genau die Prozesse zu dieser Barriere gelangen und was passiert, wenn dies nicht der Fall ist, steht noch nicht fest (#484).
Wenn die Prozessgruppe partitionübergreifende Kommunikation umfasst, d.h., wenn sich in der Prozessgruppe Prozesse mit unterschiedlichen Partitions-IDs befinden, benötigt die Ausführung des gemeinsamen Vorgangs einen Kanal und der kollektive Vorgang muss eine positive channel_id
vom Typ si64
bereitstellen. Für die replizübergreifende Kommunikation sind keine Kanäle erforderlich.
Die von den kollektiven Vorgängen durchgeführten Berechnungen sind für einzelne Operationen spezifisch und werden in den einzelnen Vorgangsabschnitten oben beschrieben. Die Strategien, nach denen das Prozessraster in Prozessgruppen aufgeteilt wird, sind jedoch auf diese Vorgänge verteilt und werden in diesem Abschnitt beschrieben. Formal unterstützt StableHLO die folgenden vier Strategien.
cross_replica
Nur die replizübergreifende Kommunikation findet innerhalb jeder Prozessgruppe statt. Diese Strategie verwendet replica_groups
– eine Liste von Listen von Replikat-IDs – und berechnet ein kartesisches Produkt aus replica_groups
nach partition_ids
. replica_groups
muss eindeutige Elemente haben und alle replica_ids
abdecken. Formeller mithilfe der Python-Syntax:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Für replica_groups = [[0, 1], [2, 3]]
und num_partitions = 2
erzeugt cross_replica
beispielsweise [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Nur die partitionübergreifende Kommunikation innerhalb jeder Prozessgruppe. Diese Strategie verwendet partition_groups
– eine Liste mit Listen von Partitions-IDs – und berechnet ein kartesisches Produkt aus partition_groups
nach replica_ids
.
partition_groups
muss eindeutige Elemente haben und alle partition_ids
abdecken.
Formell unter Verwendung der Python-Syntax:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Für partition_groups = [[0, 1]]
und num_replicas = 4
erzeugt cross_partition
beispielsweise [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Sowohl die Replikat- als auch die partitionübergreifende Kommunikation kann innerhalb jeder Prozessgruppe erfolgen. Diese Strategie verwendet replica_groups
(eine Liste mit Listen von Replikat-IDs) und berechnet kartesische Produkte jeder replica_group
nach partition_ids
. replica_groups
muss eindeutige Elemente haben und alle replica_ids
abdecken. Formell unter Verwendung der Python-Syntax:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Für replica_groups = [[0, 1], [2, 3]]
und num_partitions = 2
erzeugt cross_replica_and_partition
beispielsweise [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Diese Strategie nimmt flattened_id_groups
– eine Liste von Listen mit vereinfachten Prozess-IDs im Format replica_id * num_partitions + partition_id
– und wandelt sie in Prozess-IDs um. flattened_id_groups
muss eindeutige Elemente haben und alle process_ids
abdecken. Formell unter Verwendung der Python-Syntax:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Für flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
und num_partitions = 2
erzeugt flattened_ids
beispielsweise [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Genauigkeit
Derzeit bietet StableHLO keine Garantien für die numerische Genauigkeit. Dies kann sich jedoch in Zukunft ändern (#1156).
Fehler
StableHLO-Programme werden durch eine umfassende Reihe von Einschränkungen für einzelne Vorgänge validiert, wodurch viele Fehlerklassen vor der Laufzeit ausgeschlossen werden. Fehlerbedingungen sind jedoch weiterhin möglich, z.B. durch Ganzzahlüberläufe, Zugriffe außerhalb des Bereichs usw. Sofern nicht ausdrücklich anders angegeben, führen alle diese Fehler zu einem implementierungsdefinierten Verhalten, das sich in Zukunft jedoch ändern kann (#1157).
Als Ausnahme von dieser Regel haben Gleitkommaausnahmen in StableHLO-Programmen ein klar definiertes Verhalten. Vorgänge, die zu Ausnahmen führen, die durch den IEEE-754-Standard definiert sind (ungültige Vorgänge, Division durch Null, Überlauf, Unterlauf oder ungenaue Ausnahmen), erzeugen Standardergebnisse (wie im Standard definiert) und werden fortgesetzt, ohne das entsprechende Status-Flag zu setzen, ähnlich wie bei der raiseNoFlag
-Ausnahmebehandlung des Standards. Ausnahmen für nicht standardmäßige Operationen (z.B. komplexe arithmetische und bestimmte transzendentale Funktionen) sind implementierungsdefiniert.
Notation
Zur Beschreibung der Syntax wird in diesem Dokument die modifizierte ISO-Variante der EBNF-Syntax (ISO/IEC 14977:1996, Wikipedia) mit zwei Änderungen verwendet: 1) Regeln werden mit ::=
statt mit =
definiert.
2) Die Verkettung wird durch Gegenüberstellung anstelle von ,
ausgedrückt.
Zur Beschreibung der Semantik (d.h. in den Abschnitten „Typen“, „Konstanten“ und „Ops“) verwenden wir Formeln, die auf Python-Syntax basieren und mit Unterstützung für das prägnante Ausdrucken von Arrayvorgängen wie unten beschrieben unterstützt werden. Dies funktioniert gut bei kleinen Code-Snippets. Wenn jedoch größere Code-Snippets erforderlich sind, verwenden wir in seltenen Fällen eine einfache Python-Syntax, die immer explizit eingeführt wird.
Formeln
Sehen wir uns anhand eines Beispiels aus der Spezifikation dot_general
an, wie Formeln funktionieren. Eine der Einschränkungen für diesen Vorgang sieht so aus: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Die in dieser Formel verwendeten Namen stammen aus zwei Quellen: 1) globale Funktionen, d.h. dim
, 2) Mitgliedsdefinitionen des entsprechenden Programmelements, d.h. Eingaben lhs
, lhs_batching_dimensions
, rhs
und rhs_batching_dimensions
, die im Abschnitt „Eingaben“ von dot_general
definiert sind.
Wie oben erwähnt, basiert die Syntax dieser Formel auf Python und bietet einige Erweiterungen, die auf Genauigkeit ausgerichtet sind. Um die Formel verständlich zu machen, wandeln wir sie in einfache Python-Syntax um.
A) In diesen Formeln verwenden wir =
, um Gleichheit darzustellen. Der erste Schritt zum Abrufen der Python-Syntax besteht also darin, =
durch ==
zu ersetzen: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Diese Formeln unterstützen außerdem Ellipsen (...
), die skalare Ausdrücke in Tensorausdrücke umwandeln. Kurz gesagt bedeutet f(xs...)
ungefähr „für jede skalare x
im Tensor xs
, einen skalaren f(x)
zu berechnen und dann alle diese skalaren Ergebnisse zusammen als Tensorergebnis zurückzugeben“. In der einfachen Python-Syntax wird aus unserer Beispielformel [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Dank Ellipsen ist es oft möglich, die Arbeit auf der Ebene einzelner Skalare zu vermeiden. In einigen kniffligen Fällen kann jedoch eine halbinformelle Syntax mit niedrigerer Stufe wie in der start_indices[bi0, ..., :, ..., biN]
-Formel aus der gather
-Spezifikation verwendet werden. Im Sinne der Präzision bieten wir keinen exakten Formalismus für die Übersetzung einer solchen Syntax in einfaches Python, in der Hoffnung, dass sie von Fall zu Fall dennoch intuitiv verständlich ist.
Bitte teilen Sie uns mit, wenn bestimmte Formeln undurchsichtig erscheinen. Wir werden versuchen, sie zu verbessern.
Sie werden auch feststellen, dass Formeln mithilfe von Ellipsen alle Arten von Listen erweitern, einschließlich Tensoren, Listen von Tensoren (die z.B. aus einer variierenden Anzahl von Tensoren entstehen können) usw. Dies ist ein weiterer Bereich, in dem wir keine exakte Formalität angeben (z.B. Listen sind nicht einmal Teil des StableHLO-Typsystems) und basieren stattdessen auf intuitiver Verständlichkeit.
C) Das letzte nennenswerte Instrument, das wir verwenden, ist die implizite Übertragung. Das StableHLO-Opset unterstützt zwar kein implizites Broadcasting, die Formeln unterstützen dies aber auch im Rahmen der Präzision. Kurz gesagt: Wenn ein Skalar in einem Kontext verwendet wird, in dem ein Tensor erwartet wird, wird der Skalar in der erwarteten Form übertragen.
Um das Beispiel dot_general
fortzusetzen, gibt es eine weitere Einschränkung: 0 <= lhs_batching_dimensions < rank(lhs)
. Wie in der dot_general
-Spezifikation definiert, ist lhs_batching_dimensions
ein Tensor. Sowohl 0
als auch rank(lhs)
sind jedoch Skalare. Nach Anwendung des impliziten Broadcasts wird die Formel zu [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Bei Anwendung auf einen bestimmten dot_general
-Vorgang wird diese Formel als boolescher Tensor ausgewertet. Wenn Formeln als Einschränkungen verwendet werden, gilt die Einschränkung, ob die Formel entweder true
oder einen Tensor auswertet, der nur true
-Elemente hat.
Namen
In Formeln umfasst der lexikalische Geltungsbereich: 1) globale Funktionen, 2) Mitgliederdefinitionen,
3) lokale Definitionen. Die Liste der globalen Funktionen finden Sie unten. Die Liste der Elementdefinitionen hängt vom Programmelement ab, auf das die Notation angewendet wird:
- Bei Vorgängen umfassen Mitgliederdefinitionen Namen, die in den Abschnitten "Eingaben" und "Ausgaben" eingeführt wurden.
- Für alles andere enthalten Mitgliedsdefinitionen strukturelle Teile des Programmelements, die nach den entsprechenden EBNF-Nicht-Terminals benannt sind. In den meisten Fällen werden die Namen dieser Strukturteile durch Konvertierung der Namen der Nicht-Terminals in Snake Case (z. B.
IntegerLiteral
=>integer_literal
) abgerufen. Manchmal werden Namen jedoch während des Prozesses abgekürzt (z. B.QuantizationStorageType
=>storage_type
). In diesem Fall werden die Namen in den Vorgangsspezifikationen ausdrücklich ähnlich wie die Abschnitte „Inputs“ und „Outputs“ eingeführt. - Außerdem enthalten Mitgliedsdefinitionen immer
self
, um auf das entsprechende Programmelement zu verweisen.
Werte
Wenn Formeln ausgewertet werden, funktionieren sie mit den folgenden Wertetypen:
1) Value
(tatsächliche Werte, z.B. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
; sie kennen immer ihren Typ),
2) Placeholder
(zukünftige Werte, z.B. lhs
, rhs
oder result
; ihre tatsächlichen Werte sind noch nicht bekannt, nur ihre Typen sind bekannt),
3) Type
(Typen wie im Abschnitt „Typen“ definiert),
4) Function
(globale Funktionen wie im Abschnitt „Funktionen“ definiert).
Je nach Kontext können sich Namen auf verschiedene Werte beziehen. Genauer gesagt definiert der Abschnitt „Semantik“ für Vorgänge (und Äquivalente für andere Programmelemente) Laufzeitlogik, sodass alle Eingaben als Value
verfügbar sind.
Im Gegensatz dazu wird im Abschnitt „Einschränkungen“ für Vorgänge (und Äquivalente) die Logik für die Kompilierungszeit definiert, d.h. etwas, das normalerweise vor der Laufzeit ausgeführt wird, sodass nur konstante Eingaben als Value
und andere Eingaben nur als Placeholder
verfügbar sind.
Namen | In „Semantik“ | Unter „Einschränkungen“ |
---|---|---|
Globale Funktionen | Function |
Function |
Konstante Eingaben | Value |
Value |
Nicht konstante Eingaben | Value |
Placeholder |
Ausgaben | Value |
Placeholder |
Lokale Definitionen | Hängt von der Definition ab | Hängt von der Definition ab |
Sehen wir uns einen transpose
-Beispielvorgang an:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Für diesen Vorgang ist permutation
eine Konstante und somit sowohl in Bezug auf Semantik als auch Einschränkungen als Value
verfügbar. Im Gegensatz dazu sind operand
und result
in der Semantik als Value
verfügbar, in Einschränkungen jedoch nur als Placeholder
.
Funktionen
Konstruktion von Typen
Es gibt keine Funktionen, die zum Erstellen von Typen verwendet werden können. Stattdessen verwenden wir direkt die Typsyntax, da sie normalerweise prägnanter ist. Beispiel: (tensor<E>, tensor<E>) -> (tensor<E>)
statt function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funktionen für Typen
element_type
ist für Tensortypen und quantisierte Tensortypen definiert und gibt jeweils denTensorElementType
- oderQuantizedTensorElementType
-Teil der entsprechendenTensorType
oderQuantizedTensorType
zurück.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
ist ein Kurzbefehl füris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
ist ein Kürzel füris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
prüft, ob der Typx
zum Typy
hochgestuft werden kann. Wennx
undy
den WertQuantizedTensorElementType
haben, wird das Angebot nur aufstorage_type
angewendet. Diese spezielle Version der Hochstufung wird derzeit im Zusammenhang mit der Berechnung von Reduzierungen verwendet (weitere Informationen finden Sie unter RFC).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
ist eine Kurzform füris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Verfügbar für alle Typen. Beispiel:is_float(x)
gibttrue
zurück, wennx
einFloatType
ist. Wennx
ein Wert oder Platzhalter ist, ist diese Funktion eine Kurzform füris_type_name(type(x))
.max_value(x: Type) -> Value
gibt den Maximalwert vonTensorElementType
zurück. Wennx
keinTensorElementType
ist, wirdNone
zurückgegeben.min_value(x: Type) -> Value
gibt den kleinstmöglichen Wert vonTensorElementType
zurück. Wennx
keinTensorElementType
ist, wirdNone
zurückgegeben.member_name(x: Value | Placeholder | Type) -> Any
: Verfügbar für alle Mitgliedsdefinitionenmember_name
aller Typen.tensor_element_type(x)
gibt beispielsweise denTensorElementType
-Teil einer entsprechendenTensorType
zurück. Wennx
ein Wert oder Platzhalter ist, ist diese Funktion eine Kurzform fürmember_name(type(x))
. Wennx
kein Typ mit einem geeigneten Mitglied oder einem Wert oder Platzhalter dieses Typs ist, wirdNone
zurückgegeben.
Konstruktion von Werten
operation_name(*xs: Value | Type) -> Value
. Für alle Vorgänge verfügbar.add(lhs, rhs)
nimmt beispielsweise die beiden Tensorwertelhs
undrhs
und gibt die Ausgabe der Auswertung des Vorgangsadd
mit diesen Eingaben zurück. Bei einigen Vorgängen, z.B.broadcast_in_dim
, sind die Ausgabentypen „lastig“, d.h. sie werden benötigt, um einen Vorgang zu bewerten. In diesem Fall übernimmt die Funktion diese Typen als Argumente.
Funktion für Werte
Alle Operatoren und Funktionen von Python sind verfügbar. Beispielsweise stehen sowohl subscription- als auch Slicing-Notationen aus Python zur Indexierung in Tensoren, quantisierten Tensoren und Tupeln zur Verfügung.
to_destination_type(x: Value, destination_type: Type) -> Value
ist auf Tensoren definiert und gibt den konvertierten Wert vonx
basierend auftype(x)
unddestination_type
so zurück:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Es gibt bereits eine Diskussion über das Zusammenführen von convert
-, uniform_quantize
- und uniform_dequantize
-Vorgängen (#1576).
Nach der Zusammenführung benötigen wir die obige Funktion nicht und können stattdessen den Vorgangsnamen für convert
verwenden.
is_nan(x: Value) -> Value
ist auf Tensoren definiert und gibttrue
zurück, wenn alle Elemente vonx
NaN
sind, andernfallsfalse
. Wennx
kein Tensor ist, wirdNone
zurückgegeben.is_sorted(x: Value) -> Value
ist auf Tensoren definiert und gibttrue
zurück, wenn Elemente vonx
in aufsteigender Reihenfolge in Bezug auf die aufsteigende lexikografische Reihenfolge ihrer Indexe sortiert werden, andernfalls nachfalse
. Wennx
kein Tensor ist, wirdNone
zurückgegeben.is_unique(x: Value) -> Value
ist auf Tensoren definiert und gibttrue
zurück, wennx
keine doppelten Elemente oder sonstfalse
hat. Wennx
kein Tensor ist, wirdNone
zurückgegeben.member_name(x: Value) -> Any
ist für alle Mitgliedsdefinitionenmember_name
aller Werte definiert.real_part(x)
gibt beispielsweise denRealPart
-Teil einer entsprechendenComplexConstant
zurück. Wennx
kein Wert für ein geeignetes Mitglied ist, wirdNone
zurückgegeben.same(x: Value) -> Value
ist auf Tensoren definiert und gibttrue
zurück, wenn die Elemente vonx
alle gleich sind, oder andernfallsfalse
. Wenn der Tensor keine Elemente hat, zählt dies als „alle gleich“, d.h., die Funktion gibttrue
zurück. Wennx
kein Tensor ist, wirdNone
zurückgegeben.split(x: Value, num_results: Value, axis: Value) -> Value
ist auf Tensoren definiert und gibtnum_results
-Segmente vonx
entlang deraxis
-Achse zurück. Wennx
weder ein Tensor nochdim(x, axis) % num_results != 0
ist, wirdNone
zurückgegeben.
Formberechnungen
axes(x: Value | Placeholder | Type) -> Value
ist eine Kurzform fürrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
ist eine Kurzform fürshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
ist eine Kurzform fürlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
ist auf Tensoren definiert und gibtsize(x)
-Indizes für die entsprechendeTensorType
zurück, sortiert in aufsteigender lexikografischer Reihenfolge, z.B.[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Wennx
kein Tensortyp, kein quantisierter Tensortyp oder ein Wert oder Platzhalter eines dieser Typen ist, wirdNone
zurückgegeben.rank(x: Value | Placeholder | Type) -> Value
ist eine Kurzform fürsize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
wird im Abschnitt „Funktionen für Typen“ übermember_name
definiert.size(x: Value | Placeholder | Type) -> Value
ist eine Kurzform fürreduce(lambda x, y: x * y, shape(x))
.
Quantisierungsberechnungen
def baseline_element_type(x: Value | Placeholder | Type) -> Type
ist ein Kürzel fürelement_type(baseline_type(x))
.baseline_type
wird für Tensortypen und quantisierte Tensortypen definiert und transformiert sie in eine "Basislinie", d.h. einen Typ mit der gleichen Form, aber mit den Quantisierungsparametern des Elementtyps auf Standardwerte zurückgesetzt. Dies wird als praktischer Trick verwendet, um sowohl Tensor- als auch quantisierte Tensortypen einheitlich zu vergleichen, was recht häufig erforderlich ist. Bei quantisierten Typen ermöglicht dies, dass Vergleichstypen die Quantisierungsparameter ignorieren. Das heißt,shape
,storage_type
,expressed_type
,storage_min
,storage_max
undquantization_dimension
(für quantisierten Typ pro Achse) müssen alle übereinstimmen,scales
undzero points
können jedoch abweichen.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
wird für quantisierte Tensortypen definiert und wandelt sie in Gleitkomma-Tensortypen um. Dazu werden quantisierte Elemente, die Ganzzahlwerte des Speichertyps darstellen, in entsprechende Gleitkommawerte des Ausdruckstyps konvertiert. Dabei werden der Nullpunkt und die Skalierung verwendet, die dem quantisierten Elementtyp zugeordnet sind.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
wird für Gleitkomma-Tensortypen definiert und wandelt sie in quantisierte Tensortypen um. Dazu werden Gleitkommawerte des ausgedrückten Typs mithilfe des Nullpunkts und der Skalierung, die dem quantisierten Elementtyp zugeordnet sind, in entsprechende Ganzzahlwerte des Speichertyps konvertiert.
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
dequantize_op_quantize
wird verwendet, um elementweise Berechnungen für quantisierte Tensoren anzugeben. Sie dequantisiert, d.h. quantisiert quantisierte Elemente, wandelt sie also in ihre Ausdruckstypen um, führt dann einen Vorgang aus und quantisiert dann die Ergebnisse, d.h., wandelt die Ergebnisse wieder in ihren Speichertyp um. Derzeit funktioniert diese Funktion nur für die Quantisierung pro Tensor. Die Quantisierung pro Achse wird ausgeführt (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
Rasterberechnungen
cross_partition(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt "cross_ Nachbau".cross_replica(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt "cross_ Nachbau".cross_replica_and_partition(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt "cross_replica_and_partition".flattened_ids(replica_groups: Value) -> Value
. Weitere Informationen finden Sie oben im Abschnitt "flattened_ids".