StableHLO-Spezifikation

StableHLO ist eine Reihe von Operationen für High-Level-Operationen (HLO) in Modellen für maschinelles Lernen (ML). StableHLO dient als Portabilitätsebene zwischen verschiedenen ML-Frameworks und ML-Compilern: ML-Frameworks, die StableHLO-Programme erstellen, sind mit ML-Compilern kompatibel, die StableHLO-Programme verwenden.

Unser Ziel ist es, die ML-Entwicklung zu vereinfachen und zu beschleunigen, indem wir die Interoperabilität zwischen verschiedenen ML-Frameworks (z. B. TensorFlow, JAX und PyTorch) und ML-Compilern (z. B. XLA und IREE) verbessern. Zu diesem Zweck enthält dieses Dokument eine Spezifikation für die StableHLO-Programmiersprache.

Diese Spezifikation enthält drei Hauptabschnitte. Zuerst wird im Abschnitt Programme die Struktur von StableHLO-Programmen beschrieben, die aus StableHLO-Funktionen bestehen, die wiederum aus StableHLO-Operationen bestehen. Im Abschnitt Ops dieser Struktur wird die Semantik der einzelnen Vorgänge angegeben. Im Abschnitt Ausführung finden Sie die Semantik für alle diese Vorgänge, die gemeinsam in einem Programm ausgeführt werden. Im Abschnitt Notation wird die in der gesamten Spezifikation verwendete Notation erläutert.

Wenn Sie die Spezifikation einer früheren Version von StableHLO aufrufen möchten, öffnen Sie das Repository für das entsprechende Release mit Tag. Ein Beispiel ist die StableHLO v0.19.0-Spezifikation. Änderungen, die bei jeder Nebenversionserhöhung von StableHLO aufgetreten sind, finden Sie im Versionslog in VhloDialect.td.

Programme

Program ::= {Func}

StableHLO-Programme bestehen aus einer beliebigen Anzahl von StableHLO-Funktionen. Unten sehen Sie ein Beispielprogramm mit einer Funktion @main, die drei Eingaben (%image, %weights und %bias) und eine Ausgabe hat. Der Funktionskörper hat sechs Vorgänge.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Funktionen

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO-Funktionen (auch benannte Funktionen genannt) haben eine Kennung, Ein-/Ausgaben und einen Rumpf. In Zukunft planen wir, zusätzliche Metadaten für Funktionen einzuführen, um eine bessere Kompatibilität mit HLO zu erreichen (#425, #626, #740, #744).

IDs

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO-Kennzeichnungen ähneln Kennzeichnungen in vielen Programmiersprachen, haben aber zwei Besonderheiten: 1) Alle Kennzeichnungen haben Sigils, die verschiedene Arten von Kennzeichnungen unterscheiden, 2) Wertkennzeichnungen können vollständig numerisch sein, um die Generierung von StableHLO-Programmen zu vereinfachen.

Typen

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO-Typen werden in Werttypen (auch Erstklassentypen genannt), die StableHLO-Werte darstellen, und Nicht-Werttypen, die andere Programmelemente beschreiben, unterteilt. StableHLO-Typen ähneln Typen in vielen Programmiersprachen. Die Besonderheit ist, dass StableHLO domänenspezifisch ist, was zu einigen ungewöhnlichen Ergebnissen führt (z. B. sind skalare Typen keine Wertetypen).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Tensortypen stellen Tensoren dar, d.h. mehrdimensionale Arrays. Sie haben eine Form und einen Elementtyp. Eine Form stellt nicht negative oder unbekannte Dimensionsgrößen in aufsteigender Reihenfolge der entsprechenden Dimensionen (auch Achsen genannt) dar, die von 0 bis R-1 nummeriert sind. Die Anzahl der Dimensionen R wird als Rang bezeichnet. tensor<2x3xf32> ist beispielsweise ein Tensortyp mit der Form 2x3 und dem Elementtyp f32. Sie hat zwei Dimensionen (oder anders ausgedrückt: zwei Achsen) – die 0. und die 1. Dimension – mit den Größen 2 und 3. Der Rang ist 2.

Formen können teilweise oder vollständig unbekannt sein (dynamisch), z.B. ist tensor<?x2xf64> teilweise unbekannt und tensor<?x?xf64> vollständig unbekannt. Dynamische Dimensionsgrößen werden mit einem ? dargestellt. Formen können nicht aus dem Ranking entfernt werden.

In Zukunft möchten wir die Tensortypen über Dimensionsgrößen und Elementtypen hinaus erweitern, z. B. um Layouts (#629) und Sparsity (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Name Typ Einschränkungen
storage_type Ganzzahltyp (C1–C3), (C8)
storage_min Ganzzahlkonstante (C1), (C3), (C7)
storage_max Ganzzahlkonstante (C2), (C3), (C7)
expressed_type Gleitkommatyp (C4)
quantization_dimension optionale ganzzahlige Konstante (C10-C12)
scales Variadische Anzahl von Gleitkommakonstanten (C4–C6), (C9), (C10), (C13)
zero_points Variadische Anzahl von Ganzzahlkonstanten (C7–C9)

Quantisierte Elementtypen stellen Ganzzahlwerte eines Speichertyp im Bereich von storage_min bis storage_max (einschließlich) dar, die Gleitkommawerten eines ausgedrückten Typs entsprechen. Für einen bestimmten Ganzzahlwert i kann der entsprechende Gleitkommawert f als f = (i - zero_point) * scale berechnet werden, wobei scale und zero_point als Quantisierungsparameter bezeichnet werden. storage_min und storage_max sind in der Grammatik optional, haben aber die Standardwerte min_value(storage_type) bzw. max_value(storage_type). Für quantisierte Elementtypen gelten die folgenden Einschränkungen:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Wenn is_empty(quantization_dimension), dann size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Derzeit ist QuantizationScale eine Gleitkommakonstante. Es besteht jedoch großes Interesse an ganzzahligen Skalen, die mit Multiplikatoren und Verschiebungen dargestellt werden. Wir planen, dies in naher Zukunft zu untersuchen (#1404).

Es gibt eine laufende Diskussion über die Semantik von QuantizationZeroPoint, einschließlich des Typs, der Werte und der Frage, ob es nur einen oder möglicherweise mehrere Nullpunkte in einem quantisierten Tensortyp geben kann. Basierend auf den Ergebnissen dieser Diskussion kann sich die Spezifikation für null Punkte in Zukunft ändern (#1405).

Eine weitere laufende Diskussion betrifft die Semantik von QuantizationStorageMin und QuantizationStorageMax, um festzustellen, ob Einschränkungen für diese Werte und für die Werte quantisierter Tensoren auferlegt werden sollten (#1406).

Schließlich planen wir, unbekannte Skalen und Nullpunkte ähnlich darzustellen wie unbekannte Dimensionsgrößen (#1407).

Quantisierte Tensortypen stellen Tensoren mit quantisierten Elementen dar. Diese Tensoren sind genau wie reguläre Tensoren, nur dass ihre Elemente quantisierte Elementtypen anstelle von regulären Elementtypen haben.

Bei quantisierten Tensoren kann die Quantisierung pro Tensor erfolgen. Das bedeutet, dass für den gesamten Tensor ein scale- und ein zero_point-Wert vorhanden ist. Sie kann aber auch pro Achse erfolgen. Das bedeutet, dass mehrere scales- und zero_points-Werte vorhanden sind, ein Paar pro Slice einer bestimmten Dimension quantization_dimension. Formaler ausgedrückt: In einem Tensor t mit achsenweiser Quantisierung gibt es dim(t, quantization_dimension) Slices von quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] usw. Alle Elemente im i-ten Slice verwenden scales[i] und zero_points[i] als Quantisierungsparameter. Für quantisierte Tensortypen gelten die folgenden Einschränkungen:

  • Für die Quantisierung pro Tensor:
    • Keine zusätzlichen Einschränkungen.
  • Für die Quantisierung pro Achse:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Tokentypen stellen Tokens dar, d.h. undurchsichtige Werte, die von einigen Vorgängen erstellt und verwendet werden. Tokens werden verwendet, um die Ausführungsreihenfolge von Vorgängen festzulegen, wie im Abschnitt Ausführung beschrieben.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Puffertypen stellen Puffer dar. In XLA sind Puffer beispielsweise mehrdimensionale Arrays mit konsistentem Speicher. Ähnlich wie bei Tensortypen haben Puffertypen eine Form und einen Elementtyp. Eine Form stellt nicht negative oder unbekannte Dimensionsgrößen in aufsteigender Reihenfolge der entsprechenden Dimensionen (auch Achsen genannt) dar, die von 0 bis R-1 nummeriert sind. Die Anzahl der Dimensionen R wird als Rang bezeichnet. memref<2x3xf32> ist beispielsweise ein Puffertyp mit der Form 2x3 und dem Elementtyp f32. Es hat zwei Dimensionen (oder mit anderen Worten zwei Achsen) – die 0. Dimension und die 1. Dimension – mit den Größen 2 und 3. Der Rang ist 2.

Puffer können mit einem custom_call bis CreateBuffer oder Pin zugewiesen und mit einem custom_call bis Unpin freigegeben werden. Nur custom_call-Vorgänge können den Inhalt von Puffern lesen und schreiben. Weitere Informationen finden Sie unter custom_call.

Tupeltypen stellen Tupel dar, d.h. heterogene Listen. Tupel sind eine Legacy-Funktion, die nur zur Kompatibilität mit HLO vorhanden ist. In HLO werden Tupel verwendet, um variadische Ein- und Ausgaben darzustellen. In StableHLO werden variadische Ein- und Ausgaben nativ unterstützt. Tupel werden in StableHLO nur verwendet, um die HLO-ABI umfassend darzustellen, wobei z. B. T, tuple<T> und tuple<tuple<T>> je nach Implementierung erheblich variieren können. In Zukunft planen wir, Änderungen an der HLO-ABI vorzunehmen, die es uns ermöglichen, Tupeltypen aus StableHLO zu entfernen (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Elementtypen stellen Elemente von Tensortypen dar. Anders als in vielen Programmiersprachen sind diese Typen in StableHLO nicht erstklassig. Das bedeutet, dass StableHLO-Programme Werte dieser Typen nicht direkt darstellen können. Daher ist es üblich, skalare Werte vom Typ T mit 0-dimensionalen Tensorwerten vom Typ tensor<T> darzustellen.

  • Der boolesche Typ stellt die booleschen Werte true und false dar.
  • Ganzzahltypen können entweder mit Vorzeichen (si) oder ohne Vorzeichen (ui) sein und eine der unterstützten Bitbreiten (2, 4, 8, 16, 32 oder 64) haben. Mit Vorzeichen versehene siN-Typen stellen Ganzzahlwerte von -2^(N-1) bis einschließlich 2^(N-1)-1 dar und vorzeichenlose uiN-Typen stellen Ganzzahlwerte von 0 bis einschließlich 2^N-1 dar.
  • Gleitkommatypen können einen der folgenden Werte haben:
  • Komplexe Typen stellen komplexe Werte mit einem Realteil und einem Imaginärteil desselben Elementtyps dar. Unterstützte komplexe Typen sind complex<f32> (beide Teile sind vom Typ f32) und complex<f64> (beide Teile sind vom Typ f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Funktionstypen stellen sowohl benannte als auch anonyme Funktionen dar. Sie haben Eingabetypen (die Liste der Typen auf der linken Seite von ->) und Ausgabetypen (die Liste der Typen auf der rechten Seite von ->). In vielen Programmiersprachen sind Funktionstypen erstklassig, in StableHLO jedoch nicht.

StringType ::= 'string'

String-Typ steht für Bytefolgen. Im Gegensatz zu vielen Programmiersprachen ist der String-Typ in StableHLO nicht erstklassig und wird nur verwendet, um statische Metadaten für Programmelemente anzugeben.

Vorgänge

StableHLO-Vorgänge (auch Ops genannt) stellen eine geschlossene Menge von Vorgängen auf hoher Ebene in Machine-Learning-Modellen dar. Wie oben beschrieben, ist die StableHLO-Syntax stark von MLIR inspiriert. Das ist nicht unbedingt die ergonomischste Alternative, aber wohl die beste Lösung für das Ziel von StableHLO, die Interoperabilität zwischen ML-Frameworks und ML-Compilern zu verbessern.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO-Vorgänge (auch Ops genannt) haben einen Namen, Ein-/Ausgaben und eine Signatur. Der Name besteht aus dem Präfix stablehlo. und einem Mnemonik, das einen der unterstützten Vorgänge eindeutig identifiziert. Unten finden Sie eine vollständige Liste aller unterstützten Operationen.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Vorgänge nutzen Eingaben und erzeugen Ausgaben. Eingaben werden in Eingabewerte (während der Ausführung berechnet), Eingabefunktionen (statisch bereitgestellt, da Funktionen in StableHLO keine erstklassigen Werte sind) und Eingabeattribute (ebenfalls statisch bereitgestellt) unterteilt. Die Art der Ein- und Ausgaben, die von einem Vorgang verwendet und erzeugt werden, hängt von seinem Mnemonic ab. Der add-Vorgang verwendet beispielsweise 2 Eingabewerte und erzeugt 1 Ausgabewert. Im Vergleich dazu benötigt der Vorgang select_and_scatter drei Eingabewerte, zwei Eingabefunktionen und drei Eingabeattribute.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Eingabefunktionen (auch anonyme Funktionen genannt) ähneln benannten Funktionen sehr, mit folgenden Ausnahmen: 1) Sie haben keine Kennung (daher der Name „anonym“), 2) Sie deklarieren keine Ausgabetypen (Ausgabetypen werden vom return-Vorgang innerhalb der Funktion abgeleitet).

Die Syntax für Eingabefunktionen enthält einen derzeit nicht verwendeten Teil (siehe die Unused-Produktion oben), der aus Gründen der Kompatibilität mit MLIR vorhanden ist. In MLIR gibt es ein allgemeineres Konzept von „Regionen“, die mehrere „Blöcke“ von Vorgängen haben können, die über Sprungvorgänge miteinander verbunden sind. Diese Blöcke haben IDs, die der Unused-Produktion entsprechen, sodass sie voneinander unterschieden werden können. StableHLO hat keine Jump-Operationen, daher wird der entsprechende Teil der MLIR-Syntax nicht verwendet (ist aber weiterhin vorhanden).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Eingabeattribute haben einen Namen und einen Wert, der eine der unterstützten Konstanten ist. Sie sind die primäre Methode zum Angeben statischer Metadaten für Programmelemente. Der concatenate-Vorgang verwendet beispielsweise das Attribut dimension, um die Dimension anzugeben, entlang derer die Eingabewerte verkettet werden. Ebenso verwendet der slice-Vorgang mehrere Attribute wie start_indices und limit_indices, um die Grenzen anzugeben, die zum Aufteilen des Eingabewerts verwendet werden.

Derzeit enthalten StableHLO-Programme manchmal Attribute, die in diesem Dokument nicht beschrieben werden. In Zukunft planen wir, diese Attribute entweder in das StableHLO-Opset aufzunehmen oder zu verhindern, dass sie in StableHLO-Programmen vorkommen. In der Zwischenzeit finden Sie hier die Liste dieser Attribute:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Standortmetadaten (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

Die Op-Signatur besteht aus den Typen aller Eingabewerte (der Liste der Typen auf der linken Seite von ->) und den Typen aller Ausgabewerte (der Liste der Typen auf der rechten Seite von ->). Streng genommen sind Eingabetypen redundant und Ausgabetypen sind fast immer redundant, da für die meisten StableHLO-Vorgänge Ausgabetypen aus Eingaben abgeleitet werden können. Die Signatur des Vorgangs ist jedoch bewusst Teil der StableHLO-Syntax, um die Kompatibilität mit MLIR zu gewährleisten.

Unten sehen Sie ein Beispiel für einen Vorgang mit dem Mnemonic select_and_scatter. Sie verwendet 3 Eingabewerte (%operand, %source und %init_value), 2 Eingabefunktionen und 3 Eingabeattribute (window_dimensions, window_strides und padding). Beachten Sie, dass die Signatur des Vorgangs nur die Typen der Eingabewerte enthält (nicht die Typen der Eingabefunktionen und ‑attribute, die inline angegeben werden).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Konstanten

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO-Konstanten haben ein Literal und einen Typ, die zusammen einen StableHLO-Wert darstellen. Im Allgemeinen ist der Typ Teil der Konstantensyntax, es sei denn, er ist eindeutig (z.B. hat eine boolesche Konstante eindeutig den Typ i1, während eine Ganzzahlkonstante mehrere mögliche Typen haben kann).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Boolesche Konstanten stellen die booleschen Werte true und false dar. Boolesche Konstanten haben den Typ i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Ganzzahlkonstanten stellen Ganzzahlwerte über Strings dar, die die Dezimal- oder Hexadezimalnotation verwenden. Andere Basen, z.B. binär oder oktal, werden nicht unterstützt. Für Ganzzahlkonstanten gelten die folgenden Einschränkungen:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Gleitkommakonstanten stellen Gleitkommawerte über Strings dar, die die Dezimal- oder wissenschaftliche Notation verwenden. Außerdem kann die hexadezimale Notation verwendet werden, um die zugrunde liegenden Bits im Gleitkommaformat des entsprechenden Typs direkt anzugeben. Für Gleitkommakonstanten gelten die folgenden Einschränkungen:

  • (C1) Wenn keine hexadezimale Notation verwendet wird,is_wellformed(float_literal, float_type).
  • (C2) Wenn die hexadezimale Notation verwendet wird, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Komplexe Konstanten stellen komplexe Werte mithilfe von Listen mit einem Realteil (kommt zuerst) und einem Imaginärteil (kommt als Zweites) dar. Beispiel: (1.0, 0.0) : complex<f32> steht für 1.0 + 0.0i und (0.0, 1.0) : complex<f32> für 0.0 + 1.0i. Die Reihenfolge, in der diese Teile dann im Arbeitsspeicher gespeichert werden, ist implementierungsabhängig. Für komplexe Konstanten gelten die folgenden Einschränkungen:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Tensorkonstanten stellen Tensorwerte mithilfe von verschachtelten Listen dar, die über die NumPy-Notation angegeben werden. Beispielsweise stellt dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> einen Tensorwert mit der folgenden Zuordnung von Index zu Element dar: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Die Reihenfolge, in der diese Elemente dann im Speicher abgelegt werden, ist implementierungsabhängig. Für Tensor-Konstanten gelten die folgenden Einschränkungen:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), wobei:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), wobei:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • Andernfalls false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Quantisierte Tensorkonstanten stellen quantisierte Tensorwerte mit derselben Notation wie Tensorkonstanten dar. Die Elemente werden als Konstanten ihres Speichertyps angegeben. Für quantisierte Tensor-Konstanten gelten die folgenden Einschränkungen:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Stringliterale bestehen aus Byte, die mit ASCII-Zeichen und Escape-Sequenzen angegeben werden. Sie sind unabhängig von der Codierung, daher ist die Interpretation dieser Bytes implementierungsdefiniert. Stringliterale haben den Typ string.

Operativer Betrieb

abs

Semantik

Führt eine elementweise „abs“-Operation für den Tensor operand aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Bei Ganzzahlen mit Vorzeichen: Ganzzahlmodul.
  • Für Gleitkommazahlen: abs aus IEEE-754.
  • Für komplexe Zahlen: komplexer Modulus.
  • Für quantisierte Typen: dequantize_op_quantize(abs, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Typ „Ganzzahl mit Vorzeichen“, „Gleitkomma“, „Komplex“ oder „Tensor mit per-Tensor-Quantisierung“ (C1-C2)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl mit Vorzeichen“ oder „Gleitkomma“ oder per-Tensor quantisierter Tensor (C1-C2)

Einschränkungen

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) wird so definiert:
    • complex_element_type(element_type(operand)), wenn is_complex(operand).
    • Andernfalls baseline_element_type(operand).

Beispiele

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]

 Weitere Beispiele

Hinzufügen

Semantik

Führt eine elementweise Addition von zwei Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches ODER.
  • Bei Ganzzahlen: Addition von Ganzzahlen.
  • Für Gleitkommazahlen: addition aus IEEE-754.
  • Für komplexe Zahlen: komplexe Addition.
  • Für quantisierte Typen: dequantize_op_quantize(add, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder quantisierter Tensor (C1–C6)
(I2) rhs Tensor oder quantisierter Tensor (C1–C5), (C7)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C1–C7)

Einschränkungen

  • Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
    • (C1) type(lhs) = type(rhs) = type(result).
  • Wenn für den Vorgang quantisierte Tensoren verwendet werden:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Wenn is_per_axis_quantized(lhs), dann quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Wenn is_per_axis_quantized(rhs), dann quantization_dimension(rhs) = quantization_dimension(result).

Beispiele

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]

 Weitere Beispiele

after_all

Semantik

Sorgt dafür, dass die Vorgänge, die inputs erzeugen, vor allen Vorgängen ausgeführt werden, die von result abhängen. Die Ausführung dieses Vorgangs hat keine Auswirkungen. Er ist nur dazu da, Datenabhängigkeiten von result zu inputs herzustellen.

Eingaben

Label Name Typ
(I1) inputs variadische Anzahl von token

Ausgaben

Name Typ
result token

Beispiele

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token

 Weitere Beispiele

all_gather

Semantik

Verkettet in jeder Prozessgruppe im StableHLO-Prozessraster die Werte der operands-Tensoren aus jedem Prozess entlang all_gather_dim und erzeugt results-Tensoren.

Bei der Operation wird das StableHLO-Prozessraster in process_groups aufgeteilt, das so definiert ist:

  • cross_replica(replica_groups) bei channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) bei channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) bei channel_id > 0 and use_global_device_ids = true.

Gehen Sie anschließend in jedem process_group so vor:

  • operands...@receiver = [operand@sender for sender in process_group] für alle receiver in process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) für alle process in process_group.

Eingaben

Label Name Typ Einschränkungen
(I1) operands variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1), (C6)
(I2) all_gather_dim Konstante vom Typ si64 (C1), (C6)
(I3) replica_groups 2‑dimensionaler Tensor vom Typ si64 (C2-C4)
(I4) channel_id Konstante vom Typ si64 (C5)
(I5) use_global_device_ids Konstante vom Typ i1 (C5)

Ausgaben

Name Typ Einschränkungen
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C6)

Einschränkungen

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) wird so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_replicas, wenn cross_replica_and_partition verwendet wird.
    • num_processes, wenn flattened_ids verwendet wird.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Wenn use_global_device_ids = true, dann channel_id > 0.
  • (C6) type(results...) = type(operands...) außer:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
  // channel_id = 0
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
  // use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

 Weitere Beispiele

all_reduce

Semantik

Wendet in jeder Prozessgruppe im StableHLO-Prozessraster eine Reduzierungsfunktion computation auf die Werte der operands-Tensoren aus jedem Prozess an und gibt results-Tensoren aus.

Bei der Operation wird das StableHLO-Prozessraster in process_groups aufgeteilt, das so definiert ist:

  • cross_replica(replica_groups) bei channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) bei channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) bei channel_id > 0 and use_global_device_ids = true.

Gehen Sie anschließend in jedem process_group so vor:

  • results...@process[result_index] = exec(schedule) für einen binären Baum schedule, wobei:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule ist ein implementierungsdefinierter binärer Baum, dessen Inorder-Durchlauf to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])) ist.

Eingaben

Label Name Typ Einschränkungen
(I1) operands variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C5), (C6)
(I2) replica_groups variadische Anzahl von eindimensionalen Tensorkonstanten vom Typ si64 (C1–C3)
(I3) channel_id Konstante vom Typ si64 (C4)
(I4) use_global_device_ids Konstante vom Typ i1 (C4)
(I5) computation Funktion (C5)

Ausgaben

Name Typ Einschränkungen
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C6–C7)

Einschränkungen

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) wird so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_replicas, wenn cross_replica_and_partition verwendet wird.
    • num_processes, wenn flattened_ids verwendet wird.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Wenn use_global_device_ids = true, dann channel_id > 0.
  • (C5) computation hat den Typ (tensor<E>, tensor<E>) -> (tensor<E>), wobei is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  // channel_id = 0
  channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
  // use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

 Weitere Beispiele

all_to_all

Semantik

all_to_all

In jeder Prozessgruppe im StableHLO-Prozessraster werden die Werte der operands-Tensoren entlang split_dimension in Teile aufgeteilt, die auf die Prozesse verteilt, entlang concat_dimension verkettet und results-Tensoren erzeugt. Bei der Operation wird das StableHLO-Prozessraster in process_groups aufgeteilt, das so definiert ist:

  • cross_replica(replica_groups), wenn channel_id <= 0.
  • cross_partition(replica_groups), wenn channel_id > 0.

Gehen Sie anschließend in jedem process_group so vor:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) für alle sender in process_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group], wobei receiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Eingaben

Label Name Typ Einschränkungen
(I1) operands variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C3), (C9)
(I2) split_dimension Konstante vom Typ si64 (C1), (C2), (C9)
(I3) concat_dimension Konstante vom Typ si64 (C3), (C9)
(I4) split_count Konstante vom Typ si64 (C2), (C4), (C8), (C9)
(I5) replica_groups 2‑dimensionaler Tensor vom Typ si64 (C5–C8)
(I6) channel_id Konstante vom Typ si64

Ausgaben

Name Typ Einschränkungen
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C9)

Einschränkungen

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) wird so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...), außer wenn split_dimension != concat_dimension:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
  // channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

 Weitere Beispiele

und

Semantik

Führt eine elementweise AND-Operation für zwei Tensoren lhs und rhs aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches UND.
  • Für Ganzzahlen: bitweises UND.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)
(I2) rhs Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)

Einschränkungen

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]

 Weitere Beispiele

atan2

Semantik

Führt eine elementweise atan2-Operation für die Tensoren lhs und rhs aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: atan2 aus IEEE-754.
  • Für komplexe Zahlen: komplexer atan2.
  • Für quantisierte Typen: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)
(I2) rhs Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

 Weitere Beispiele

batch_norm_grad

Semantik

Berechnet die Gradienten mehrerer Eingaben von batch_norm_training durch Backpropagation von grad_output und erzeugt die Tensoren grad_operand, grad_scale und grad_offset. Formaler kann dieser Vorgang als Zerlegung in vorhandene StableHLO-Vorgänge mit der folgenden Python-Syntax ausgedrückt werden:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Für quantisierte Typen wird dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1–C3), (C5)
(I2) scale 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C4), (C5)
(I3) mean 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C4)
(I4) variance 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C4)
(I5) grad_output Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C2), (C3)
(I6) epsilon Konstante vom Typ f32
(I7) feature_index Konstante vom Typ si64 (C1), (C5)

Ausgaben

Name Typ Einschränkungen
grad_operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C2), (C3)
grad_scale 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C4)
grad_offset 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C4)

Einschränkungen

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale und grad_offset haben dieselbe baseline_element_type.
  • (C3) operand, grad_output und grad_operand haben dieselbe Form.
  • (C4) scale, mean, variance, grad_scale und grad_offset haben dieselbe Form.
  • (C5) size(scale) = dim(operand, feature_index).

Beispiele

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
 <    tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantik

Normalisiert den operand-Tensor über alle Dimensionen hinweg, mit Ausnahme der feature_index-Dimension, und erzeugt einen result-Tensor. Formaler kann dieser Vorgang als Zerlegung in vorhandene StableHLO-Vorgänge mit der folgenden Python-Syntax ausgedrückt werden:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Für quantisierte Typen wird dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1–C7)
(I2) scale 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C3)
(I3) offset 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C4)
(I4) mean 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C5)
(I5) variance 1‑dimensionaler Tensor vom Gleitkomma- oder Tensor-quantisierten Typ (C2), (C6)
(I6) epsilon Konstante vom Typ f32
(I7) feature_index Konstante vom Typ si64 (C1), (C3–C6)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C2), (C7)

Einschränkungen

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance und result haben dieselbe baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantik

Berechnet den Mittelwert und die Varianz über alle Dimensionen hinweg, mit Ausnahme der feature_index-Dimension, und normalisiert den operand-Tensor, wodurch die Tensoren output, batch_mean und batch_var entstehen. Formaler kann dieser Vorgang als Zerlegung in vorhandene StableHLO-Vorgänge mit der folgenden Python-Syntax ausgedrückt werden:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Für quantisierte Typen wird dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)
(I2) scale 1‑dimensionaler Tensor mit Gleitkomma- oder Tensor-quantisierten Werten (C2), (C3)
(I3) offset 1‑dimensionaler Tensor mit Gleitkomma- oder Tensor-quantisierten Werten (C2), (C4)
(I4) epsilon Konstante vom Typ f32 (C1), (C3–C6)
(I5) feature_index Konstante vom Typ si64 (C1), (C3–C6)

Ausgaben

Name Typ Einschränkungen
output Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C7)
batch_mean 1‑dimensionaler Tensor mit Gleitkomma- oder Tensor-quantisierten Werten (C2), (C5)
batch_var 1‑dimensionaler Tensor mit Gleitkomma- oder Tensor-quantisierten Werten (C2), (C6)

Einschränkungen

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var und output haben dieselbe baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Beispiele

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
 <   (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantik

Führt einen Bitcast-Vorgang für den operand-Tensor aus und erzeugt einen result-Tensor, in dem die Bits des gesamten operand-Tensors mit dem Typ des result-Tensors neu interpretiert werden.

Formal ausgedrückt: Bei E = element_type(operand), E' = element_type(result) und R = rank(operand) gilt Folgendes:

  • Wenn num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Wenn num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Wenn num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits gibt die In-Memory-Darstellung eines bestimmten Werts zurück. Das Verhalten ist implementierungsdefiniert, da die genaue Darstellung von Tensoren und Elementtypen implementierungsdefiniert ist.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder quantisierter Tensor (C1-C2)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C1-C2)

Einschränkungen

  • (C1) Angesichts von E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) und R = rank(operand):
    • Wenn num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Wenn num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) für alle 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Wenn num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) für alle 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Wenn is_complex(operand) or is_complex(result), dann is_complex(operand) and is_complex(result).

Beispiele

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

 Weitere Beispiele

broadcast_in_dim

Semantik

broadcast_in_dim

Erweitert die Dimensionen und/oder den Rang eines Eingabetensors, indem die Daten im operand-Tensor dupliziert werden, und erzeugt einen result-Tensor. Formaler ausgedrückt: result[result_index] = operand[operand_index], wobei für alle d in axes(operand) gilt:

  • operand_index[d] = 0, wenn dim(operand, d) = 1.
  • Andernfalls operand_index[d] = result_index[broadcast_dimensions[d]].

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder quantisierter Tensor (C1-C2), (C5-C6)
(I2) broadcast_dimensions 1-dimensionaler Tensor vom Typ si64 (C2-C6)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C1), (C3), (C5–C6)

Einschränkungen

  • (C1) element_type(result) wird so berechnet:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), mit der Ausnahme, dass quantization_dimension(operand), scales(operand) und zero_points(operand) von quantization_dimension(result), scales(result) und zero_points(result) abweichen können.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Für alle d in axes(operand):
    • dim(operand, d) = 1 oder
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Wenn is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Bei dim(operand, quantization_dimension(operand)) = 1, dann scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Beispiele

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 Weitere Beispiele

Supportanfrage

Semantik

Gibt die Ausgabe zurück, die durch die Ausführung genau einer Funktion aus branches entsteht, je nach dem Wert von index. Formaler ausgedrückt: result = selected_branch(), wobei gilt:

  • selected_branch = branches[index], wenn 0 <= index < size(branches).
  • Andernfalls selected_branch = branches[-1].

Eingaben

Label Name Typ Einschränkungen
(I1) index 0-dimensionaler Tensor vom Typ si32
(I2) branches variadische Anzahl von Funktionen (C1–C4)

Ausgaben

Name Typ Einschränkungen
results Variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C4)

Einschränkungen

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Beispiele

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
  "stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]

 Weitere Beispiele

cbrt

Semantik

Führt eine elementweise Kubikwurzeloperation für den operand-Tensor aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: rootn(x, 3) aus IEEE-754.
  • Für komplexe Zahlen: komplexe Kubikwurzel.
  • Für quantisierte Typen: dequantize_op_quantize(cbrt, operand, type(result))

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]

 Weitere Beispiele

ceil

Semantik

Führt die elementweise Aufrundung des operand-Tensors aus und gibt einen result-Tensor aus. Implementiert den roundToIntegralTowardPositive-Vorgang aus der IEEE-754-Spezifikation. Für quantisierte Typen wird dequantize_op_quantize(ceil, operand, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

 Weitere Beispiele

cholesky

Semantik

Berechnet die Cholesky-Zerlegung einer Reihe von Matrizen.

Formaler ausgedrückt: Für alle i in index_space(result) ist result[i0, ..., iR-3, :, :] eine Cholesky-Zerlegung von a[i0, ..., iR-3, :, :] in Form einer unteren Dreiecksmatrix (wenn lower gleich true ist) oder einer oberen Dreiecksmatrix (wenn lower gleich false ist). Die Ausgabewerte im gegenüberliegenden Dreieck, d.h. im strikten oberen oder strikten unteren Dreieck, sind implementierungsabhängig.

Wenn es i gibt, für das die Eingabematrix keine hermitesche positiv definite Matrix ist, ist das Verhalten nicht definiert.

Für quantisierte Typen wird dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) a Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1–C3)
(I2) lower Konstante vom Typ i1

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Beispiele

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

einschränken

Semantik

Beschneidet jedes Element des operand-Tensors auf einen Minimal- und Maximalwert und gibt einen result-Tensor aus. Formaler ausgedrückt: result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), wobei min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Für quantisierte Typen wird dequantize_op_quantize(clamp, min, operand, max, type(result)) ausgeführt.

Die Festlegung einer Reihenfolge für komplexe Zahlen birgt überraschende Semantiken. Daher planen wir, die Unterstützung für komplexe Zahlen für diesen Vorgang in Zukunft zu entfernen (#560).

Eingaben

Label Name Typ Einschränkungen
(I1) min Tensor oder Tensor mit Tensor-Quantisierung (C1), (C3)
(I2) operand Tensor oder Tensor mit Tensor-Quantisierung (C1–C4)
(I3) max Tensor oder Tensor mit Tensor-Quantisierung (C2), (C3)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C4)

Einschränkungen

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Beispiele

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]

 Weitere Beispiele

collective_broadcast

Semantik

Senden Sie in jeder Prozessgruppe im StableHLO-Prozessraster den Wert des operand-Tensors vom Quellprozess an die Zielprozesse und erstellen Sie einen result-Tensor.

Bei der Operation wird das StableHLO-Prozessraster in process_groups aufgeteilt, das so definiert ist:

  • cross_replica(replica_groups), wenn channel_id <= 0.
  • cross_partition(replica_groups), wenn channel_id > 0.

Danach wird result@process so berechnet:

  • operand@process_groups[i, 0], wenn es ein i gibt, sodass der Prozess in process_groups[i] ist.
  • Andernfalls broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C3)
(I2) replica_groups variadische Anzahl von eindimensionalen Tensorkonstanten vom Typ si64 (C1), (C2)
(I3) channel_id Konstante vom Typ si64

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C3)

Einschränkungen

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, wobei N so definiert ist:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C3) type(result) = type(operand).

Beispiele

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantik

In jeder Prozessgruppe im StableHLO-Prozessraster wird der Wert des Tensors operand vom Quellprozess an den Zielprozess gesendet und ein Tensor result erstellt.

Bei der Operation wird das StableHLO-Prozessraster in process_groups aufgeteilt, das so definiert ist:

  • cross_replica(source_target_pairs), wenn channel_id <= 0.
  • cross_partition(source_target_pairs), wenn channel_id > 0.

Danach wird result@process so berechnet:

  • operand@process_groups[i, 0], wenn ein i vorhanden ist, sodass process_groups[i, 1] = process.
  • Andernfalls broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C5)
(I2) source_target_pairs 2‑dimensionaler Tensor vom Typ si64 (C1–C4)
(I3) channel_id Konstante vom Typ si64

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1)

Einschränkungen

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, wobei N so definiert ist:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C5) type(result) = type(operand).

Beispiele

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

 Weitere Beispiele

vergleichen

Semantik

Führt einen elementweisen Vergleich der Tensoren lhs und rhs gemäß comparison_direction und compare_type durch und gibt einen result-Tensor aus.

Die Werte von comparison_direction und compare_type haben die folgende Semantik:

Für boolesche und Ganzzahlelementtypen:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

Für Gleitkomma-Elementtypen mit compare_type = FLOAT implementiert der Vorgang die folgenden IEEE-754-Operationen:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

Für Gleitkomma-Elementtypen mit compare_type = TOTALORDER verwendet der Vorgang die Kombination aus totalOrder- und compareQuietEqual-Vorgängen aus IEEE-754.

Bei komplexen Elementtypen wird ein lexikografischer Vergleich von (real, imag)-Paaren mit den bereitgestellten comparison_direction und compare_type durchgeführt. Die Festlegung einer Reihenfolge für komplexe Zahlen birgt überraschende Semantiken. Daher planen wir, die Unterstützung für komplexe Zahlen in Zukunft zu entfernen, wenn comparison_direction GE, GT, LE oder LT ist (#560).

Für quantisierte Typen wird dequantize_compare(lhs, rhs, comparison_direction) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder Tensor mit Tensor-Quantisierung (C1–C3)
(I2) rhs Tensor oder Tensor mit Tensor-Quantisierung (C1-C2)
(I3) comparison_direction Enumeration von EQ, NE, GE, GT, LE und LT
(I4) compare_type enum von FLOAT, TOTALORDER, SIGNED und UNSIGNED (C3)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Boolesch“ (C2)

Einschränkungen

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type wird so definiert:
    • SIGNED, wenn is_signed_integer(element_type(lhs)).
    • UNSIGNED, wenn is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT oder TOTALORDER, wenn is_float(element_type(lhs)).
    • FLOAT, wenn is_complex(element_type(lhs)).

Beispiele

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = <#stablehlocomparison_di>rection LT,
  compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]

 Weitere Beispiele

komplex

Semantik

Führt die elementweise Konvertierung in einen komplexen Wert aus einem Paar von reellen und imaginären Werten, lhs und rhs, durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ f32 oder f64 (C1–C3)
(I2) rhs Tensor vom Typ f32 oder f64 (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom komplexen Typ (C2), (C3)

Einschränkungen

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) hat den Typ complex<E>, wobei E = element_type(lhs).

Beispiele

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]

 Weitere Beispiele

Zusammengesetzt

Semantik

Kapselt einen Vorgang, der aus anderen StableHLO-Vorgängen besteht, wobei inputs und composite_attributes verwendet und results ausgegeben werden. Die Semantik des Vorgangs wird durch das Attribut decomposition implementiert. Der Vorgang composite kann durch seine Zerlegung ersetzt werden, ohne dass sich die Programmsemantik ändert. Wenn durch das Inlining der Dekomposition nicht dieselbe Semantik für den Vorgang erreicht wird, sollten Sie custom_call verwenden.

Das Feld version (Standardwert: 0) wird verwendet, um anzugeben, wann sich die Semantik eines Composites ändert.

Eingaben

Label Name Typ
(I1) inputs variable Anzahl von Werten
(I2) name Konstante vom Typ string
(I3) composite_attributes Attributwörterbuch
(I4) decomposition Konstante vom Typ string
(I5) version Konstante vom Typ si32

Ausgaben

Name Typ
results variable Anzahl von Werten

Einschränkungen

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Beispiele

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
 < ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32

 Weitere Beispiele

concatenate

Semantik

Verkettet inputs entlang der Dimension dimension in derselben Reihenfolge wie die angegebenen Argumente und gibt einen result-Tensor aus. Formaler ausgedrückt: result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], wobei gilt:

  1. id = d0 + ... + dk-1 + kd.
  2. d entspricht dimension und d0, ... sind die Größen der d-ten Dimension von inputs.

Eingaben

Label Name Typ Einschränkungen
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C6)
(I2) dimension Konstante vom Typ si64 (C2), (C4), (C6)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C5–C6)

Einschränkungen

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) außer dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) mit Ausnahme von:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Beispiele

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

 Weitere Beispiele

Konstante

Semantik

Erstellt einen output-Tensor aus einer konstanten value.

Eingaben

Label Name Typ Einschränkungen
(I1) value Konstante (C1)

Ausgaben

Name Typ Einschränkungen
output Tensor oder quantisierter Tensor (C1)

Einschränkungen

  • (C1) type(value) = type(output).

Beispiele

%output = "stablehlo.constant"() {
  val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]

 Weitere Beispiele

eine Conversion ausführen

Semantik

Führt eine elementweise Konvertierung von einem Elementtyp in einen anderen für den operand-Tensor aus und gibt einen result-Tensor zurück.

Bei boolean-to-any-supported-type-Konvertierungen wird der Wert false in null und der Wert true in eins konvertiert. Bei any-supported-type-to-boolean-Konvertierungen wird ein Nullwert in false und ein Wert ungleich null in true konvertiert. Unten erfahren Sie, wie das bei komplexen Typen funktioniert.

Bei Konvertierungen, die Ganzzahl-zu-Ganzzahl, Ganzzahl-zu-Gleitkomma oder Gleitkomma-zu-Gleitkomma umfassen, ist der Ergebniswert die genaue Darstellung des Quellwerts im Zieltyp, sofern der Quellwert im Zieltyp genau dargestellt werden kann. Andernfalls ist das Verhalten noch nicht definiert (#180).

Bei Konvertierungen, bei denen floating-point-to-integer umgewandelt werden, wird der Bruchteil abgeschnitten. Wenn der gekürzte Wert nicht im Zieltyp dargestellt werden kann, ist das Verhalten noch nicht definiert (#180).

Bei der Konvertierung von komplex zu komplex wird das gleiche Verhalten wie bei der Konvertierung von Gleitkommazahl zu Gleitkommazahl verwendet, um Real- und Imaginärteile zu konvertieren.

Bei complex-to-any-other-type- und any-other-type-to-complex-Konvertierungen wird der imaginäre Quellwert ignoriert bzw. der imaginäre Zielwert auf null gesetzt. Die Konvertierung des Realteils folgt den Gleitkommakonvertierungen.

Grundsätzlich könnte dieser Vorgang die Dequantisierung (Konvertierung von quantisierten Tensoren in reguläre Tensoren), die Quantisierung (Konvertierung von regulären Tensoren in quantisierte Tensoren) und die Requantisierung (Konvertierung zwischen quantisierten Tensoren) ausdrücken. Derzeit haben wir jedoch dedizierte Vorgänge dafür: uniform_dequantize für den ersten Anwendungsfall und uniform_quantize für den zweiten und dritten Anwendungsfall. In Zukunft werden diese beiden Vorgänge möglicherweise in convert zusammengeführt (#1576).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor (C1)

Einschränkungen

  • (C1) shape(operand) = shape(result).

Beispiele

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

 Weitere Beispiele

Faltung

Semantik

Berechnet Punktprodukte zwischen Fenstern von lhs und Slices von rhs und gibt result aus. Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result aus lhs und rhs berechnet werden.

Faltung

Formaler ausgedrückt: Die Eingaben werden in Bezug auf lhs neu formuliert, um Zeiträume von lhs ausdrücken zu können:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Für diese Umformulierung werden die folgenden Hilfsfunktionen verwendet:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1], wobei j[d] = i[permutation[d]].

Wenn feature_group_count = 1 und batch_group_count = 1, dann für alle output_spatial_index in index_space(dim(result, output_spatial_dimensions...)): result[result_shape(:, output_spatial_index, :)] = dot_product, wobei:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Diese Funktion wird offenbar nicht verwendet. Wir planen daher, sie in Zukunft zu entfernen (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Wenn feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Wenn batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Führt für quantisierte Typen dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)) aus.

Für hybride quantisierte Typen wird hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder Tensor mit Tensor-Quantisierung (C1), (C10–C11), (C14), (C25), (C27–C28), (C31–C32), (C34)
(I2) rhs Tensor oder quantisierter Tensor (C1), (C14–C16), (C25), (C27–C29), (C31–C34)
(I3) window_strides 1-dimensionaler Tensor vom Typ si64 (C2-C3), (C25)
(I4) padding 2‑dimensionaler Tensor vom Typ si64 (C4), (C25)
(I5) lhs_dilation 1-dimensionaler Tensor vom Typ si64 (C5–C6), (C25)
(I6) rhs_dilation 1-dimensionaler Tensor vom Typ si64 (C7–C8), (C25)
(I7) window_reversal 1-dimensionaler Tensor vom Typ i1 (C9)
(I8) input_batch_dimension Konstante vom Typ si64 (C10), (C13), (C25)
(I9) input_feature_dimension Konstante vom Typ si64 (C11), (C13–C14)
(I10) input_spatial_dimensions 1-dimensionaler Tensor vom Typ si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension Konstante vom Typ si64 (C14), (C18)
(I12) kernel_output_feature_dimension Konstante vom Typ si64 (C15–C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions 1-dimensionaler Tensor vom Typ si64 (C17–C18), (C25)
(I14) output_batch_dimension Konstante vom Typ si64 (C20), (C25)
(I15) output_feature_dimension Konstante vom Typ si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions 1-dimensionaler Tensor vom Typ si64 (C19–C20), (C25)
(I17) feature_group_count Konstante vom Typ si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count Konstante vom Typ si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config Variadische Anzahl von Enums von DEFAULT, HIGH und HIGHEST (C24)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C25–C28), (C30), (C32–34)

Einschränkungen

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Angenommen, input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Angesichts von kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Angesichts von output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) wird so definiert:
    • dim(lhs, input_batch_dimension) / batch_group_count, wenn result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension), wenn result_dim = output_feature_dimension.
    • Andernfalls num_windows, wobei gilt:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Wenn für den Vorgang quantisierte Tensoren verwendet werden:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Wenn is_per_axis_quantized(rhs), dann quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Wenn is_per_axis_quantized(result), dann quantization_dimension(result) = output_feature_dimension.
    • Wenn is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Wenn is_per_tensor_quantized(rhs), dann is_per_tensor_quantized(result).
    • Wenn !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Beispiele

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strid<es = arra>yi64: 4, 4,
  paddi<n>g = dense<0 : ten>sor2x2xi64,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  // In the StableHLO dialect, dimension numbers are encoded vi<a:
  // `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" a<re spatial dimensions.
  d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
  batch_group_count = 1 : i64,
  fea<ture_group_count >= 1 : i64,
 < precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

 Weitere Beispiele

Kosinus

Semantik

Führt eine elementweise Kosinusoperation für den operand-Tensor aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: cos aus IEEE-754.
  • Für komplexe Zahlen: komplexer Kosinus.
  • Für quantisierte Typen: dequantize_op_quantize(cosine, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]

 Weitere Beispiele

count_leading_zeros

Semantik

Führt eine elementweise Zählung der Anzahl der führenden Nullbits im operand-Tensor durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Ganzzahltyp (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Ganzzahltyp (C1)

Einschränkungen

  • (C1) type(operand) = type(result).

Beispiele

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]

 Weitere Beispiele

custom_call

Semantik

Kapselt einen implementierungsdefinierten Vorgang call_target_name, der inputs und called_computations verwendet und results erzeugt. has_side_effect, backend_config und api_version können verwendet werden, um zusätzliche implementierungsdefinierte Metadaten bereitzustellen.

Derzeit enthält dieser Vorgang eine ziemlich unorganisierte Sammlung von Metadaten, die die organische Entwicklung des entsprechenden Vorgangs im XLA-Compiler widerspiegelt. In Zukunft planen wir, diese Metadaten zu vereinheitlichen (#741).

Eingaben

Label Name Typ
(I1) inputs variable Anzahl von Werten
(I2) call_target_name Konstante vom Typ string
(I3) has_side_effect Konstante vom Typ i1
(I4) backend_config Konstante vom Typ string oder Attribut-Dictionary
(I5) api_version Konstante vom Typ si32
(I6) called_computations variadische Anzahl von Konstanten vom Typ string
(I7) output_operand_aliases Geben Sie die Aliasing-Teile in den Ausgaben und Operanden an.

Ausgaben

Name Typ
results variable Anzahl von Werten

(XLA-GPU-Unterstützung) Spezielle custom_call-Ziele

Es gibt drei spezielle call_target_name im Zusammenhang mit buffer-Typen: CreateBuffer erstellt eine nicht initialisierte buffer, Pin erstellt eine initialisierte buffer und Unpin gibt eine buffer frei und gibt den Inhalt der buffer zurück.

%uninitialized_buffer = "stablehlo.custom_call"() {
  call_target_name = "CreateBuffer",
  api_version> = 4 : <i32,
>} : () - memref4xf64

%initialized_buffer = "stablehlo.custom_call"(%init_value) {
  call_target_name = "Pin&quo<t;,
 > ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64

%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
  call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
  api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64

Alias

Bei einigen custom_call-Vorgängen muss ein Teil der Ausgaben und ein Teil der Operanden denselben Speicher verwenden. Das lässt sich mit output_operand_aliases ausdrücken. Eine Alias-Paar-Darstellung besteht aus einer Liste von Ausgabetupel-Indexen, die den Ausgabeteil darstellen, und einem operand_index zusammen mit einer Liste von Operanden-Tupel-Indexen, die den Operandenteil darstellen. Die Liste der Ausgabeargument- oder Operanden-Tupelindizes ist leer, wenn der entsprechende Typ kein tuple-Typ ist. Sie kann für einen beliebig geschachtelten Tupeltyp beliebig lang sein. Dies ähnelt der XLA-Aliasdarstellung.

Der Ausgabeteil und der Eingabeteil in einem Aliaspaar müssen denselben Typ haben. Bei custom_call-Vorgängen, die nicht CreateBuffer, Pin und Unpin aufrufen, kann ein buffer-Operand in höchstens einem Aliaspaar und eine buffer-Ausgabe in einem Aliaspaar vorkommen.

Beispiele

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64

%updated_buffer = "stablehlo.custom_call"(%buffer) {
  call_target_name = "Update",
  api_version = 4 : i32,
  output_operand_aliases< = [
    #stablehlo.output_operand_aliasoutput_tuple_indices = [],
      operand_ind>ex = 0,
     < oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64

Dividieren

Semantik

Führt die elementweise Division der Tensoren „dividend“ lhs und „divisor“ rhs aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Bei Ganzzahlen: Ganzzahldivision, bei der der algebraische Quotient ohne Bruchteil zurückgegeben wird.
  • Für Gleitkommazahlen: division aus IEEE-754.
  • Für komplexe Zahlen: komplexe Division.
  • Für quantisierte Typen:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder Tensor mit quantisierten Werten pro Tensor (C1)
(I2) rhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder Tensor mit quantisierten Werten pro Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

 Weitere Beispiele

dot_general

Semantik

Berechnet die Skalarprodukte zwischen Slices von lhs und Slices von rhs und gibt einen result-Tensor aus.

Formaler ausgedrückt: result[result_index] = dot_product, wobei gilt:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index wobei size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) und size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Führt für quantisierte Typen dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)) aus.

Für hybride quantisierte Typen wird hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs) ausgeführt.

precision_config steuert den Kompromiss zwischen Geschwindigkeit und Genauigkeit für Berechnungen auf Accelerator-Back-Ends. Dies kann einer der folgenden Werte sein (derzeit sind die Semantiken dieser Enum-Werte nicht genau festgelegt, aber wir planen, dies in #755 zu ändern):

  • DEFAULT: Schnellste Berechnung, aber ungenaueste Annäherung an die ursprüngliche Zahl.
  • HIGH: Langsamere Berechnung, aber genauere Annäherung an die ursprüngliche Zahl.
  • HIGHEST: Langsamste Berechnung, aber genaueste Annäherung an die Originalzahl.

Ein DotAlgorithm definiert die Haupteigenschaften des Algorithmus, der zur Implementierung der Punktoperation verwendet wird. Dadurch wird auch die Genauigkeit definiert. Wenn die Felder für das Algorithmusattribut festgelegt sind, muss precision_config DEFAULT sein. DotAlgorithms haben keinen Standardwert, da die Standardparameter implementierungsdefiniert sind. Daher können alle Felder des Punktalgorithmus auf None gesetzt werden, um einen leeren Punktalgorithmus anzugeben, der stattdessen den Wert precision_config verwendet.

DotAlgorithm-Felder:

  • lhs_precision_type und rhs_precision_type: Die Genauigkeiten, auf die die linke und rechte Seite des Vorgangs gerundet werden. Die Genauigkeitstypen sind unabhängig von den Speichertypen der Ein- und Ausgabe.
  • accumulation_type die für die Akkumulierung verwendete Genauigkeit.
  • lhs_component_count, rhs_component_count und num_primitive_operations werden angewendet, wenn wir einen Algorithmus verwenden, der die linke und/oder rechte Seite in mehrere Komponenten zerlegt und mehrere „primitive“ Punktoperationen für diese Werte ausführt, in der Regel um eine höhere Genauigkeit zu emulieren (z. B. bfloat16-Datentyp für künstliche Intelligenz für Berechnungen mit höherer Genauigkeit nutzen: bf16_6x tf32_3x usw.). Bei Algorithmen ohne Dekomposition sollten diese Werte auf 1 gesetzt werden.
  • allow_imprecise_accumulation, um anzugeben, ob die Akkumulierung in niedrigerer Genauigkeit für einige Schritte zulässig ist (z.B. CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Beispielattribute für DotAlgorithm:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

Es liegt an den Implementierungen, welche Kombinationen unterstützt werden. Im Allgemeinen kann nicht garantiert werden, dass jeder Algorithmus von jedem Beschleunigertyp unterstützt wird. Wenn ein bestimmter Algorithmus nicht unterstützt wird, sollte ein Fehler ausgegeben werden, anstatt auf eine Alternative zurückzugreifen. Die StableHLO-Überprüfung erfolgt nach dem Best-Effort-Prinzip und verhindert die Verwendung von Algorithmen, die auf keiner Hardware unterstützt werden.

Einige unterstützte Algorithmuswerte finden Sie unter xla_data.proto > Algorithm. Im Ticket 2483 wird der Plan beschrieben, ein zentrales Dokument zu den unterstützten Algorithmen nach Backend zu erstellen.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder Tensor mit Tensor-Quantisierung (C5–C6), (C9–C10), (C12–C14), (C17–C18), (C20)
(I2) rhs Tensor oder quantisierter Tensor (C7–C10), (C12–C20)
(I3) lhs_batching_dimensions 1-dimensionaler Tensor vom Typ si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions 1-dimensionaler Tensor vom Typ si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions 1-dimensionaler Tensor vom Typ si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions 1-dimensionaler Tensor vom Typ si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config Variadische Anzahl von Enums von DEFAULT, HIGH und HIGHEST (C11), (C21)
(I8) lhs_precision_type FloatType oder TensorFloat32 (C21)
(I9) rhs_precision_type FloatType oder TensorFloat32 (C21)
(I10) accumulation_type FloatType oder TensorFloat32 (C21)
(I11) lhs_component_count Konstante vom Typ si32 (C21), (C22)
(I12) rhs_component_count Konstante vom Typ si32 (C21), (C23)
(I13) num_primitive_operations Konstante vom Typ si32 (C21), (C24)
(I14) allow_imprecise_accumulation Konstante vom Typ bool (C21)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C12), (C14), (C18–C20)

Einschränkungen

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
    • (C13) element_type(lhs) = element_type(rhs).
  • Wenn für den Vorgang quantisierte Tensoren verwendet werden:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Wenn is_per_axis_quantized(rhs), dann quantization_dimension(rhs) nicht in rhs_contracting_dimensions.
    • Wenn is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Wenn is_per_tensor_quantized(rhs), dann is_per_tensor_quantized(result).
    • Wenn !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • Wenn !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Beispiele

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #sta<blehlo.dot
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimension>s = [1]
  ,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
  algorithm = #stablehlo.dot<_algorithm
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation >= false
  
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

 Weitere Beispiele

dynamic_broadcast_in_dim

Semantik

Diese Operation ist funktional identisch mit der Operation broadcast_in_dim, aber die Ergebnisform wird dynamisch über output_dimensions angegeben.

Der Vorgang akzeptiert auch die optionalen Attribute known_expanding_dimensions und known_nonexpanding_dimensions, um statisches Wissen über das Expansionsverhalten von Dimensionen auszudrücken. Wenn nichts angegeben ist, wird davon ausgegangen, dass alle Dimensionen möglicherweise erweitert werden.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder quantisierter Tensor (C1–C2), (C5–C6), (C9)
(I2) output_dimensions 1‑dimensionaler Tensor vom Typ „Ganzzahl“ (C7)
(I3) broadcast_dimensions 1‑dimensionaler konstanter Tensor vom Typ „Ganzzahl“ (C2-C6)
(I4) known_expanding_dimensions 1‑dimensionaler konstanter Tensor vom Typ „Ganzzahl“ (C8–C9)
(I5) known_nonexpanding_dimensions 1‑dimensionaler konstanter Tensor vom Typ „Ganzzahl“ (C8–C9)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C1), (C3), (C5–C7)

Einschränkungen

  • (C1) element_type(result) wird so berechnet:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), mit der Ausnahme, dass quantization_dimension(operand), scales(operand) und zero_points(operand) von quantization_dimension(result), scales(result) und zero_points(result) abweichen können.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Für alle d in axes(operand):
    • dim(operand, d) = 1 oder
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Wenn is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Bei dim(operand, quantization_dimension(operand)) = 1, dann scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand).

Beispiele

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensio<ns = arra>yi64: 2, 1,
  known_expanding_dimensio<ns = a>rrayi64: 0,
  known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 Weitere Beispiele

dynamic_conv

Semantik

Diese Operation ist funktional identisch mit der Faltungsoperation, aber das Padding wird dynamisch über padding angegeben.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder Tensor mit Tensor-Quantisierung (C1), (C10–C11), (C14), (C25), (C26–C27), (C30–C31), (C33)
(I2) rhs Tensor oder quantisierter Tensor (C1), (C14–C16), (C26–C28), (C30–C33)
(I3) padding 2‑dimensionaler Tensor vom Typ „Ganzzahl“ (C4)
(I4) window_strides 1-dimensionaler Tensor vom Typ si64 (C2-C3)
(I5) lhs_dilation 1-dimensionaler Tensor vom Typ si64 (C5–C6)
(I6) rhs_dilation 1-dimensionaler Tensor vom Typ si64 (C7–C8)
(I7) window_reversal 1-dimensionaler Tensor vom Typ i1 (C9)
(I8) input_batch_dimension Konstante vom Typ si64 (C10), (C13)
(I9) input_feature_dimension Konstante vom Typ si64 (C11), (C13–C14)
(I10) input_spatial_dimensions 1-dimensionaler Tensor vom Typ si64 (C12), (C13)
(I11) kernel_input_feature_dimension Konstante vom Typ si64 (C14), (C18)
(I12) kernel_output_feature_dimension Konstante vom Typ si64 (C15–C16), (C18), (C28)
(I13) kernel_spatial_dimensions 1-dimensionaler Tensor vom Typ si64 (C17–C18)
(I14) output_batch_dimension Konstante vom Typ si64 (C20)
(I15) output_feature_dimension Konstante vom Typ si64 (C20), (C29)
(I16) output_spatial_dimensions 1-dimensionaler Tensor vom Typ si64 (C19–C20)
(I17) feature_group_count Konstante vom Typ si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count Konstante vom Typ si64 (C10), (C15), (C22), (C23)
(I19) precision_config Variadische Anzahl von Enums von DEFAULT, HIGH und HIGHEST (C24)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C25–C27), (C29), (C31–C33)

Einschränkungen

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Angenommen, input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Angesichts von kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Angesichts von output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) wird so definiert:
    • dim(lhs, input_batch_dimension) / batch_group_count, wenn result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension), wenn result_dim = output_feature_dimension.
    • Andernfalls num_windows, wobei gilt:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Wenn für die Operation nicht quantisierte Tensoren verwendet werden:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Wenn für den Vorgang quantisierte Tensoren verwendet werden:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Wenn is_per_axis_quantized(rhs), dann quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Wenn is_per_axis_quantized(result), dann quantization_dimension(result) = output_feature_dimension.
    • Wenn is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Wenn is_per_tensor_quantized(rhs), dann is_per_tensor_quantized(result).
    • Wenn !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Beispiele

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strid<es = arra>yi64: 4, 4,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  dimension_numbers = #stab<lehlo.convraw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions => [1, 2]
  ,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

 Weitere Beispiele

dynamic_gather

Semantik

Dieser Vorgang ist funktional identisch mit dem gather-Vorgang, wobei slice_sizes dynamisch als Wert angegeben wird.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1), (C7), (C10–C12), (C14)
(I2) start_indices Tensor vom Ganzzahltyp (C2), (C3), (C13)
(I3) slice_sizes 1‑dimensionaler Tensor vom Typ „Ganzzahl“ (C8), (C11–C13)
(I4) offset_dims 1-dimensionaler Tensor vom Typ si64 (C1), (C4–C5), (C13)
(I5) collapsed_slice_dims 1-dimensionaler Tensor vom Typ si64 (C1), (C6–C8), (C13)
(I6) start_index_map 1-dimensionaler Tensor vom Typ si64 (C3), (C9), (C10)
(I7) index_vector_dim Konstante vom Typ si64 (C2), (C3), (C13)
(I8) indices_are_sorted Konstante vom Typ i1

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C5), (C13–C14)

Einschränkungen

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) wobei:
    • batch_dim_sizes = shape(start_indices), mit der Ausnahme, dass die Dimensionsgröße von start_indices, die index_vector_dim entspricht, nicht enthalten ist.
    • offset_dim_sizes = shape(slice_sizes), mit der Ausnahme, dass die Dimensionsgrößen in slice_sizes, die collapsed_slice_dims entsprechen, nicht enthalten sind.
    • Mit combine wird batch_dim_sizes auf Achsen platziert, die batch_dims entsprechen, und offset_dim_sizes auf Achsen, die offset_dims entsprechen.
  • (C14) element_type(operand) = element_type(result).

Beispiele

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vect>or_dim = 2,
  indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

 Weitere Beispiele

dynamic_iota

Semantik

Diese Operation ist funktional identisch mit der iota-Operation, aber die Ergebnisform wird dynamisch über output_shape angegeben.

Eingaben

Label Name Typ Einschränkungen
(I1) output_shape 1‑dimensionaler Tensor vom Typ „Ganzzahl“ (C1), (C2)
(I2) iota_dimension si64 (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C2)

Einschränkungen

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Beispiele

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

 Weitere Beispiele

dynamic_pad

Semantik

Dieser Vorgang ist funktional identisch mit dem Vorgang pad, aber edge_padding_low, edge_padding_high und interior_padding werden dynamisch als Werte angegeben.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1), (C2), (C4)
(I2) padding_value 0-dimensionaler Tensor oder Tensor mit Tensor-Quantisierung (C1)
(I3) edge_padding_low 1‑dimensionaler Tensor vom Typ „Ganzzahl“ (C1), (C4)
(I4) edge_padding_high 1‑dimensionaler Tensor vom Typ „Ganzzahl“ (C1), (C4)
(I5) interior_padding 1‑dimensionaler Tensor vom Typ „Ganzzahl“ (C2-C4)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C3–C6)

Einschränkungen

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Beispiele

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Weitere Beispiele

dynamic_reshape

Semantik

Dieser Vorgang ist funktional identisch mit dem Vorgang reshape, aber die Ergebnisform wird dynamisch über output_shape angegeben.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder quantisierter Tensor (C1–C3)
(I2) output_shape 1‑dimensionaler Tensor vom Typ „Ganzzahl“ (C4)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C1–C4)

Einschränkungen

  • (C1) element_type(result) wird so berechnet:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), mit der Ausnahme, dass sich quantization_dimension(operand) und quantization_dimension(result) unterscheiden können.
  • (C2) size(operand) = size(result).
  • (C3) Wenn is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Beispiele

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]

 Weitere Beispiele

dynamic_slice

Semantik

Extrahiert einen Ausschnitt aus operand mithilfe von dynamisch berechneten Startindizes und gibt einen result-Tensor aus. start_indices enthält die Startindexe des Slices für jede Dimension, die möglicherweise angepasst wird, und slice_sizes enthält die Größen des Slices für jede Dimension. Formeller ausgedrückt: result[result_index] = operand[operand_index], wobei gilt:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1), (C2), (C4)
(I2) start_indices variadische Anzahl von 0-dimensionalen Tensoren vom Typ „integer“ (C2), (C3)
(I3) slice_sizes 1-dimensionaler Tensor vom Typ si64 (C2), (C4), (C5)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1), (C5)

Einschränkungen

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Beispiele

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

 Weitere Beispiele

dynamic_update_slice

Semantik

Erstellt einen result-Tensor, der dem operand-Tensor entspricht, mit der Ausnahme, dass das Segment, das bei start_indices beginnt, mit den Werten in update aktualisiert wird. Formaler wird result[result_index] so definiert:

  • update[update_index] wenn 0 <= update_index < shape(update), wobei:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • Andernfalls operand[result_index].

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1–C4), (C6)
(I2) update Tensor oder Tensor mit Tensor-Quantisierung (C2), (C3), (C6)
(I3) start_indices variadische Anzahl von 0-dimensionalen Tensoren vom Typ „integer“ (C4), (C5)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1)

Einschränkungen

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Beispiele

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
 < : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

 Weitere Beispiele

Exponentialfunktionen

Semantik

Führt eine elementweise Exponentialoperation für den operand-Tensor aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: exp aus IEEE-754.
  • Für komplexe Zahlen: komplexes Exponential.
  • Für quantisierte Typen: dequantize_op_quantize(exponential, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

 Weitere Beispiele

exponential_minus_one

Semantik

Führt eine elementweise Operation „Exponential minus eins“ für den Tensor operand aus und gibt einen Tensor result zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: expm1 aus IEEE-754.
  • Für komplexe Zahlen: komplexes Exponential minus 1.
  • Für quantisierte Typen: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]

 Weitere Beispiele

fft

Semantik

Führt die Vorwärts- und Rücktransformation der Fourier-Transformation für reelle und komplexe Ein-/Ausgaben aus.

fft_type ist einer der folgenden Werte:

  • FFT: Leitet die komplexe FFT weiter.
  • IFFT: Inverse komplex-zu-komplex-FFT.
  • RFFT: Forward Real-to-Complex FFT.
  • IRFFT: Inverse Real-to-Complex-FFT (d.h. sie nimmt komplexe Zahlen entgegen und gibt reelle Zahlen zurück).

Formaler ausgedrückt: Angenommen, die Funktion fft, die eindimensionale Tensoren komplexer Typen als Eingabe verwendet, eindimensionale Tensoren desselben Typs als Ausgabe erzeugt und die diskrete Fourier-Transformation berechnet:

Bei fft_type = FFT ist result das Endergebnis einer Reihe von L Berechnungen, wobei L = size(fft_length). Zum Beispiel für L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Außerdem wird die Funktion ifft mit derselben Typsignatur angegeben, die die Umkehrung von fft berechnet:

Bei fft_type = IFFT wird result als das Inverse der Berechnungen für fft_type = FFT definiert. Zum Beispiel für L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Außerdem wird die Funktion rfft betrachtet, die 1-dimensionale Tensoren von Gleitkommatypen akzeptiert, 1-dimensionale Tensoren von komplexen Typen mit derselben Gleitkommasemantik erzeugt und so funktioniert:

  • rfft(real_operand) = truncated_result, wobei
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

Wenn die diskrete Fourier-Transformation für reelle Operanden berechnet wird, definieren die ersten N/2 + 1 Elemente des Ergebnisses den Rest des Ergebnisses eindeutig. Das Ergebnis von rfft wird daher gekürzt, um die Berechnung redundanter Elemente zu vermeiden.

Bei fft_type = RFFT ist result das Endergebnis einer Reihe von L Berechnungen, wobei L = size(fft_length). Zum Beispiel für L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Angenommen, es gibt die Funktion irfft mit derselben Typsignatur, die die Inverse von rfft berechnet:

Bei fft_type = IRFFT wird result als das Inverse der Berechnungen für fft_type = RFFT definiert. Zum Beispiel für L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ (C1), (C2), (C4), (C5)
(I2) fft_type enum von FFT, IFFT, RFFT und IRFFT (C2), (C5)
(I3) fft_length 1-dimensionaler Tensor vom Typ si64 (C1), (C3), (C4)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ (C2), (C4), (C5)

Einschränkungen

  • (C1) size(fft_length) <= rank(operand).
  • (C2) Die Beziehung zwischen den Elementtypen operand und result variiert:
    • Wenn fft_type = FFT, element_type(operand) und element_type(result) denselben komplexen Typ haben.
    • Wenn fft_type = IFFT, element_type(operand) und element_type(result) denselben komplexen Typ haben.
    • Wenn fft_type = RFFT, element_type(operand) ein Gleitkommatyp und element_type(result) ein komplexer Typ mit derselben Gleitkommasemantik ist.
    • Wenn fft_type = IRFFT, element_type(operand) ein komplexer Typ und element_type(result) ein Gleitkommatyp mit derselben Gleitkommasemantik ist.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Wenn unter operand und result ein Tensor real vom Typ „Gleitkomma“ vorhanden ist, dann shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand), mit Ausnahme von:
    • Wenn fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Wenn fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Beispiele

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = <#stablehloff>t_type FFT,
  fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

Boden

Semantik

Führt die elementweise Abrundung des operand-Tensors aus und gibt einen result-Tensor aus. Implementiert den roundToIntegralTowardNegative-Vorgang aus der IEEE-754-Spezifikation. Für quantisierte Typen wird dequantize_op_quantize(floor, operand, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

 Weitere Beispiele

gather

Semantik

Sammelt Slices aus dem operand-Tensor anhand der in start_indices angegebenen Offsets und gibt einen result-Tensor aus.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result Elementen in operand zugeordnet werden. Im Diagramm werden einige Beispiel-result Indizes ausgewählt und detailliert erläutert, welchen operand-Indizes sie entsprechen.

gather

Formaler ausgedrückt: result[result_index] = operand[operand_index], wobei gilt:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index wird so definiert:
    • start_indices[bi0, ..., :, ..., biN], wobei bi einzelne Elemente in batch_index sind und : am Index index_vector_dim eingefügt wird, wenn index_vector_dim < rank(start_indices).
    • Andernfalls [start_indices[batch_index]].
  • Für d_operand in axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) bei d_operand = start_index_map[d_start].
    • Andernfalls full_start_index[d_operand] = 0.
  • Für d_operand in axes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] bei d_operand = operand_batching_dims[i_batching] und d_start = start_indices_batching_dims[i_batching].
    • Andernfalls full_batching_index[d_operand] = 0.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], wobei oi einzelne Elemente in offset_index sind und 0 an den Indexen von collapsed_slice_dims und operand_batching_dims eingefügt wird.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Wenn indices_are_sorted gleich true ist, kann die Implementierung davon ausgehen, dass start_indices in Bezug auf start_index_map sortiert ist. Andernfalls ist das Verhalten nicht definiert. Formaler ausgedrückt: Für alle i1 < i2 aus indices(result) gilt full_start_index(i1) <= full_start_index(i2).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1), (C8), (C11), (C17), (C19–C21), (C23)
(I2) start_indices Tensor vom Ganzzahltyp (C2-C3), (C14), (C17), (C22)
(I3) offset_dims 1-dimensionaler Tensor vom Typ si64 (C1), (C4–C5), (C22)
(I4) collapsed_slice_dims 1-dimensionaler Tensor vom Typ si64 (C1), (C6–C9), (C22)
(I5) operand_batching_dims 1-dimensionaler Tensor vom Typ si64 (C1), (C6), (C10–C12), (C16–C18), (C22)
(I6) start_indices_batching_dims 1-dimensionaler Tensor vom Typ si64 (C13–C17)
(I7) start_index_map 1-dimensionaler Tensor vom Typ si64 (C3), (C18–C19)
(I8) index_vector_dim Konstante vom Typ si64 (C2-C3), (C15), (C22)
(I9) slice_sizes 1-dimensionaler Tensor vom Typ si64 (C9), (C12), (C20–C22)
(I10) indices_are_sorted Konstante vom Typ i1

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C5), (C22–C23)

Einschränkungen

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) wobei:
    • batch_dim_sizes = shape(start_indices), mit der Ausnahme, dass die Dimensionsgröße von start_indices, die index_vector_dim entspricht, nicht enthalten ist.
    • offset_dim_sizes = slice_sizes, mit der Ausnahme, dass die Dimensionsgrößen in slice_sizes, die collapsed_slice_dims und operand_batching_dims entsprechen, nicht enthalten sind.
    • Mit combine wird batch_dim_sizes auf Achsen platziert, die batch_dims entsprechen, und offset_dim_sizes auf Achsen, die offset_dims entsprechen.
  • (C23) element_type(operand) = element_type(result).

Beispiele

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vect>or_dim = 3,
  slice_siz<es = arrayi64: >1, 1, 2, 2,
  indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

 Weitere Beispiele

get_dimension_size

Semantik

Gibt die Größe des angegebenen dimension des operand zurück. Formaler ausgedrückt: result = dim(operand, dimension). Die Semantik bezieht sich nur auf die Formkomponente des Typs. Der Elementtyp kann ein beliebiger Typ sein.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder quantisierter Tensor (C1)
(I2) dimension Konstante vom Typ si64 (C1)

Ausgaben

Name Typ
result 0-dimensionaler Tensor vom Typ si32

Einschränkungen

  • (C1) 0 <= dimension < rank(operand).

Beispiele

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3

 Weitere Beispiele

get_tuple_element

Semantik

Extrahiert das Element an der Position index des operand-Tupels und gibt ein result aus. Formaler ausgedrückt: result = operand[index].

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tupel (C1), (C2)
(I2) index Konstante vom Typ si32 (C1), (C2)

Ausgaben

Name Typ Einschränkungen
result beliebiger Wert (C2)

Einschränkungen

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Beispiele

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]

 Weitere Beispiele

wenn

Semantik

Gibt die Ausgabe zurück, die durch die Ausführung genau einer Funktion aus true_branch oder false_branch generiert wird, je nach Wert von pred. Formaler ausgedrückt: result = pred ? true_branch() : false_branch().

Eingaben

Label Name Typ Einschränkungen
(I1) pred 0-dimensionaler Tensor vom Typ i1
(I2) true_branch Funktion (C1–C3)
(I3) false_branch Funktion (C1), (C2)

Ausgaben

Name Typ Einschränkungen
results Variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C3)

Einschränkungen

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Beispiele

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
  "stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10

 Weitere Beispiele

imag

Semantik

Extrahiert den imaginären Teil elementweise aus operand und gibt einen result-Tensor aus. Formal ausgedrückt: Für jedes Element x gilt: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ (C1), (C2)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp (C1), (C2)

Einschränkungen

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) wird so definiert:
    • complex_element_type(element_type(operand)), wenn is_complex(operand).
    • Andernfalls element_type(operand).

Beispiele

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]

 Weitere Beispiele

Infeed

Semantik

Liest Daten aus dem Infeed und erstellt results.

Die Semantik von infeed_config ist implementierungsdefiniert.

results besteht aus Nutzlastwerten, die zuerst kommen, und einem Token, das zuletzt kommt. In Zukunft planen wir, die Nutzlast und das Token in zwei separate Ausgaben aufzuteilen, um die Übersichtlichkeit zu verbessern (#670).

Eingaben

Label Name Typ
(I1) token token
(I2) infeed_config Konstante vom Typ string

Ausgaben

Name Typ Einschränkungen
results Variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C1–C3)

Einschränkungen

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) oder is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Beispiele

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

 Weitere Beispiele

iota

Semantik

Füllt einen output-Tensor mit Werten in aufsteigender Reihenfolge ab null entlang der iota_dimension-Dimension. Formeller ausgedrückt:

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Eingaben

Label Name Typ Einschränkungen
(I1) iota_dimension si64 (C1)

Ausgaben

Name Typ Einschränkungen
output Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) 0 <= iota_dimension < rank(output).

Beispiele

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

 Weitere Beispiele

is_finite

Semantik

Führt eine elementweise Prüfung durch, ob der Wert in x endlich ist (d.h. weder +Inf, -Inf noch NaN), und gibt einen y-Tensor aus. Implementiert den isFinite-Vorgang aus der IEEE-754-Spezifikation. Bei quantisierten Typen ist das Ergebnis immer true.

Eingaben

Label Name Typ Einschränkungen
(I1) x Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
y Tensor vom Typ „Boolesch“ (C1)

Einschränkungen

  • (C1) shape(x) = shape(y).

Beispiele

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]

 Weitere Beispiele

log

Semantik

Führt eine elementweise Logarithmusoperation für den operand-Tensor aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: log aus IEEE-754.
  • Für komplexe Zahlen: komplexer Logarithmus.
  • Für quantisierte Typen: dequantize_op_quantize(log, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

 Weitere Beispiele

log_plus_one

Semantik

Führt eine elementweise Logarithmus-plus-1-Operation für den operand-Tensor aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: logp1 aus IEEE-754.
  • Für komplexe Zahlen: complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))
  • Für quantisierte Typen: dequantize_op_quantize(log_plus_one, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

 Weitere Beispiele

logistisch

Semantik

Führt eine elementweise logistische Operation für den operand-Tensor aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: division(1, addition(1, exp(-x))) aus IEEE-754.
  • Für komplexe Zahlen: „complex logistic“.
  • Für quantisierte Typen: dequantize_op_quantize(logistic, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

 Weitere Beispiele

Karte

Semantik

Wendet eine Map-Funktion computation auf inputs entlang der dimensions an und gibt einen result-Tensor zurück.

Formaler ausgedrückt: result[result_index] = computation(inputs...[result_index]).

Eingaben

Label Name Typ Einschränkungen
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C4)
(I2) dimensions 1-dimensionaler Tensor vom Typ si64 (C3)
(I3) computation Funktion (C4)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1), (C4)

Einschränkungen

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation hat den Typ (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, wobei Ei = element_type(inputs[i]) und E' = element_type(result).

Beispiele

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
    stablehlo.return %<0 :> tensori64
}) {
  dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]

 Weitere Beispiele

Maximum

Semantik

Führt eine elementweise Max-Operation für die Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches ODER.
  • Für Ganzzahlen: maximaler Ganzzahlwert.
  • Für Gleitkommazahlen: maximum aus IEEE-754.
  • Für komplexe Zahlen: lexikografisches Maximum für das (real, imaginary)-Paar. Die Festlegung einer Reihenfolge für komplexe Zahlen birgt überraschende Semantiken. Daher planen wir, die Unterstützung für komplexe Zahlen für diesen Vorgang in Zukunft zu entfernen (#560).
  • Für quantisierte Typen:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder Tensor mit Tensor-Quantisierung (C1)
(I2) rhs Tensor oder Tensor mit Tensor-Quantisierung (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1)

Einschränkungen

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]

 Weitere Beispiele

Minimum

Semantik

Führt eine elementweise MIN-Operation für die Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches UND.
  • Für Ganzzahlen: Mindestwert für Ganzzahlen.
  • Für Gleitkommazahlen: minimum aus IEEE-754.
  • Für komplexe Zahlen: lexikografisches Minimum für das (real, imaginary)-Paar. Die Festlegung einer Reihenfolge für komplexe Zahlen birgt überraschende Semantiken. Daher planen wir, die Unterstützung für komplexe Zahlen für diesen Vorgang in Zukunft zu entfernen (#560).
  • Für quantisierte Typen:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder Tensor mit Tensor-Quantisierung (C1)
(I2) rhs Tensor oder Tensor mit Tensor-Quantisierung (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1)

Einschränkungen

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]

 Weitere Beispiele

multiplizieren

Semantik

Führt das elementweise Produkt von zwei Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches UND.
  • Bei Ganzzahlen: Ganzzahlmultiplikation.
  • Für Gleitkommazahlen: multiplication aus IEEE-754.
  • Für komplexe Zahlen: komplexe Multiplikation.
  • Für quantisierte Typen:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor oder Tensor mit Tensor-Quantisierung (C1)
(I2) rhs Tensor oder Tensor mit Tensor-Quantisierung (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]

 Weitere Beispiele

negate

Semantik

Führt die elementweise Negation des operand-Tensors aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Bei vorzeichenbehafteten Ganzzahlen: Negation von Ganzzahlen.
  • Bei vorzeichenlosen Ganzzahlen: Bitcast zu vorzeichenbehafteter Ganzzahl, Ganzzahlnegation, Bitcast zurück zu vorzeichenloser Ganzzahl.
  • Für Gleitkommazahlen: negate aus IEEE-754.
  • Für komplexe Zahlen: komplexe Negation.
  • Für quantisierte Typen: dequantize_op_quantize(negate, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]

 Weitere Beispiele

nicht

Semantik

Führt eine elementweise NOT-Operation für den Tensor operand aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches NICHT.
  • Für Ganzzahlen: bitweises NOT.

Argumente

Name Typ Einschränkungen
operand Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)

Einschränkungen

  • (C1) type(operand) = type(result).

Beispiele

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]

 Weitere Beispiele

optimization_barrier

Semantik

Sorgt dafür, dass die Vorgänge, die das operand erzeugen, vor allen Vorgängen ausgeführt werden, die vom result abhängen, und verhindert, dass Compiler-Transformationen Vorgänge über die Barriere hinweg verschieben. Ansonsten ist die Operation eine Identität, d.h. result = operand.

Argumente

Name Typ Einschränkungen
operand variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens pro Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens pro Tensor (C1)

Einschränkungen

  • (C1) type(operand...) = type(result...).

Beispiele

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0

 Weitere Beispiele

oder

Semantik

Führt eine elementweise ODER-Operation für zwei Tensoren lhs und rhs aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches ODER.
  • Für Ganzzahlen: bitweises OR.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ „Ganzzahl“ oder „Boolescher Wert“ (C1)
(I2) rhs Tensor vom Typ „Ganzzahl“ oder „Boolescher Wert“ (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“ oder „Boolescher Wert“ (C1)

Einschränkungen

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]

 Weitere Beispiele

Ausgang

Semantik

Schreibt inputs in den Outfeed und erzeugt ein result-Token.

Die Semantik von outfeed_config ist implementierungsdefiniert.

Eingaben

Label Name Typ
(I1) inputs variadische Anzahl von Tensoren oder quantisierten Tensoren
(I2) token token
(I3) outfeed_config Konstante vom Typ string

Ausgaben

Name Typ
result token

Beispiele

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token

 Weitere Beispiele

pad

Semantik

Erweitert operand durch Auffüllen des Tensors sowie zwischen den Elementen des Tensors mit dem angegebenen padding_value.

edge_padding_low und edge_padding_high geben die Menge an Padding an, die jeweils am unteren Ende (neben Index 0) und am oberen Ende (neben dem höchsten Index) jeder Dimension hinzugefügt wird. Der Betrag des Padding kann negativ sein. Der absolute Wert des negativen Padding gibt die Anzahl der Elemente an, die aus der angegebenen Dimension entfernt werden sollen.

interior_padding gibt den Abstand zwischen zwei Elementen in jeder Dimension an. Er darf nicht negativ sein. Die innere Auffüllung erfolgt vor der Auffüllung am Rand. Bei einer negativen Auffüllung am Rand werden Elemente aus dem inneren aufgefüllten Operanden entfernt.

Formaler wird result[result_index] so definiert:

  • operand[operand_index], wenn result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • Andernfalls padding_value.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1), (C2), (C4)
(I2) padding_value 0-dimensionaler Tensor oder Tensor mit Tensor-Quantisierung (C1)
(I3) edge_padding_low 1-dimensionaler Tensor vom Typ si64 (C1), (C4)
(I4) edge_padding_high 1-dimensionaler Tensor vom Typ si64 (C1), (C4)
(I5) interior_padding 1-dimensionaler Tensor vom Typ si64 (C2-C4)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C3–C6)

Einschränkungen

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Beispiele

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_l<ow = arra>yi64: 0, 1,
  edge_padding_hi<gh = arra>yi64: 2, 1,
  interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

 Weitere Beispiele

partition_id

Semantik

Erstellt partition_id des aktuellen Prozesses.

Ausgaben

Name Typ
result 0-dimensionaler Tensor vom Typ ui32

Beispiele

%result = "stablehlo.partition_id">;() : (<) - >tensorui32

 Weitere Beispiele

popcnt

Semantik

Führt eine elementweise Zählung der Anzahl der im operand-Tensor festgelegten Bits durch und gibt einen result-Tensor aus.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Ganzzahltyp (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Ganzzahltyp (C1)

Einschränkungen

  • (C1) type(operand) = type(result).

Beispiele

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]

 Weitere Beispiele

Leistung

Semantik

Führt die elementweise Potenzierung des lhs-Tensors mit dem rhs-Tensor aus und gibt einen result-Tensor aus. Je nach Elementtyp wird Folgendes ausgeführt:

  • Bei Ganzzahlen: Ganzzahlexponentiation.
  • Für Gleitkommazahlen: pow aus IEEE-754.
  • Für komplexe Zahlen: komplexe Potenzierung.
  • Für quantisierte Typen: dequantize_op_quantize(power, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)
(I2) rhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

 Weitere Beispiele

real

Semantik

Extrahiert den Realteil elementweise aus operand und gibt einen result-Tensor zurück. Formal ausgedrückt: Für jedes Element x gilt: real(x) = is_complex(x) ? real_part(x) : x.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ (C1), (C2)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp (C1), (C2)

Einschränkungen

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) wird so definiert:
    • complex_element_type(element_type(operand)), wenn is_complex(operand).
    • Andernfalls element_type(operand).

Beispiele

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]

 Weitere Beispiele

recv

Semantik

Empfängt Daten von einem Channel mit channel_id und gibt results aus.

Wenn is_host_transfer true ist, werden Daten vom Host übertragen. Andernfalls werden Daten von einem anderen Gerät basierend auf den Werten von source_target_pairs übertragen. Dieses Flag dupliziert die Informationen in channel_type. Daher planen wir, in Zukunft nur eines der beiden Flags beizubehalten (#666). Wenn is_host_transfer = false und source_target_pairs None oder leer ist, gilt dies als undefiniertes Verhalten.

results besteht aus Nutzlastwerten, die zuerst kommen, und einem Token, das zuletzt kommt. In Zukunft planen wir, die Nutzlast und das Token in zwei separate Ausgaben aufzuteilen, um die Übersichtlichkeit zu verbessern (#670).

Eingaben

Label Name Typ Einschränkungen
(I1) token token
(I2) channel_id Konstante vom Typ si64
(I3) channel_type Aufzählung von DEVICE_TO_DEVICE und DEVICE_TO_HOST (C5)
(I4) is_host_transfer Konstante vom Typ i1 (C5–C6)
(I5) source_target_pairs 2‑dimensionaler Tensor vom Typ si64 (C1–C4), (C6)

Ausgaben

Name Typ Einschränkungen
results Variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C2-C4)

Einschränkungen

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, wobei N so definiert ist:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C5) channel_type wird so definiert:
    • DEVICE_TO_HOST wenn is_host_transfer = true,
    • Andernfalls DEVICE_TO_DEVICE.

Beispiele

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)

 Weitere Beispiele

reduce

Semantik

Wendet eine Reduktionsfunktion body auf inputs und init_values entlang der dimensions an und gibt results-Tensoren aus.

Die Reihenfolge der Reduzierungen ist implementierungsabhängig. Das bedeutet, dass body und init_values ein Monoid bilden müssen, damit die Operation für alle Eingaben in allen Implementierungen dieselben Ergebnisse liefert. Diese Bedingung gilt jedoch nicht für viele beliebte Reduktionen. Die Gleitkomma-Addition für body und Null für init_values bilden beispielsweise kein Monoid, da die Gleitkomma-Addition nicht assoziativ ist.

Formaler ausgedrückt: results...[j0, ..., jR-1] = reduce(input_slices_converted), wobei gilt:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], wobei : an der Stelle dimensions eingefügt werden.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) für einen binären Baum schedule, wobei:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule ist ein implementierungsdefinierter vollständiger binärer Baum, dessen Inorder-Traversierung aus Folgendem besteht:
    • input_slices_converted...[index]-Werte für alle index in index_space(input_slices_converted) in aufsteigender lexikografischer Reihenfolge von index.
    • An implementierungsdefinierten Positionen mit einer implementierungsdefinierten Anzahl von init_values_converted.

Eingaben

Label Name Typ Einschränkungen
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C4), (C6), (C7)
(I2) init_values Variadische Anzahl von 0-dimensionalen Tensoren oder per Tensor quantisierte Tensoren (C2), (C3)
(I3) dimensions 1-dimensionaler Tensor vom Typ si64 (C4), (C5), (C7)
(I4) body Funktion (C6)

Ausgaben

Name Typ Einschränkungen
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C3), (C7), (C8)

Einschränkungen

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body hat den Typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), wobei is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), mit Ausnahme der Dimensionsgrößen von inputs..., die dimensions entsprechen.
  • (C8) element_type(results[i]) = Ei für alle i in [0,N).

Beispiele

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
  dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]

 Weitere Beispiele

reduce_precision

Semantik

Führt die elementweise Konvertierung von operand in einen anderen Gleitkommatyp mit exponent_bits und mantissa_bits und zurück in den ursprünglichen Gleitkommatyp aus und gibt einen output-Tensor zurück.

Formeller ausgedrückt:

  • Die Mantissenbits des ursprünglichen Werts werden aktualisiert, um den ursprünglichen Wert mit mantissa_bits-Semantik auf den nächsten mit roundToIntegralTiesToEven darstellbaren Wert zu runden.
  • Wenn mantissa_bits kleiner als die Anzahl der Mantissenbits des ursprünglichen Werts ist, werden die Mantissenbits auf mantissa_bits gekürzt.
  • Wenn die Exponentenbits des Zwischenergebnisses nicht in den von exponent_bits angegebenen Bereich passen, wird das Zwischenergebnis mit dem ursprünglichen Vorzeichen auf unendlich oder mit dem ursprünglichen Vorzeichen auf null gesetzt.
  • Führt für quantisierte Typen dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)) aus.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)
(I2) exponent_bits Konstante vom Typ si32 (C2)
(I3) mantissa_bits Konstante vom Typ si32 (C3)

Ausgaben

Name Typ Einschränkungen
output Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Beispiele

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

 Weitere Beispiele

reduce_scatter

Semantik

reduce_scatter

In jeder Prozessgruppe im StableHLO-Prozessraster wird die Reduzierung mit computations für die Werte des operand-Tensors aus jedem Prozess durchgeführt. Das Reduzierungsergebnis wird entlang scatter_dimension in Teile aufgeteilt und die aufgeteilten Teile werden zwischen den Prozessen verteilt, um den result zu erzeugen.

Bei der Operation wird das StableHLO-Prozessraster in process_groups aufgeteilt, das so definiert ist:

  • cross_replica(replica_groups) bei channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) bei channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) bei channel_id > 0 and use_global_device_ids = true.

Gehen Sie anschließend in jedem process_group so vor:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] für alle sender in process_group, wobei receiver_index = process_group.index(receiver).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1), (C2), (C7), (C8)
(I2) scatter_dimension Konstante vom Typ si64 (C1), (C2), (C8)
(I3) replica_groups 2‑dimensionaler Tensor vom Typ si64 (C3–C5)
(I4) channel_id Konstante vom Typ si64 (C6)
(I5) use_global_device_ids Konstante vom Typ i1 (C6)
(I6) computation Funktion (C7)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C8–C9)

Einschränkungen

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) wird so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_replicas, wenn cross_replica_and_partition verwendet wird.
    • num_processes, wenn flattened_ids verwendet wird.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Wenn use_global_device_ids = true, dann channel_id > 0.
  • (C7) computation hat den Typ (tensor<E>, tensor<E>) -> (tensor<E>), wobei is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand), außer:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
  %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
  "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimension = 1 :< i64,
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

 Weitere Beispiele

reduce_window

Semantik

Wendet eine Reduktionsfunktion body auf Fenster von inputs und init_values an und gibt results aus.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in results... aus inputs... berechnet werden.

reduce_window

Formaler ausgedrückt: results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (siehe reduce), wobei gilt:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Eingaben

Label Name Typ Einschränkungen
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values Variadische Anzahl von 0-dimensionalen Tensoren oder per Tensor quantisierte Tensoren (C1), (C13)
(I3) window_dimensions 1-dimensionaler Tensor vom Typ si64 (C4), (C5), (C15)
(I4) window_strides 1-dimensionaler Tensor vom Typ si64 (C6), (C7), (C15)
(I5) base_dilations 1-dimensionaler Tensor vom Typ si64 (C8), (C9), (C15)
(I6) window_dilations 1-dimensionaler Tensor vom Typ si64 (C10), (C11), (C15)
(I7) padding 2‑dimensionaler Tensor vom Typ si64 (C12), (C15)
(I8) body Funktion (C13)

Ausgaben

Name Typ Einschränkungen
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1), (C14–C16)

Einschränkungen

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body hat den Typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), wobei is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows wobei:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei für alle i in [0,N).

Beispiele

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
  wind>ow_dimensions = arrayi64: <2, 1,
  w>indow_strides = arrayi64: <4, 1,
  b>ase_dilations = arrayi64: 2,< 1,
  win>dow_dilations = arr<ayi64: 3, 1,
  p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]

 Weitere Beispiele

Restwert

Semantik

Berechnet den elementweisen Rest der Tensoren für Dividend lhs und Divisor rhs und gibt einen result-Tensor zurück.

Genauer gesagt, das Vorzeichen des Ergebnisses wird vom Dividenden übernommen und der absolute Wert des Ergebnisses ist immer kleiner als der absolute Wert des Divisors. Der Rest wird als lhs - d * rhs berechnet, wobei d so definiert ist:

  • Für Ganzzahlen: stablehlo.divide(lhs, rhs).
  • Für Gleitkommazahlen: division(lhs, rhs) aus IEEE-754 mit dem Rundungsattribut roundTowardZero.
  • Für komplexe Zahlen: TBD (#997).
  • Für quantisierte Typen:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Bei Gleitkomma-Elementtypen steht diese Operation im Gegensatz zur remainder-Operation aus der IEEE-754-Spezifikation, bei der d ein ganzzahliger Wert ist, der dem genauen Wert von lhs/rhs am nächsten liegt, wobei bei Gleichheit die gerade Zahl gewählt wird.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder Tensor mit quantisierten Werten pro Tensor (C1)
(I2) rhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder Tensor mit quantisierten Werten pro Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder Tensor mit quantisierten Werten pro Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]

 Weitere Beispiele

replica_id

Semantik

Erstellt replica_id des aktuellen Prozesses.

Ausgaben

Name Typ
result 0-dimensionaler Tensor vom Typ ui32

Beispiele

%result = "stablehlo.replica_id">;() : (<) - >tensorui32

 Weitere Beispiele

reshape

Semantik

Führt eine Umformung des operand-Tensors in einen result-Tensor durch. Konzeptionell entspricht das dem Beibehalten derselben kanonischen Darstellung, wobei sich die Form ändern kann, z.B. von tensor<2x3xf32> zu tensor<3x2xf32> oder tensor<6xf32>.

Formaler ausgedrückt: result[result_index] = operand[operand_index], wobei result_index und operand_index dieselbe Position in der lexikografischen Reihenfolge von index_space(result) und index_space(operand) haben.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder quantisierter Tensor (C1–C3)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C1–C3)

Einschränkungen

  • (C1) element_type(result) wird so berechnet:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), mit der Ausnahme, dass sich quantization_dimension(operand) und quantization_dimension(result) unterscheiden können.
  • (C2) size(operand) = size(result).
  • (C3) Wenn is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Beispiele

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]

 Weitere Beispiele

umkehren

Semantik

Kehrt die Reihenfolge der Elemente in operand entlang der angegebenen dimensions um und erzeugt einen result-Tensor. Formeller ausgedrückt: result[result_index] = operand[operand_index], wobei gilt:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 wenn d in dimensions.
  • Andernfalls operand_index[d] = result_index[d].

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1), (C3)
(I2) dimensions 1-dimensionaler Tensor vom Typ si64 (C2), (C3)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1), (C3)

Einschränkungen

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Beispiele

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]

 Weitere Beispiele

rng

Semantik

Generiert Zufallszahlen mit dem rng_distribution-Algorithmus und gibt einen result-Tensor mit einer bestimmten Form shape aus.

Falls rng_distribution = UNIFORM, werden die Zufallszahlen gemäß der Gleichverteilung über dem Intervall [a, b) generiert. Wenn a >= b, ist das Verhalten nicht definiert.

Wenn rng_distribution = NORMAL, werden die Zufallszahlen gemäß der Normalverteilung mit Mittelwert = a und Standardabweichung = b generiert. Wenn b < 0, ist das Verhalten nicht definiert.

Die genaue Art und Weise, wie Zufallszahlen generiert werden, ist implementierungsdefiniert. Sie können beispielsweise deterministisch sein oder nicht und einen verborgenen Status verwenden oder nicht.

In Gesprächen mit vielen Stakeholdern wurde diese Operation als effektiv eingestellt bezeichnet. Daher planen wir, sie in Zukunft zu entfernen (#597).

Eingaben

Label Name Typ Einschränkungen
(I1) a 0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolescher Wert“ oder „Gleitkommazahl“ (C1), (C2)
(I2) b 0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolescher Wert“ oder „Gleitkommazahl“ (C1), (C2)
(I3) shape 1-dimensionaler Tensor vom Typ si64 (C3)
(I4) rng_distribution Aufzählung von UNIFORM und NORMAL (C2)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“, „Boolescher Wert“ oder „Gleitkommazahl“ (C1–C3)

Einschränkungen

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Wenn rng_distribution = NORMAL, dann is_float(a).
  • (C3) shape(result) = shape.

Beispiele

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantik

Gibt einen mit gleichmäßigen Zufallsbits gefüllten output und einen aktualisierten Ausgabezustand output_state zurück, wobei der Pseudozufallszahlengenerator-Algorithmus rng_algorithm mit einem Anfangszustand initial_state verwendet wird. Die Ausgabe ist garantiert eine deterministische Funktion von initial_state, aber nicht garantiert deterministisch zwischen Implementierungen.

rng_algorithm ist einer der folgenden Werte:

  • DEFAULT: Implementierungsdefinierter Algorithmus.
  • THREE_FRY: Implementierungsdefinierte Variante des Threefry-Algorithmus.*
  • PHILOX: Implementierungsdefinierte Variante des Philox-Algorithmus.*

* Siehe Salmon et al. SC 2011. Parallele Zufallszahlen: So einfach wie das Einmaleins.

Eingaben

Label Name Typ Einschränkungen
(I1) rng_algorithm Aufzählung von DEFAULT, THREE_FRY und PHILOX (C2)
(I2) initial_state 1-dimensionaler Tensor vom Typ ui64 (C1), (C2)

Ausgaben

Name Typ Einschränkungen
output_state 1-dimensionaler Tensor vom Typ ui64 (C1)
output Tensor vom Ganzzahl- oder Gleitkommatyp

Einschränkungen

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) wird so definiert:
    • Implementierungsdefiniert, wenn rng_algorithm = DEFAULT.
    • 2, wenn rng_algorithm = THREE_FRY.
    • 2 oder 3, wenn rng_algorithm = PHILOX.

Beispiele

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantik

Führt eine elementweise Rundung auf die nächste Ganzzahl durch, wobei bei Gleichheit von null weg gerundet wird, und gibt einen result-Tensor zurück.operand Implementiert den roundToIntegralTiesToAway-Vorgang aus der IEEE-754-Spezifikation. Für quantisierte Typen wird dequantize_op_quantize(round_nearest_afz, operand, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

 Weitere Beispiele

round_nearest_even

Semantik

Führt eine elementweise Rundung auf die nächste Ganzzahl durch. Bei Gleichstand wird auf die gerade Ganzzahl gerundet. Das Ergebnis ist ein result-Tensor.operand Implementiert den roundToIntegralTiesToEven-Vorgang aus der IEEE-754-Spezifikation. Für quantisierte Typen wird dequantize_op_quantize(round_nearest_even, operand, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

 Weitere Beispiele

rsqrt

Semantik

Führt eine elementweise Operation zur Berechnung der reziproken Quadratwurzel für den operand-Tensor aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: rSqrt aus IEEE-754.
  • Für komplexe Zahlen: komplexe reziproke Quadratwurzel.
  • Für quantisierte Typen: dequantize_op_quantize(rsqrt, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

 Weitere Beispiele

aufteilen

Semantik

Erstellt results-Tensoren, die mit inputs-Tensoren identisch sind, mit der Ausnahme, dass mehrere von scatter_indices angegebene Slices mit den Werten updates aktualisiert werden, wobei update_computation verwendet wird.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in updates... Elementen in results... zugeordnet werden. Im Diagramm werden einige Beispiel-updates...-Indexe ausgewählt und detailliert erläutert, welchen results...-Indexen sie entsprechen.

aufteilen

Formeller ausgedrückt: Für alle update_index in index_space(updates[0]) gilt:

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index wird so definiert:
    • scatter_indices[si0, ..., :, ..., siN], wobei si einzelne Elemente in update_scatter_index sind und : am Index index_vector_dim eingefügt wird, wenn index_vector_dim < rank(scatter_indices).
    • Andernfalls [scatter_indices[update_scatter_index]].
  • Für d_input in axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start], wenn d_input = scatter_dims_to_operand_dims[d_start].
    • Andernfalls full_start_index[d_input] = 0.
  • Für d_input in axes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] bei d_input = input_batching_dims[i_batching] und d_start = scatter_indices_batching_dims[i_batching].
    • Andernfalls full_batching_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], wobei wi einzelne Elemente in update_window_index sind und 0 an den Indexen von inserted_window_dims und input_batching_dims eingefügt wird.
  • result_index = full_start_index + full_batching_index + full_window_index.

Daher gilt: results = exec(schedule, inputs), wobei:

  • schedule ist eine implementierungsdefinierte Permutation von index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results), wobei:
    • Wenn result_index innerhalb des Bereichs für shape(results...) liegt
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results ist eine Kopie von results, wobei results...[result_index] auf updated_values... gesetzt ist.
    • Andernfalls:
    • updated_results = results.
  • exec([], results) = results.

Wenn indices_are_sorted gleich true ist, kann die Implementierung davon ausgehen, dass scatter_indices in Bezug auf scatter_dims_to_operand_dims sortiert ist. Andernfalls ist das Verhalten nicht definiert. Formaler ausgedrückt: für alle i1 < i2 aus indices(result), full_start_index(i1) <= full_start_index(i2).

Wenn unique_indices gleich true ist, kann bei der Implementierung davon ausgegangen werden, dass alle result_index-Indizes, an die Daten verteilt werden, eindeutig sind. Wenn unique_indices gleich true ist, die Indexe, in die die Werte verteilt werden, aber nicht eindeutig sind, ist das Verhalten nicht definiert.

Eingaben

Label Name Typ Einschränkungen
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1), (C2), (C4–C6), (C11), (C13), (C18), (C21), (C23–C24)
(I2) scatter_indices Tensor vom Ganzzahltyp (C4), (C15), (C19), (C22)
(I3) updates variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C3–C6), (C8)
(I4) update_window_dims 1-dimensionaler Tensor vom Typ si64 (C2), (C4), (C7–C8)
(I5) inserted_window_dims 1-dimensionaler Tensor vom Typ si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims 1-dimensionaler Tensor vom Typ si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims 1-dimensionaler Tensor vom Typ si64 (C14–C18)
(I8) scatter_dims_to_operand_dims 1-dimensionaler Tensor vom Typ si64 (C19–C21)
(I9) index_vector_dim Konstante vom Typ si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted Konstante vom Typ i1
(I11) unique_indices Konstante vom Typ i1
(I12) update_computation Funktion (C23)

Ausgaben

Name Typ Einschränkungen
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C24-C25)

Einschränkungen

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes) wobei:
    • update_scatter_dim_sizes = shape(scatter_indices), mit Ausnahme der Dimensionsgröße von scatter_indices, die index_vector_dim entspricht.
    • update_window_dim_sizes <= shape(inputs[0]), mit Ausnahme der Dimensionsgrößen in inputs[0], die inserted_window_dims und input_batching_dims entsprechen.
    • combine platziert update_scatter_dim_sizes auf Achsen, die update_scatter_dims entsprechen, und update_window_dim_sizes auf Achsen, die update_window_dims entsprechen.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation hat den Typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), wobei is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei für alle i in [0,N).

Beispiele

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ],
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimensio<n_numbers = #stablehlo.scatter
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2>, 1],
    index_vector_dim = 3,
  indices_are_sorted = false,
  uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

 Weitere Beispiele

auswählen

Semantik

Erstellt einen result-Tensor, in dem jedes Element basierend auf dem Wert des entsprechenden Elements von pred aus dem on_true- oder on_false-Tensor ausgewählt wird. Formaler ausgedrückt: result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], wobei pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Für quantisierte Typen wird dequantize_select_quantize(pred, on_true, on_false, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) pred Tensor vom Typ i1 (C1)
(I2) on_true Tensor oder Tensor mit Tensor-Quantisierung (C1-C2)
(I3) on_false Tensor oder Tensor mit Tensor-Quantisierung (C2)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C2)

Einschränkungen

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Beispiele

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]

 Weitere Beispiele

select_and_scatter

Semantik

Die Werte aus dem source-Tensor werden mithilfe von scatter basierend auf dem Ergebnis von reduce_window des input-Tensors mit select verteilt und es wird ein result-Tensor erstellt.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result aus operand und source berechnet werden.

select_and_scatter

Formeller ausgedrückt:

  • selected_values = reduce_window_without_init(...) mit den folgenden Eingaben:

    • inputs = [operand].
    • window_dimensions, window_strides und padding, die unverändert verwendet werden.
    • base_dilations = windows_dilations = 1.
    • body ist so definiert:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    Dabei funktioniert E = element_type(operand) und reduce_window_without_init genau wie reduce_window, mit dem Unterschied, dass die schedule der zugrunde liegenden reduce (siehe reduce) keine Initialisierungswerte enthält. Es ist derzeit nicht angegeben, was passiert, wenn das entsprechende Fenster keine Werte hat (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter) wobei:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index, wenn selected_values[source_index] das operand-Element aus operand_index enthält.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1–C4), (C6), (C8–C11)
(I2) source Tensor oder Tensor mit Tensor-Quantisierung (C1), (C2)
(I3) init_value 0-dimensionaler Tensor oder Tensor mit Tensor-Quantisierung (C3)
(I4) window_dimensions 1-dimensionaler Tensor vom Typ si64 (C2), (C4), (C5)
(I5) window_strides 1-dimensionaler Tensor vom Typ si64 (C2), (C6), (C7)
(I6) padding 2‑dimensionaler Tensor vom Typ si64 (C2), (C8)
(I7) select Funktion (C9)
(I8) scatter Funktion (C10)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C11-C12)

Einschränkungen

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows wobei:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select hat den Typ (tensor<E>, tensor<E>) -> tensor<i1>, wobei E = element_type(operand).
  • (C10) scatter hat den Typ (tensor<E>, tensor<E>) -> tensor<E>, wobei is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Beispiele

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>E
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
  ^bb0(%<arg>0: tensori64, %arg1: tensori64):
    %0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
  >  "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
  window_dim<ensions => arrayi64: 3, 1,
  <window_strides => arrayi64<: 2, 1,>
  padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

 Weitere Beispiele

senden

Semantik

Sendet inputs an den Kanal channel_id. Die Eingaben werden dann in der von source_target_pairs angegebenen Reihenfolge an andere Geräte gesendet. Bei dem Vorgang wird ein result-Token generiert.

Wenn is_host_transfer true ist, werden Daten an den Host übertragen. Andernfalls werden Daten basierend auf den Werten von source_target_pairs auf ein anderes Gerät übertragen. Dieses Flag dupliziert die Informationen in channel_type. Daher planen wir, in Zukunft nur eines der beiden Flags beizubehalten (#666). Wenn is_host_transfer = false und source_target_pairs None oder leer ist, gilt dies als undefiniertes Verhalten.

Eingaben

Label Name Typ Einschränkungen
(I1) inputs variadische Anzahl von Tensoren oder quantisierten Tensoren
(I2) token token
(I3) channel_id Konstante vom Typ si64
(I4) channel_type Aufzählung von DEVICE_TO_DEVICE und DEVICE_TO_HOST (C5)
(I5) is_host_transfer Konstante vom Typ i1 (C5–C6)
(I6) source_target_pairs 2‑dimensionaler Tensor vom Typ si64 (C1–C4), (C6)

Ausgaben

Name Typ
result token

Einschränkungen

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, wobei N so definiert ist:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C5) channel_type wird so definiert:
    • DEVICE_TO_HOST wenn is_host_transfer = true,
    • Andernfalls DEVICE_TO_DEVICE.

Beispiele

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token

 Weitere Beispiele

shift_left

Semantik

Führt eine elementweise Linksverschiebung des Tensors lhs um rhs Bits aus und gibt einen Tensor result zurück.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Ganzzahltyp (C1)
(I2) rhs Tensor vom Ganzzahltyp (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Ganzzahltyp (C1)

Einschränkungen

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]

 Weitere Beispiele

shift_right_arithmetic

Semantik

Führt eine elementweise arithmetische Rechtsverschiebung des Tensors lhs um rhs Bits aus und gibt einen Tensor result zurück.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Ganzzahltyp (C1)
(I2) rhs Tensor vom Ganzzahltyp (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Ganzzahltyp (C1)

Einschränkungen

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]

 Weitere Beispiele

shift_right_logical

Semantik

Führt eine elementweise logische Rechtsverschiebung des Tensors lhs um rhs Bits aus und gibt einen Tensor result zurück.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Ganzzahltyp (C1)
(I2) rhs Tensor vom Ganzzahltyp (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Ganzzahltyp (C1)

Einschränkungen

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]

 Weitere Beispiele

Signieren

Semantik

Gibt das Vorzeichen des operand-Elements zurück und erzeugt einen result-Tensor. Formaler lässt sich die Semantik für jedes Element x mit der folgenden Python-Syntax ausdrücken:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Für quantisierte Typen wird dequantize_op_quantize(sign, operand, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Typ „Ganzzahl mit Vorzeichen“, „Gleitkomma“, „Komplex“ oder „Tensor mit per-Tensor-Quantisierung“ (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl mit Vorzeichen“, „Gleitkomma“, „Komplex“ oder „Tensor mit per-Tensor-Quantisierung“ (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

 Weitere Beispiele

Sinus

Semantik

Führt eine elementweise Sinusoperation für den Tensor operand aus und gibt einen Tensor result zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: sin aus IEEE-754.
  • Für komplexe Zahlen: komplexer Sinus.
  • Für quantisierte Typen: dequantize_op_quantize(sine, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]

 Weitere Beispiele

Slice

Semantik

Extrahiert einen Ausschnitt aus dem operand-Tensor mithilfe von statisch berechneten Startindizes und gibt einen result-Tensor aus. start_indices enthält die Startindexe des Slices für jede Dimension, limit_indices die Endindexe (ausschließlich) des Slices für jede Dimension und strides die Strides für jede Dimension.

Formaler ausgedrückt: result[result_index] = operand[operand_index], wobei operand_index = start_indices + result_index * strides.

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder Tensor mit Tensor-Quantisierung (C1–C3), (C5)
(I2) start_indices 1-dimensionaler Tensor vom Typ si64 (C2), (C3), (C5)
(I3) limit_indices 1-dimensionaler Tensor vom Typ si64 (C2), (C3), (C5)
(I4) strides 1-dimensionaler Tensor vom Typ si64 (C2), (C4)

Ausgaben

Name Typ Einschränkungen
result Tensor oder Tensor mit Tensor-Quantisierung (C1), (C5)

Einschränkungen

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Beispiele

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indic<es = arra>yi64: 1, 2,
  limit_indic<es = arra>yi64: 3, 4,
  strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

 Weitere Beispiele

Sortieren

Semantik

Sortiert eindimensionale Slices von inputs entlang der Dimension dimension zusammen nach einem comparator und gibt results aus.

Im Gegensatz zu ähnlichen Eingaben bei anderen Vorgängen sind bei dimension negative Werte zulässig, mit der unten beschriebenen Semantik. Aus Konsistenzgründen kann dies in Zukunft nicht mehr zulässig sein (#1377).

Wenn is_stable „true“ ist, ist die Sortierung stabil. Das bedeutet, dass die relative Reihenfolge von Elementen, die vom Comparator als gleich betrachtet werden, beibehalten wird. Bei einer einzelnen Eingabe gelten zwei Elemente e1 und e2 nur dann als gleich, wenn comparator(e1, e2) = comparator(e2, e1) = false. Unten finden Sie die Formalisierung, wie sich dies auf mehrere Eingaben verallgemeinert.

Formeller ausgedrückt: Für alle result_index in index_space(results[0]) gilt:

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], wobei riN einzelne Elemente in result_index sind und : bei adjusted_dimension eingefügt wird.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • Dabei sortiert sort einen eindimensionalen Ausschnitt in aufsteigender Reihenfolge. Es wird erwartet, dass comparator_together true zurückgibt, wenn das Argument auf der linken Seite kleiner als das zweite Argument auf der rechten Seite ist.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Eingaben

Label Name Typ Einschränkungen
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C5)
(I2) dimension Konstante vom Typ si64 (C4)
(I3) is_stable Konstante vom Typ i1
(I4) comparator Funktion (C5)

Ausgaben

Name Typ Einschränkungen
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C2), (C3)

Einschränkungen

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, wobei R = rank(inputs[0]).
  • (C5) comparator hat den Typ (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, wobei Ei = element_type(inputs[i]).

Beispiele

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
  dimension = 0 : i64,
<  is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

 Weitere Beispiele

sqrt

Semantik

Führt eine elementweise Quadratwurzeloperation für den Tensor operand aus und erzeugt einen Tensor result. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: squareRoot aus IEEE-754.
  • Für komplexe Zahlen: komplexe Quadratwurzel.
  • Für quantisierte Typen: dequantize_op_quantize(sqrt, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]

 Weitere Beispiele

subtract

Semantik

Führt eine elementweise Subtraktion von zwei Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Je nach Elementtyp wird Folgendes ausgeführt:

  • Bei Ganzzahlen: Subtraktion von Ganzzahlen.
  • Für Gleitkommazahlen: subtraction aus IEEE-754.
  • Für komplexe Zahlen: komplexe Subtraktion.
  • Für quantisierte Typen:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)
(I2) rhs Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Ganzzahl“, „Gleitkomma“ oder „Komplex“ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]

 Weitere Beispiele

tan

Semantik

Führt eine elementweise Tangensoperation für den Tensor operand aus und erzeugt einen Tensor result. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: tan aus IEEE-754.
  • Für komplexe Zahlen: komplexer Tangens.
  • Für quantisierte Typen: dequantize_op_quantize(tan, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

 Weitere Beispiele

tanh

Semantik

Führt eine elementweise Tangens-Hyperbolicus-Operation für den operand-Tensor aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für Gleitkommazahlen: tanh aus IEEE-754.
  • Für komplexe Zahlen: komplexer hyperbolischer Tangens.
  • Für quantisierte Typen:
    • dequantize_op_quantize(tanh, operand, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]

 Weitere Beispiele

Transponieren

Semantik

Permutiert die Dimensionen des operand-Tensors mit permutation und erzeugt einen result-Tensor. Formaler ausgedrückt: result[result_index] = operand[operand_index], wobei result_index[d] = operand_index[permutation[d]].

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor oder quantisierter Tensor (C1–C4)
(I2) permutation 1-dimensionaler Tensor vom Typ si64 (C2-C4)

Ausgaben

Name Typ Einschränkungen
result Tensor oder quantisierter Tensor (C1), (C3–C4)

Einschränkungen

  • (C1) element_type(result) wird so berechnet:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), mit der Ausnahme, dass sich quantization_dimension(operand) und quantization_dimension(result) unterscheiden können.
  • (C2) permutation ist eine Permutation von range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Wenn is_per_axis_quantized(result), dann quantization_dimension(operand) = permutation(quantization_dimension(result)).

Beispiele

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

 Weitere Beispiele

triangular_solve

Semantik

Löst Batches von linearen Gleichungssystemen mit unteren oder oberen Dreieckskoeffizientenmatrizen.

Formaler ausgedrückt: Bei gegebenen a und b ist result[i0, ..., iR-3, :, :] die Lösung für op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :], wenn left_side gleich true oder x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] ist, wenn left_side gleich false ist. Dabei wird die Variable x berechnet, wobei op(a) durch transpose_a bestimmt wird, was einer der folgenden Werte sein kann:

  • NO_TRANSPOSE: Vorgang mit a unverändert ausführen.
  • TRANSPOSE: Vorgang für die Transponierte von a ausführen.
  • ADJOINT: Vorgang für die konjugierte Transponierte von a ausführen.

Eingabedaten werden nur aus dem unteren Dreieck von a gelesen, wenn lower gleich true ist, andernfalls aus dem oberen Dreieck von a. Ausgabedaten werden im selben Dreieck zurückgegeben. Die Werte im anderen Dreieck sind implementierungsabhängig.

Wenn unit_diagonal „true“ ist, kann bei der Implementierung davon ausgegangen werden, dass die Diagonalelemente von a gleich 1 sind. Andernfalls ist das Verhalten nicht definiert.

Für quantisierte Typen wird dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)) ausgeführt.

Eingaben

Label Name Typ Einschränkungen
(I1) a Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1–C3)
(I2) b Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1–C4)
(I3) left_side Konstante vom Typ i1 (C3)
(I4) lower Konstante vom Typ i1
(I5) unit_diagonal Konstante vom Typ i1
(I6) transpose_a Aufzählung von NO_TRANSPOSE, TRANSPOSE und ADJOINT

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkomma- oder komplexen Typ oder per-Tensor quantisierter Tensor (C1)

Einschränkungen

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) Die Beziehung zwischen shape(a) und shape(b) wird so definiert:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Beispiele

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

Tupel

Semantik

Erstellt ein result-Tupel aus den Werten val.

Eingaben

Label Name Typ Einschränkungen
(I1) val variable Anzahl von Werten (C1)

Ausgaben

Name Typ Einschränkungen
result Tupel (C1)

Einschränkungen

  • (C1) result hat den Typ tuple<E0, ..., EN-1>, wobei Ei = type(val[i]).

Beispiele

// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))

 Weitere Beispiele

uniform_dequantize

Semantik

Führt die elementweise Konvertierung des quantisierten Tensors operand in einen Gleitkommatensor result gemäß den durch den Typ operand definierten Quantisierungsparametern aus.

Formaler ausgedrückt: result = dequantize(operand).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Quantisierter Tensor (C1), (C2)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Gleitkommatyp (C1), (C2)

Einschränkungen

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Beispiele

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]

uniform_quantize

Semantik

Führt die elementweise Konvertierung des Gleitkomma-Tensors oder des quantisierten Tensors operand in einen quantisierten Tensor result gemäß den durch den Typ result definierten Quantisierungsparametern durch.

Formeller ausgedrückt:

  • Wenn is_float(operand):
    • result = quantize(operand, type(result)).
  • Wenn is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Eingaben

Label Name Typ Einschränkungen
(I1) operand Tensor vom Gleitkomma- oder quantisierten Typ (C1), (C2)

Ausgaben

Name Typ Einschränkungen
result Quantisierter Tensor (C1), (C2)

Einschränkungen

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Beispiele

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]

während

Semantik

Erstellt die Ausgabe aus der Ausführung der Funktion body null oder mehrmals, während die Funktion cond true ausgibt. Die Semantik kann formaler mit der folgenden Python-Syntax ausgedrückt werden:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Das Verhalten einer Endlosschleife muss noch festgelegt werden (#383).

Eingaben

Label Name Typ Einschränkungen
(I1) operand variable Anzahl von Werten (C1–C3)
(I2) cond Funktion (C1)
(I3) body Funktion (C2)

Ausgaben

Name Typ Einschränkungen
results variable Anzahl von Werten (C3)

Einschränkungen

  • (C1) cond hat den Typ (T0, ..., TN-1) -> tensor<i1>, wobei Ti = type(operand[i]).
  • (C2) body hat den Typ (T0, ..., TN-1) -> (T0, ..., TN-1), wobei Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Beispiele

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_di<rection = #stablehlocom>parison_directio<n L>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    stablehlo.r<et>urn %cond : tensori1
  }, {
<  ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
    %new_sum = stablehlo.add <%ar>g1, %one : tensori64
    %new_i = stablehlo.add <%ar>g0, %one : tensori64
    stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10

 Weitere Beispiele

xor

Semantik

Führt eine elementweise XOR-Operation für zwei Tensoren lhs und rhs aus und gibt einen result-Tensor zurück. Je nach Elementtyp wird Folgendes ausgeführt:

  • Für boolesche Werte: logisches XOR.
  • Für Ganzzahlen: bitweises XOR.

Eingaben

Label Name Typ Einschränkungen
(I1) lhs Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)
(I2) rhs Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)

Ausgaben

Name Typ Einschränkungen
result Tensor vom Typ „Boolesch“ oder „Ganzzahl“ (C1)

Einschränkungen

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]

 Weitere Beispiele

Dialekt-Interop

Derzeit enthalten StableHLO-Programme manchmal Vorgänge, die nicht durch StableHLO definiert sind.

Modul, Funktion, Aufruf und Rückgabe

StableHLO verwendet Upstream-MLIR-Vorgänge für ModuleOp, FuncOp, CallOp und ReturnOp. Dies wurde für eine bessere Interoperabilität mit vorhandenen MLIR-Mechanismen durchgeführt, da viele nützliche Durchläufe auf FuncOp und ModuleOp ausgerichtet sind und viele Kompilierungspipelines das Vorhandensein dieser Vorgänge erwarten. Für diese Vorgänge gelten vollständige Kompatibilitätsgarantien. Sollte sich an diesen Vorgängen jemals etwas in inkompatibler Weise ändern (z.B. durch Entfernen), werden StableHLO-Entsprechungen hinzugefügt, um die Kompatibilität zu wahren.

CHLO

Das CHLO-Opset enthält Vorgänge auf höherer Ebene, die in StableHLO zerlegt werden. Derzeit gibt es keine Kompatibilitätsgarantien für CHLO. Um die Kompatibilität zu gewährleisten, muss der chlo-legalize-to-stablehlo-Pass vor der Serialisierung verwendet werden.

Formvorgänge

Es ist ein häufiger Anwendungsfall in der Community, bestimmte Operationen aus den MLIR-Kerndialekten in dynamischen StableHLO-Programmen für Formberechnungen zu verwenden. Am häufigsten sind das shape-Dialekt-Operationen wie shape_of oder num_elements, tensor-Dialekt-Operationen wie dim oder from_elements und der integrierte Typ index.

Im Dynamism RFC > O2 werden diese als nicht relevant bezeichnet. Aus Interoperabilitätsgründen ist jedoch eine gewisse Unterstützung für index-Typen enthalten. Für diese Vorgänge oder Typen gibt es keine Kompatibilitätsgarantien. Mit dem Pass shape-legalize-to-stablehlo können diese Vorgänge in vollständig unterstützte StableHLO-Vorgänge konvertiert werden.

Veraltete Vorgänge

Es gibt mehrere StableHLO-Vorgänge, die von MHLO übernommen wurden und die eingestellt werden. Vollständige Informationen zu diesen Entfernungen finden Sie hier. Das Tracker-Problem für diese Einstellungen ist #2340.

Diese Vorgänge lassen sich in einige Kategorien einteilen:

  • Kategorie „Nicht in HLO“ von StableHLO-Vorgängen – sie waren ursprünglich Teil des StableHLO-Opsets, wurden aber später als nicht passend eingestuft: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • Nicht verwendete Vorgänge: Diese Vorgänge waren möglicherweise irgendwann nützlich, aber sie waren entweder unterentwickelt oder die Pipelines, in denen sie verwendet wurden, wurden so umgestaltet, dass sie nicht mehr erforderlich sind. Dazu gehören map, tuple (#598), get_tuple_element, rng, complex-Vergleiche #560 und die Faltung window_reversal (#1181).

Einige dieser Vorgänge können problemlos entfernt werden, da sie mit vorhandenen Vorgängen (broadcast, create_token, cross-replica-sum, dot, unary_einsum) ausgedrückt werden können. Sie werden nach Ablauf des vorhandenen Kompatibilitätszeitraums (6 Monate) entfernt. Andere werden noch auf die Möglichkeit der Entfernung hin untersucht (einsum, get_tuple_element, map, rng torch_index_select, tuple, complex-Vergleiche, window_reversal). Je nach Community-Feedback werden diese Operationen entweder entfernt oder mit vollständiger Unterstützung in die Spezifikation aufgenommen. Bis diese Vorgänge bekannt sind, wird nur eine Kompatibilität von 6 Monaten garantiert.

Ausführung

Sequenzielle Ausführung

Ein StableHLO-Programm wird ausgeführt, indem der main-Funktion Eingabewerte bereitgestellt und Ausgabewerte berechnet werden. Ausgabewerte einer Funktion werden berechnet, indem der Graph der Operationen ausgeführt wird, der in der entsprechenden return-Operation verwurzelt ist.

Die Ausführungsreihenfolge ist implementierungsdefiniert, solange sie mit dem Datenfluss übereinstimmt, d.h. wenn Vorgänge vor ihrer Verwendung ausgeführt werden. In StableHLO verarbeiten alle Operationen mit Nebeneffekten ein Token und geben ein Token aus. Mehrere Tokens können über after_all in ein Token gemultiplext werden. Die Ausführungsreihenfolge von Nebeneffekten wird also auch am Datenfluss ausgerichtet. Im folgenden Programm gibt es beispielsweise zwei mögliche Ausführungsreihenfolgen: %0 → %1 → %2 → return und %1 → %0 → %2 → return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Formaler ausgedrückt ist ein StableHLO-Prozess eine Kombination aus: 1) einem StableHLO-Programm, 2) Vorgangsstatus (noch nicht ausgeführt, bereits ausgeführt) und 3) Zwischenwerten, die im Prozess verwendet werden. Der Prozess beginnt mit Eingabewerten für die main-Funktion, durchläuft den Diagramm der Vorgänge, in dem Vorgangsstatus und Zwischenwerte aktualisiert werden, und endet mit Ausgabewerten. Weitere Formalisierung ist noch nicht festgelegt (#484).

Parallele Ausführung

StableHLO-Programme können parallel ausgeführt werden. Sie sind in einem 2D-Prozessraster von num_replicas × num_partitions organisiert, wobei beide den Typ ui32 haben.

Im StableHLO-Prozessraster werden num_replicas * num_partitions StableHLO-Prozesse gleichzeitig ausgeführt. Jeder Prozess hat eine eindeutige process_id = (replica_id, partition_id), wobei replica_id in replica_ids = range(num_replicas) und partition_id in partition_ids = range(num_partitions) beide vom Typ ui32 sind.

Die Größe des Prozessrasters ist für jedes Programm statisch bekannt (in Zukunft soll sie expliziter Bestandteil von StableHLO-Programmen sein #650). Die Position im Prozessraster ist für jeden Prozess statisch bekannt. Jeder Prozess hat über die Vorgänge replica_id und partition_id Zugriff auf seine Position im Prozessraster.

Im Prozessraster können die Programme alle gleich sein („Einzelnes Programm, mehrere Daten“), alle unterschiedlich sein („Mehrere Programme, mehrere Daten“) oder etwas dazwischen. In Zukunft planen wir, Unterstützung für andere Idiome zur Definition paralleler StableHLO-Programme einzuführen, einschließlich GSPMD (#619).

Im Prozessraster sind die Prozesse größtenteils unabhängig voneinander. Sie haben separate Vorgangsstatus, separate Ein-/Zwischen-/Ausgabewerte und die meisten Vorgänge werden separat zwischen den Prozessen ausgeführt, mit Ausnahme einer kleinen Anzahl von kollektiven Vorgängen, die unten beschrieben werden.

Da bei der Ausführung der meisten Vorgänge nur Werte aus demselben Prozess verwendet werden, ist es in der Regel eindeutig, auf diese Werte anhand ihrer Namen zu verweisen. Bei der Beschreibung der Semantik kollektiver Operationen ist das jedoch nicht ausreichend. Daher wird die Notation name@process_id verwendet, um sich auf den Wert name innerhalb eines bestimmten Prozesses zu beziehen. Aus dieser Perspektive kann ein nicht qualifiziertes name als Abkürzung für name@(replica_id(), partition_id()) angesehen werden.

Die Ausführungsreihenfolge für Prozesse ist implementierungsdefiniert, mit Ausnahme der Synchronisierung, die durch die Punkt-zu-Punkt-Kommunikation und kollektive Operationen eingeführt wird, wie unten beschrieben.

Punkt-zu-Punkt-Kommunikation

StableHLO-Prozesse können über StableHLO-Channels miteinander kommunizieren. Ein Channel wird durch eine positive ID vom Typ si64 dargestellt. Über verschiedene Vorgänge ist es möglich, Werte an Channels zu senden und von Channels zu empfangen.

Weitere Formalisierungen, z.B. woher diese Kanal-IDs stammen, wie Prozesse auf sie aufmerksam werden und welche Art der Synchronisierung durch sie eingeführt wird, sind noch nicht festgelegt (#484).

Streaming-Kommunikation

Jeder StableHLO-Prozess hat Zugriff auf zwei Streaming-Schnittstellen:

  • Infeed, aus dem gelesen werden kann.
  • Outfeed, in den geschrieben werden kann.

Im Gegensatz zu Channels, die für die Kommunikation zwischen Prozessen verwendet werden und daher Prozesse an beiden Enden haben, ist das andere Ende von Infeeds und Outfeeds implementierungsdefiniert.

Weitere Formalisierungen, z.B. wie sich die Streaming-Kommunikation auf die Ausführungsreihenfolge auswirkt und welche Art von Synchronisierung dadurch eingeführt wird, sind noch nicht festgelegt (#484).

Kollektive Vorgänge

Es gibt sechs kollektive Operationen in StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute und reduce_scatter. Bei all diesen Vorgängen werden die Prozesse im StableHLO-Prozessraster in StableHLO-Prozessgruppen aufgeteilt und eine gemeinsame Berechnung innerhalb jeder Prozessgruppe unabhängig von anderen Prozessgruppen ausgeführt.

Innerhalb jeder Prozessgruppe können kollektive Operationen eine Synchronisationsbarriere einführen. Eine weitere Formalisierung, z.B. die genaue Beschreibung, wann diese Synchronisierung erfolgt, wie die Prozesse genau an diese Barriere gelangen und was passiert, wenn sie es nicht tun, steht noch aus (#484).

Wenn die Prozessgruppe die partitionierte Kommunikation umfasst, d.h. es Prozesse in der Prozessgruppe gibt, deren Partitions-IDs unterschiedlich sind, ist für die Ausführung des kollektiven Vorgangs ein Channel erforderlich und der kollektive Vorgang muss einen positiven channel_id vom Typ si64 bereitstellen. Für die replikaübergreifende Kommunikation sind keine Channels erforderlich.

Die Berechnungen, die von den kollektiven Operationen ausgeführt werden, sind spezifisch für einzelne Operationen und werden oben in den Abschnitten zu den einzelnen Operationen beschrieben. Die Strategien, nach denen das Prozessraster in Prozessgruppen aufgeteilt wird, sind jedoch für diese Vorgänge gleich und werden in diesem Abschnitt beschrieben. Genauer gesagt unterstützt StableHLO die folgenden vier Strategien.

cross_replica

Die Kommunikation zwischen Replikaten erfolgt nur innerhalb der einzelnen Prozessgruppen. Bei dieser Strategie wird replica_groups – eine Liste mit Listen von Replikat-IDs – verwendet und ein kartesisches Produkt von replica_groups und partition_ids berechnet. replica_groups muss eindeutige Elemente haben und alle replica_ids abdecken. Formaler ausgedrückt, in Python-Syntax:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Für replica_groups = [[0, 1], [2, 3]] und num_partitions = 2 wird mit cross_replica beispielsweise [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]] ausgegeben.

cross_partition

Die Kommunikation zwischen Partitionen erfolgt nur innerhalb der einzelnen Prozessgruppen. Bei dieser Strategie wird partition_groups verwendet, eine Liste von Listen mit Partitions-IDs, und es wird ein kartesisches Produkt von partition_groups und replica_ids berechnet. partition_groups muss eindeutige Elemente haben und alle partition_ids abdecken. Formaler ausgedrückt, in Python-Syntax:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Für partition_groups = [[0, 1]] und num_replicas = 4 wird mit cross_partition beispielsweise [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]] ausgegeben.

cross_replica_and_partition

Sowohl die replikats- als auch die partitionsübergreifende Kommunikation kann innerhalb jeder Prozessgruppe erfolgen. Bei dieser Strategie wird replica_groups verwendet, eine Liste mit Listen von Replikat-IDs, und es werden kartesische Produkte der einzelnen replica_group mit partition_ids berechnet. replica_groups muss eindeutige Elemente haben und alle replica_ids abdecken. Formaler ausgedrückt, in Python-Syntax:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Für replica_groups = [[0, 1], [2, 3]] und num_partitions = 2 wird mit cross_replica_and_partition beispielsweise [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]] ausgegeben.

flattened_ids

Diese Strategie verwendet flattened_id_groups – eine Liste von Listen mit „vereinfachten“ Prozess-IDs im Format replica_id * num_partitions + partition_id – und wandelt sie in Prozess-IDs um. flattened_id_groups muss eindeutige Elemente haben und alle process_ids abdecken. Formaler ausgedrückt, in Python-Syntax:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Für flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 und num_partitions = 2 wird mit flattened_ids beispielsweise [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]] generiert.

Genauigkeit

Derzeit gibt es in StableHLO keine Garantien für die numerische Genauigkeit. Das kann sich jedoch in Zukunft ändern (#1156).

Ausführungssemantik von quantisierten Vorgängen

Die Interpretation von quantisierten StableHLO-Vorgängen kann je nach Hardwareanforderungen und ‑funktionen variieren. Einige Hardware kann beispielsweise quantisierte Vorgänge mit der Strategie „Dequantisieren, Gleitkommaoperation ausführen und schließlich quantisieren“ interpretieren. Andere führen die gesamte Berechnung mit Ganzzahlarithmetik durch. Die Interpretation von quantisierten StableHLO-Operationen wird daher ausschließlich durch die jeweilige Implementierung bestimmt. Die Interpretation der hybriden Quantisierung (#1575) sollte auf der Semantik basieren, die in der Spezifikation (über 1792) vorgeschrieben ist.

Fehler

StableHLO-Programme werden anhand einer umfangreichen Reihe von Einschränkungen für einzelne Vorgänge validiert, wodurch viele Fehlerklassen vor der Laufzeit ausgeschlossen werden. Fehlerbedingungen sind jedoch weiterhin möglich, z.B. durch Ganzzahlüberläufe, Zugriffe außerhalb des zulässigen Bereichs usw. Sofern nicht explizit angegeben, führen alle diese Fehler zu einem implementierungsdefinierten Verhalten, dies kann sich jedoch in Zukunft ändern (#1157).

Gleitkommaausnahmen

Als Ausnahme von dieser Regel haben Gleitkomma-Ausnahmen in StableHLO-Programmen ein genau definiertes Verhalten. Operationen, die zu Ausnahmen führen, die durch den IEEE-754-Standard definiert sind (ungültige Operation, Division durch null, Überlauf, Unterlauf oder ungenaue Ausnahmen), erzeugen Standardergebnisse (wie im Standard definiert) und die Ausführung wird fortgesetzt, ohne das entsprechende Statusflag zu setzen. Dies ähnelt der raiseNoFlag-Ausnahmebehandlung aus dem Standard. Ausnahmen für nicht standardmäßige Operationen (z.B. komplexe Arithmetik und bestimmte transzendentale Funktionen) sind implementierungsdefiniert.

Formularfehler

StableHLO unterstützt dynamisch geformte Tensoren. Die Formen müssen jedoch zur Laufzeit übereinstimmen, da das Verhalten sonst nicht definiert ist. StableHLO bietet keinen expliziten Vorgang, mit dem zur Laufzeit geprüft werden kann, ob ein Tensor eine bestimmte Form hat. Die Verantwortung für die Generierung von korrektem Code liegt beim Produzenten.

Das folgende Programm ist ein konkretes Beispiel für ein gültiges Programm. Zur Laufzeit müssen die genauen Formen von %arg0 und %arg1 jedoch identisch sein, da das Verhalten des Programms ansonsten nicht definiert ist:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notation

Zur Beschreibung der Syntax wird in diesem Dokument die modifizierte ISO-Version der EBNF-Syntax (ISO/IEC 14977:1996, Wikipedia) verwendet, mit zwei Änderungen: 1) Regeln werden mit ::= anstelle von = definiert.

2) Die Verkettung wird durch Nebeneinanderstellung anstelle von , ausgedrückt.

Zur Beschreibung der Semantik (d.h. in den Abschnitten „Types“, „Constants“ und „Ops“) verwenden wir Formeln, die auf der Python-Syntax basieren und um Unterstützung für die präzise Darstellung von Array-Operationen erweitert wurden, wie unten beschrieben. Das funktioniert gut für kleine Code-Snippets. In seltenen Fällen, in denen größere Code-Snippets erforderlich sind, verwenden wir die reine Python-Syntax, die immer explizit eingeführt wird.

Formeln

Sehen wir uns an einem Beispiel aus der dot_general-Spezifikation an, wie Formeln funktionieren. Eine der Einschränkungen für diesen Vorgang sieht so aus: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Die in dieser Formel verwendeten Namen stammen aus zwei Quellen: 1) globale Funktionen, z.B. dim, 2) Mitgliedsdefinitionen des entsprechenden Programmelements, z.B. die Eingaben lhs, lhs_batching_dimensions, rhs und rhs_batching_dimensions, die im Abschnitt „Eingaben“ von dot_general definiert sind.

Wie oben erwähnt, basiert die Syntax dieser Formel auf Python und enthält einige auf Kürze ausgerichtete Erweiterungen. Um die Formel zu verstehen, wandeln wir sie in die normale Python-Syntax um.

A) In diesen Formeln verwenden wir = für Gleichheit. Der erste Schritt zur Erstellung der Python-Syntax besteht also darin, = durch == zu ersetzen:dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Außerdem unterstützen diese Formeln Auslassungspunkte (...), mit denen Skalarausdrücke in Tensorausdrücke umgewandelt werden. Kurz gesagt: f(xs...) bedeutet in etwa „Berechne für jeden Skalar x im Tensor xs einen Skalar f(x) und gib dann alle diese skalaren Ergebnisse zusammen als Tensorergebnis zurück“. In der reinen Python-Syntax wird unsere Beispielformel zu: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Dank Ellipsen ist es oft möglich, nicht auf der Ebene einzelner Skalare zu arbeiten. In einigen schwierigen Fällen kann jedoch eine Syntax auf niedrigerer Ebene verwendet werden, wie in der start_indices[bi0, ..., :, ..., biN]-Formel aus der gather-Spezifikation. Aus Gründen der Übersichtlichkeit stellen wir keinen genauen Formalismus für die Übersetzung einer solchen Syntax in reines Python bereit. Wir hoffen, dass sie im Einzelfall dennoch intuitiv verständlich ist. Wenn Sie bestimmte Formeln nicht nachvollziehen können, teilen Sie uns das bitte mit. Wir werden dann versuchen, sie zu verbessern.

Außerdem werden in Formeln Auslassungspunkte verwendet, um alle Arten von Listen zu erweitern, einschließlich Tensoren, Listen von Tensoren (die z.B. aus einer variablen Anzahl von Tensoren entstehen können) usw. Auch hier wird kein genauer Formalismus verwendet (z.B. sind Listen nicht einmal Teil des StableHLO-Typsystems), sondern auf intuitive Verständlichkeit gesetzt.

C) Das letzte bemerkenswerte Notationsmittel, das wir verwenden, ist das implizite Broadcasting. Das StableHLO-Opset unterstützt zwar kein implizites Broadcasting, die Formeln aber schon, was ebenfalls der Übersichtlichkeit dient. Kurz gesagt: Wenn ein Skalar in einem Kontext verwendet wird, in dem ein Tensor erwartet wird, wird der Skalar auf die erwartete Form übertragen.

Hier ist eine weitere Einschränkung für das dot_general-Beispiel: 0 <= lhs_batching_dimensions < rank(lhs). Wie in der dot_general-Spezifikation definiert, ist lhs_batching_dimensions ein Tensor, während 0 und rank(lhs) Skalare sind. Nachdem wir das implizite Broadcasting angewendet haben, lautet die Formel [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Wenn diese Formel auf einen bestimmten dot_general-Vorgang angewendet wird, ergibt sich ein Tensor mit booleschen Werten. Wenn Formeln als Einschränkungen verwendet werden, gilt die Einschränkung, wenn die Formel entweder true oder einen Tensor mit nur true-Elementen ergibt.

Namen

In Formeln umfasst der lexikalische Bereich: 1) globale Funktionen, 2) Mitgliederdefinitionen,

3. lokale Definitionen. Die Liste der globalen Funktionen finden Sie unten. Die Liste der Elementdefinitionen hängt vom Programmelement ab, auf das die Notation angewendet wird:

  • Für Vorgänge umfassen die Memberdefinitionen Namen, die in den Abschnitten „Eingaben“ und „Ausgaben“ eingeführt werden.
  • Für alles andere enthalten die Mitgliedsdefinitionen strukturelle Teile des Programmelements, die nach den entsprechenden EBNF-Nichtterminalen benannt sind. In den meisten Fällen werden die Namen dieser strukturellen Teile durch Konvertieren der Namen der nicht terminalen Symbole in Snake Case abgerufen (z. B. IntegerLiteral => integer_literal). Manchmal werden Namen dabei jedoch abgekürzt (z. B. QuantizationStorageType => storage_type). In diesem Fall werden die Namen explizit eingeführt, ähnlich wie in den Abschnitten „Eingaben“ / „Ausgaben“ in den Betriebsspezifikationen.
  • Außerdem enthalten Mitgliedsdefinitionen immer self, um auf das entsprechende Programmelement zu verweisen.

Werte

Bei der Auswertung von Formeln werden die folgenden Arten von Werten verwendet: 1) Value (tatsächliche Werte, z.B. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; ihre Typen sind immer bekannt), 2) Placeholder (zukünftige Werte, z.B. lhs, rhs oder result; ihre tatsächlichen Werte sind noch nicht bekannt, nur ihre Typen), 3) Type (Typen, wie im Abschnitt „Typen“ definiert), 4) Function (globale Funktionen, wie im Abschnitt „Funktionen“ definiert).

Je nach Kontext können sich Namen auf unterschiedliche Werte beziehen. Im Abschnitt „Semantik“ für Vorgänge (und Entsprechungen für andere Programmelemente) wird die Laufzeitlogik definiert. Daher sind alle Eingaben als Value verfügbar. Im Gegensatz dazu wird im Abschnitt „Constraints“ für Vorgänge (und Entsprechungen) die Logik für die Kompilierungszeit definiert, d. h. etwas, das normalerweise vor der Laufzeit ausgeführt wird. Daher sind nur konstante Eingaben als Value verfügbar und andere Eingaben nur als Placeholder.

Namen Unter „Semantik“ Unter „Einschränkungen“
Globale Funktionen Function Function
Konstante Eingaben Value Value
Nicht konstante Eingaben Value Placeholder
Ausgaben Value Placeholder
Lokale Definitionen Abhängig von der Definition Abhängig von der Definition

Sehen wir uns ein Beispiel für einen transpose-Vorgang an:

%result = "stablehlo.transpose"(%operand) {
  permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32

Bei diesem Vorgang ist permutation eine Konstante. Sie ist also sowohl in der Semantik als auch in den Einschränkungen als Value verfügbar. operand und result sind dagegen in der Semantik als Value, in den Einschränkungen aber nur als Placeholder verfügbar.

Funktionen

Konstruktion von Typen

Es gibt keine Funktionen, mit denen Typen erstellt werden können. Stattdessen verwenden wir direkt die Typsyntax, da sie in der Regel prägnanter ist. Beispiel: (tensor<E>, tensor<E>) -> (tensor<E>) statt function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Funktionen für Typen

  • element_type ist für Tensortypen und quantisierte Tensortypen definiert und gibt jeweils den TensorElementType- oder QuantizedTensorElementType-Teil des entsprechenden TensorType oder QuantizedTensorType zurück.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value ist eine Tastenkombination für is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value ist eine Abkürzung für is_quantized(x) and quantization_dimension(x) is None.

  • Mit is_promotable(x: Type, y: Type) -> bool wird geprüft, ob der Typ x auf den Typ y hochgestuft werden kann. Wenn x und y QuantizedTensorElementTypes sind, wird das Angebot nur auf die storage_type angewendet. Diese spezielle Version des Angebots wird derzeit im Rahmen der Berechnung der Ermäßigung verwendet. Weitere Informationen finden Sie im RFC.

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value ist eine Tastenkombination für is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Für alle Typen verfügbar. Beispiel: is_float(x) gibt true zurück, wenn x ein FloatType ist. Wenn x ein Wert oder Platzhalter ist, ist diese Funktion eine Abkürzung für is_type_name(type(x)).

  • max_value(x: Type) -> Value gibt den Maximalwert eines TensorElementType zurück. Wenn x kein TensorElementType ist, wird None zurückgegeben.

  • min_value(x: Type) -> Value gibt den kleinstmöglichen Wert eines TensorElementType zurück. Wenn x kein TensorElementType ist, wird None zurückgegeben.

  • member_name(x: Value | Placeholder | Type) -> Any. Für alle Memberdefinitionen member_name aller Typen verfügbar. tensor_element_type(x) gibt beispielsweise den TensorElementType-Teil eines entsprechenden TensorType zurück. Wenn x ein Wert oder Platzhalter ist, ist diese Funktion eine Abkürzung für member_name(type(x)). Wenn x kein Typ mit einem entsprechenden Element oder ein Wert oder Platzhalter eines solchen Typs ist, wird None zurückgegeben.

  • is_empty_algorithm(*args: Type) prüft, ob alle Felder des Punktalgorithmus auf None gesetzt sind. Das ist erforderlich, da für Punktalgorithmen implementierungsdefinierte Standardverhalten gelten. Die Angabe eines Standardwerts wäre also falsch.

Konstruktion von Werten

  • operation_name(*xs: Value | Type) -> Value. Für alle Vorgänge verfügbar. add(lhs, rhs) nimmt beispielsweise zwei Tensorwerte lhs und rhs entgegen und gibt die Ausgabe der Auswertung des Vorgangs add mit diesen Eingaben zurück. Bei einigen Vorgängen, z.B. broadcast_in_dim, sind die Typen ihrer Ausgaben „tragend“, d.h. sie werden zur Auswertung eines Vorgangs benötigt. In diesem Fall akzeptiert die Funktion diese Typen als Argumente.

Funktionen für Werte

  • Alle Operatoren und Funktionen von Python sind verfügbar. Sowohl die Abonnement- als auch die Slicing-Notation aus Python sind verfügbar, um Tensoren, quantisierte Tensoren und Tupel zu indexieren.

  • to_destination_type(x: Value, destination_type: Type) -> Value wird für Tensoren definiert und gibt den konvertierten Wert von x basierend auf type(x) und destination_type zurück:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Es gibt erste Diskussionen über das Zusammenführen der Vorgänge convert, uniform_quantize und uniform_dequantize (#1576). Nach dem Zusammenführen benötigen wir die obige Funktion nicht mehr und können stattdessen den Vorgangsnamen für convert verwenden.

  • is_nan(x: Value) -> Value wird für Tensoren definiert und gibt true zurück, wenn alle Elemente von x NaN sind, andernfalls false. Wenn x kein Tensor ist, wird None zurückgegeben.

  • is_sorted(x: Value) -> Value wird für Tensoren definiert und gibt true zurück, wenn die Elemente von x in aufsteigender Reihenfolge in Bezug auf die aufsteigende lexikografische Reihenfolge ihrer Indexe sortiert sind, andernfalls false. Wenn x kein Tensor ist, wird None zurückgegeben.

  • is_unique(x: Value) -> Value ist für Tensoren definiert und gibt true zurück, wenn x keine doppelten Elemente enthält, andernfalls false. Wenn x kein Tensor ist, wird None zurückgegeben.

  • member_name(x: Value) -> Any ist für alle Memberdefinitionen member_name aller Werte definiert. real_part(x) gibt beispielsweise den RealPart-Teil eines entsprechenden ComplexConstant zurück. Wenn x kein Wert mit einem entsprechenden Mitglied ist, wird None zurückgegeben.

  • same(x: Value) -> Value wird für Tensoren definiert und gibt true zurück, wenn alle Elemente von x gleich sind, andernfalls false. Wenn der Tensor keine Elemente hat, gilt das als „alle gleich“, d.h., die Funktion gibt true zurück. Wenn x kein Tensor ist, wird None zurückgegeben.

  • split(x: Value, num_results: Value, axis: Value) -> Value wird für Tensoren definiert und gibt num_results-Slices von x entlang der Achse axis zurück. Wenn x kein Tensor oder dim(x, axis) % num_results != 0 ist, wird None zurückgegeben.

  • is_defined_in_parent_scope(x: Value) -> Value ist für Strings definiert und gibt true zurück, wenn x der Name einer Funktion ist, die im selben Bereich wie die übergeordnete Funktion des relevanten Vorgangs definiert ist.

  • is_namespaced_op_name(x: Value) -> Value ist für Strings definiert und gibt true zurück, wenn x ein gültiger Operationsname ist, d. h., wenn er dem folgenden regulären Ausdruck entspricht: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Formberechnungen

  • axes(x: Value | Placeholder | Type) -> Value ist eine Tastenkombination für range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value ist eine Tastenkombination für shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List ist eine Tastenkombination für list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value wird für Tensoren definiert und gibt size(x)-Indizes für die entsprechenden TensorType zurück, die in aufsteigender lexikografischer Reihenfolge sortiert sind, d.h. [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Wenn x kein Tensortyp, quantisierter Tensortyp oder Wert oder Platzhalter eines dieser Typen ist, wird None zurückgegeben.

  • rank(x: Value | Placeholder | Type) -> Value ist eine Tastenkombination für size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value wird im Abschnitt „Funktionen für Typen“ über member_name definiert.

  • size(x: Value | Placeholder | Type) -> Value ist eine Tastenkombination für reduce(lambda x, y: x * y, shape(x)).

Berechnungen für die Quantisierung

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type ist eine Abkürzung für element_type(baseline_type(x)).

  • baseline_type wird für Tensortypen und quantisierte Tensortypen definiert und wandelt sie in einen „Baseline“-Typ um, d.h. einen Typ mit derselben Form, bei dem die Quantisierungsparameter des Elementtyps auf Standardwerte zurückgesetzt werden. Dies ist ein praktischer Trick, um sowohl Tensor- als auch quantisierte Tensortypen einheitlich zu vergleichen, was häufig erforderlich ist. Bei quantisierten Typen können so Typen verglichen werden, ohne die Quantisierungsparameter zu berücksichtigen. Das bedeutet, dass shape, storage_type, expressed_type, storage_min, storage_max und quantization_dimension (für den achsenbezogenen quantisierten Typ) alle übereinstimmen müssen, scales und zero points jedoch abweichen dürfen.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize wird für quantisierte Tensortypen definiert und wandelt sie in Gleitkomma-Tensortypen um. Dazu werden quantisierte Elemente, die Ganzzahlwerte des Speichertyps darstellen, mithilfe des Nullpunkts und des Skalierungsfaktors, die dem quantisierten Elementtyp zugeordnet sind, in entsprechende Gleitkommawerte des ausgedrückten Typs konvertiert.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize wird für Gleitkomma-Tensortypen definiert und wandelt sie in quantisierte Tensortypen um. Dazu werden Gleitkommawerte des angegebenen Typs mithilfe des Nullpunkts und des Skalierungsfaktors, die dem quantisierten Elementtyp zugeordnet sind, in entsprechende Ganzzahlwerte des Speichertyps umgewandelt.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize wird verwendet, um elementweise Berechnungen für quantisierte Tensoren anzugeben. Er dequantisiert, d.h. er wandelt quantisierte Elemente in ihre ausgedrückten Typen um, führt dann einen Vorgang aus und quantisiert dann, d.h. er wandelt die Ergebnisse wieder in ihre Speichertypen um. Derzeit funktioniert diese Funktion nur für die Quantisierung pro Tensor. Die achsenweise Quantisierung ist in Arbeit (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op wird verwendet, um die reine Gewichtsquantisierung für einen hybriden Vorgang anzugeben, bei dem der linke Operand vom Typ „Gleitkomma“ und der rechte Operand vom Typ „quantisiert“ ist. Die quantisierten Eingaben werden in ihre angegebenen Typen dequantisiert und die Berechnung wird in Gleitkommazahlen durchgeführt. Der Elementtyp des Gleitkomma-LHS-Tensors und der ausgedrückte Typ des quantisierten RHS-Tensors sollten identisch sein.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Rasterberechnungen

  • cross_partition(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt „cross_replica“.

  • cross_replica(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt „cross_replica“.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt „cross_replica_and_partition“.

  • flattened_ids(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt „flattened_ids“.

Dynamik

StableHLO-Werte können dynamische Dimensionsgrößen haben, z.B. tensor<?xi64>. StableHLO-Werte können jedoch keine dynamische Anzahl von Dimensionen haben (unranked Dynamism, z.B. tensor<*xi64>). Operanden und Ergebnisse dürfen dynamische Dimensionsgrößen verwenden, auch wenn es Einschränkungen für die Größen gibt. Einschränkungen werden, sofern möglich, statisch überprüft. Andernfalls werden sie auf die Laufzeit verschoben und Abweichungen führen zu undefiniertem Verhalten. Siehe Beispiele unten.

Formabweichungen bei unären elementweisen Operationen

Sehen Sie sich das folgende Beispielprogramm an:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Ein solches Programm ist ungewöhnlich, da es nicht üblich ist, die Form des Ergebnisses, aber nicht die Form der Eingabe zu kennen. Trotzdem ist dies ein gültiges StableHLO-Programm. Der Vorgang abs kann in diesem Programm nicht statisch validiert werden, da die genaue Form des Operanden unbekannt ist. Die Formen sind jedoch kompatibel und das kann statisch geprüft werden: ? kann zur Laufzeit 2 sein und es gäbe kein Problem. ? kann aber auch eine andere Ganzzahl sein. In diesem Fall ist das Verhalten nicht definiert.

Wenn die Größe einer Dimension im Ergebnis dynamisch ist, kann es kein undefiniertes Verhalten geben. Es gibt keine „erwartete“ Größe, daher kann es auch keine Abweichung geben.

Formabweichungen bei binären elementweisen Operationen

Sehen Sie sich das folgende Beispielprogramm an:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

Bei binären elementweisen Operationen müssen die Formen der Eingaben und des Ergebnisses zur Laufzeit übereinstimmen. Zur Kompilierzeit müssen statische Dimensionen gleich sein. Andernfalls müssen sie lediglich kompatibel sein. Wenn eine Dimension in den Eingaben dynamisch ist, kann es zur Laufzeit zu undefiniertem Verhalten kommen, da die dynamische Größe möglicherweise nicht mit der entsprechenden Größe im anderen Operanden (statisch oder dynamisch) übereinstimmt. Wenn alle Eingaben statisch sind, spielt es keine Rolle, ob das Ergebnis dynamisch ist oder nicht: Statisch bekannte Dimensionen werden statisch geprüft und dynamische Dimensionen unterliegen keinen Einschränkungen.

Formabweichungen für Vorgänge, bei denen die Ausgabegröße als Operand verwendet wird

Sehen Sie sich das folgende Beispielprogramm an:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Die Werte im Formoperanden zur Laufzeit müssen mit der Form des Ergebnisses übereinstimmen. Andernfalls ist das Verhalten nicht definiert. Das bedeutet, dass %arg0 zur Laufzeit den Wert dense<[3, 4]> : tensor<2xi32> haben muss. Wenn der Formoperand konstant ist, kann dies statisch überprüft werden. Wenn die Ergebnisform vollständig dynamisch ist, kann es keine Abweichung geben.