StableHLO-Spezifikation

Stabile HLO ist ein Vorgangssatz für High-Level-Vorgänge (HLO) in Modellen für maschinelles Lernen (ML). StableHLO dient als Übertragbarkeitsebene zwischen verschiedenen ML-Frameworks und ML-Compilern: ML-Frameworks, die StableHLO-Programme produzieren, sind mit ML-Compilern kompatibel, die StableHLO-Programme verwenden.

Unser Ziel ist es, die ML-Entwicklung zu vereinfachen und zu beschleunigen, indem wir mehr Interoperabilität zwischen verschiedenen ML-Frameworks (wie TensorFlow, JAX und PyTorch) und ML-Compilern (wie XLA und IREE) schaffen. Zu diesem Zweck enthält dieses Dokument eine Spezifikation für die Programmiersprache StableHLO.

Diese Spezifikation enthält drei Hauptabschnitte. Zuerst wird im Abschnitt Programme die Struktur von StableHLO-Programmen beschrieben, die aus StableHLO-Funktionen, die wiederum aus StableHLO-Vorgängen bestehen, bestehen. Innerhalb dieser Struktur legt der Abschnitt Ops die Semantik einzelner Vorgänge fest. Der Abschnitt Ausführung enthält die Semantik für alle diese Vorgänge, die zusammen innerhalb eines Programms ausgeführt werden. Schließlich wird im Abschnitt Notation die in der Spezifikation verwendete Notation erläutert.

Programme

Program ::= {Func}

StableHLO-Programme bestehen aus einer beliebigen Anzahl von StableHLO-Funktionen. Unten sehen Sie ein Beispielprogramm mit der Funktion @main, die 3 Eingaben (%image, %weights und %bias) und 1 Ausgabe hat. Der Hauptteil der Funktion hat 6 Operationen.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Funktionen

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

StableHLO-Funktionen (auch als benannte Funktionen bezeichnet) haben eine Kennung, Ein-/Ausgaben und einen Textkörper. Für die Zukunft planen wir, zusätzliche Metadaten für Funktionen einzuführen, um die Kompatibilität mit HLO zu verbessern (#425, #626, #740, #744).

IDs

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

StableHLO-Kennungen ähneln Kennungen in vielen Programmiersprachen, mit zwei Besonderheiten: 1) Alle Kennungen haben Siegel, die verschiedene Arten von Kennungen unterscheiden, 2) Wertkennungen können vollständig numerisch sein, um die Generierung von StableHLO-Programmen zu vereinfachen.

Typen

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

StableHLO-Typen werden in Werttypen (auch als erstklassige Typen bezeichnet) kategorisiert, die StableHLO-Werte darstellen und nicht wertbezogene Typen, die andere Programmelemente beschreiben. Stabile HLO-Typen ähneln Typen in vielen Programmiersprachen, wobei die Hauptbesonderheit die domainspezifische Art von StableHLO ist, was zu einigen ungewöhnlichen Ergebnissen führt (z.B. sind skalare Typen keine Werttypen).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

Tensor-Typen stehen für Tensoren, also mehrdimensionale Arrays. Sie haben eine Form und einen Elementtyp, wobei eine Form nicht negative Dimensionsgrößen in aufsteigender Reihenfolge der entsprechenden Abmessungen darstellt (auch als Achsen bezeichnet), die von 0 bis R-1 nummeriert sind. Die Anzahl der R-Dimensionen wird als Rang bezeichnet. Zum Beispiel ist tensor<2x3xf32> ein Tensortyp mit der Form 2x3 und dem Elementtyp f32. Sie hat zwei Dimensionen (oder mit anderen Worten, zwei Achsen) – 0. Dimension und 1. Dimension – deren Größe 2 und 3 sind. Es hat den Rang 2.

Damit wird die Unterstützung für statische Formen definiert, wenn die Dimensionsgrößen statisch bekannt sind. Für die Zukunft planen wir, auch dynamische Formen zu unterstützen, bei denen die Dimensionsgrößen teilweise oder vollständig unbekannt sind (Nr. 8). Außerdem planen wir, Tensortypen über Dimensionsgrößen und Elementtypen hinaus zu erweitern, um beispielsweise Layouts (#629) und Datendichte (#1078) einzubeziehen.

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Name Typ Einschränkung
storage_type Ganzzahltyp (C1–C4), (C9)
storage_min Ganzzahlkonstante (C2), (C4), (C8)
storage_max Ganzzahlkonstante (C3), (C4), (C8)
expressed_type Gleitkommatyp (C1), (C5)
quantization_dimension Optionale Ganzzahlkonstante (C11–C13)
scales variadische Zahl von Gleitkommakonstanten (C5-C7), (C10), (C11), (C13)
zero_points variadische Zahl ganzzahliger Konstanten (C8–C10)

Quantisierte Elementtypen stellen Ganzzahlwerte eines Speichertyps im Bereich von storage_min bis storage_max (einschließlich) dar, die Gleitkommawerten eines ausgedrückten Typs entsprechen. Für einen gegebenen ganzzahligen Wert i kann der entsprechende Gleitkommawert f als f = (i - zero_point) * scale berechnet werden, wobei scale und zero_point als Quantisierungsparameter bezeichnet werden. storage_min und storage_max sind in der Grammatik optional, haben aber die Standardwerte min_value(storage_type) bzw. max_value(storage_type). Quantisierte Elementtypen unterliegen den folgenden Einschränkungen:

  • (C1) num_bits(storage_type) < num_bits(expressed_type).
  • (C2) type(storage_min) = storage_type.
  • (C3) type(storage_max) = storage_type.
  • (C4) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C5) type(scales...) = expressed_type.
  • (C6) 0 < scales.
  • (C7) is_finite(scales...).
  • (C8) storage_min <= zero_points <= storage_max.
  • (C9) type(zero_points...) = storage_type.
  • (C10) size(scales) = size(zero_points).
  • (C11) Wenn is_empty(quantization_dimension), dann size(scales) = 1.
  • (C12) 0 <= quantization_dimension.

Derzeit ist QuantizationScale eine Gleitkommakonstante. Es besteht jedoch ein großes Interesse an ganzzahligen Skalen, die durch Multiplikatoren und Verschiebungen dargestellt werden. Dies ist jedoch für die Zukunft geplant (#1404).

Es wird derzeit über die Semantik von QuantizationZeroPoint erörtert, einschließlich Typ, Werte und ob es nur einen oder potenziell mehrere Nullpunkte in einem quantisierten Tensortyp geben kann. Basierend auf den Ergebnissen dieser Diskussion kann sich die Angabe um die Nullpunkte in Zukunft ändern (#1405).

In einer weiteren laufenden Diskussion geht es um die Semantik von QuantizationStorageMin und QuantizationStorageMax, um zu bestimmen, ob diese Werte und die Werte quantisierter Tensoren Einschränkungen unterliegen (#1406).

Außerdem planen wir, die Darstellung unbekannter Skalen und Nullpunkte zu untersuchen, ähnlich wie die Darstellung unbekannter Dimensionsgrößen (#1407).

Quantisierte Tensortypen stellen Tensoren mit quantisierten Elementen dar. Diese Tensoren sind genau die gleichen wie reguläre Tensoren, mit der Ausnahme, dass ihre Elemente quantisierte Elementtypen anstelle von regulären Elementtypen haben.

Bei quantisierten Tensoren kann die Quantisierung pro Tensor erfolgen, d. h., sie kann einen scale und zero_point für den gesamten Tensor oder pro Achse haben, d. h. mit mehreren scales und zero_points, einem Paar pro Segment einer bestimmten Dimension quantization_dimension. Formaler gibt es in einem Tensor-t mit Quantisierung pro Achse dim(t, quantization_dimension)-Slices von quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] usw. Alle Elemente im i. Slice verwenden scales[i] und zero_points[i] als Quantisierungsparameter. Quantisierte Tensortypen haben die folgenden Einschränkungen:

  • Für die Quantisierung pro Tensor:
    • Keine zusätzlichen Einschränkungen.
  • Für die Quantisierung pro Achse:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Tokentypen stellen Tokens dar, also intransparente Werte, die von einigen Vorgängen erzeugt und konsumiert werden. Tokens werden verwendet, um die Ausführungsreihenfolge für Vorgänge festzulegen, wie im Abschnitt Ausführung beschrieben.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Tupeltypen stellen Tupel dar, d.h. heterogene Listen. Tupel sind ein Legacy-Feature, das nur für die Kompatibilität mit HLO existiert. In HLO werden Tupel verwendet, um variadische Eingaben und Ausgaben darzustellen. In StableHLO werden variadische Eingaben und Ausgaben nativ unterstützt. In StableHLO werden Tupel nur vollständig verwendet, wobei z. B. T, tuple<T> und tuple<tuple<T>> je nach Implementierung erheblich voneinander abweichen können. Wir planen, in Zukunft Änderungen an HLO ABI vorzunehmen, die es uns ermöglichen könnten, Tupeltypen aus StableHLO zu entfernen (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Elementtypen stellen Elemente von Tensortypen dar. Im Gegensatz zu vielen Programmiersprachen sind diese Typen in StableHLO nicht die erste Klasse. Das bedeutet, dass StableHLO-Programme Werte dieser Typen nicht direkt darstellen können. Daher ist es idiomatisch, skalare Werte vom Typ T mit 0-dimensionalen Tensorwerten vom Typ tensor<T> darzustellen.

  • Boolescher Typ stellt die booleschen Werte true und false dar.
  • Ganzzahltypen können entweder vorzeichenbehaftet (si) oder vorzeichenlos (ui) sein und eine der unterstützten Bitbreiten (4, 8, 16, 32 oder 64) haben. Vorzeichenbehaftete Typen (siN) stellen Ganzzahlwerte von -2^(N-1) bis einschließlich 2^(N-1)-1 dar. Typen ohne Vorzeichen (uiN) stellen Ganzzahlwerte von 0 bis einschließlich 2^N-1 dar.
  • Gleitkommatypen können einer der folgenden sein:
  • Komplexe Typen stellen komplexe Werte dar, die einen reellen Teil und einen imaginären Teil desselben Elementtyps haben. Unterstützte komplexe Typen sind complex<f32> (beide Teile sind vom Typ f32) und complex<f64> (beide Teile sind vom Typ f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Funktionstypen stellen sowohl benannte als auch anonyme Funktionen dar. Sie haben Eingabetypen (die Liste der Typen auf der linken Seite von ->) und Ausgabetypen (die Liste der Typen auf der rechten Seite von ->). In vielen Programmiersprachen sind Funktionstypen erste Klasse, aber nicht in StableHLO.

StringType ::= 'string'

Der Stringtyp stellt Bytesequenzen dar. Im Gegensatz zu vielen Programmiersprachen ist der Stringtyp in StableHLO nicht die erste Klasse, sondern wird nur verwendet, um statische Metadaten für Programmelemente anzugeben.

Operations

StableHLO-Vorgänge (auch Vorgänge genannt) stellen eine geschlossene Gruppe von übergeordneten Vorgängen in Modellen für maschinelles Lernen dar. Wie oben erwähnt, orientiert sich die StableHLO-Syntax stark an MLIR, die nicht unbedingt die ergonomische Alternative ist, aber wohl am besten zum Ziel von StableHLO passt, das Ziel von StableHLO, mehr Interoperabilität zwischen ML-Frameworks und ML-Compilern zu erreichen.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

StableHLO-Vorgänge (die auch als Vorgänge bezeichnet werden) haben einen Namen, Ein-/Ausgaben und eine Signatur. Der Name besteht aus dem Präfix stablehlo. und einer Mnemonic, die eine der unterstützten Vorgänge eindeutig identifiziert. Unten finden Sie eine umfassende Liste aller unterstützten Vorgänge.

Derzeit enthalten StableHLO-Programme in der Wildnis manchmal Vorgänge, die in diesem Dokument nicht beschrieben werden. Für die Zukunft planen wir, diese Vorgänge entweder in das StableHLO-Opset aufzunehmen oder zu verhindern, dass sie in StableHLO-Programmen erscheinen. In der Zwischenzeit finden Sie hier eine Liste dieser Vorgänge:

  • builtin.module, func.func, func.call und func.return (#425).
  • chlo-Vorgänge (#602)
  • „Nicht in HLO“-Kategorie von StableHLO-Vorgängen – sie waren ursprünglich Teil des StableHLO-Vorgangs, wurden aber später als nicht richtig erwiesen: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (Nr. 3).
  • Kategorie „Dynamism“ von StableHLO-Vorgängen – sie wurden von MHLO gestartet, aber wir haben sie noch nicht angegeben: compute_reshape_shape, cstr_reshapable, dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, real_dynamic_slice, set_dimension_size (Nr. 8).
  • Formberechnungen, einschließlich arith-, shape- und tensor-Vorgängen (Nr. 8).
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Ops verarbeiten Eingaben und generieren Ausgaben. Eingaben werden in Eingabewerte (berechnet während der Ausführung), Eingabefunktionen (statisch bereitgestellt, weil in StableHLO bei Funktionen keine Werte der ersten Klasse sind) und Eingabeattribute (ebenfalls statisch bereitgestellt) kategorisiert. Die Art der Ein- und Ausgaben, die von einem Vorgang verbraucht und erzeugt werden, hängt von seiner Gedächtnisstütze ab. Beispielsweise verbraucht der Vorgang add zwei Eingabewerte und erzeugt einen Ausgabewert. Im Gegensatz dazu verbraucht die select_and_scatter-Operation 3 Eingabewerte, 2 Eingabefunktionen und 3 Eingabeattribute.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Eingabefunktionen (auch als anonyme Funktionen bezeichnet) sind den benannten Funktionen sehr ähnlich, mit folgenden Unterschieden: 1) Sie haben keine Kennung (also der Name „anonymous“), 2) sie deklarieren keine Ausgabetypen (Ausgabetypen werden aus dem return-Vorgang in der Funktion abgeleitet).

Die Syntax für Eingabefunktionen enthält einen aktuell nicht verwendeten Teil (siehe Unused-Produktion oben), der aus Gründen der Kompatibilität mit MLIR dient. Beim MLIR gibt es ein allgemeineres Konzept von "Regionen", in denen mehrere "Blocks" von Operationen enthalten sein können, die über Jump-Operationen miteinander verbunden sind. Diese Blöcke haben IDs, die der Unused-Produktion entsprechen, damit sie voneinander unterschieden werden können. StableHLO hat keine Jump-Ops, sodass der entsprechende Teil der MLIR-Syntax nicht verwendet wird (aber immer noch vorhanden ist).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Eingabeattribute haben einen Namen und einen Wert, der eine der unterstützten Konstanten ist. Sie sind die primäre Möglichkeit, statische Metadaten für Programmelemente anzugeben. Der Vorgang concatenate verwendet beispielsweise das Attribut dimension, um die Dimension anzugeben, mit der die Eingabewerte verkettet werden. In ähnlicher Weise verwendet der slice-Vorgang mehrere Attribute wie start_indices und limit_indices, um die Grenzen anzugeben, die zum Aufteilen des Eingabewerts verwendet werden.

Derzeit enthalten StableHLO-Programme in der Wildnis manchmal Attribute, die in diesem Dokument nicht beschrieben werden. Wir planen, diese Attribute in Zukunft entweder in das StableHLO-Opset aufzunehmen oder zu verhindern, dass sie in StableHLO-Programmen erscheinen. In der Zwischenzeit finden Sie hier die Liste dieser Attribute:

  • layout (#629)
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619)
  • output_operand_aliases (#740)
  • Standortmetadaten (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

Die Vorgangssignatur besteht aus den Typen aller Eingabewerte (die Liste der Typen auf der linken Seite von ->) und den Typen aller Ausgabewerte (die Liste der Typen auf der rechten Seite von ->). Streng genommen sind Eingabetypen redundant und Ausgabetypen fast immer redundant (da Ausgabetypen bei den meisten StableHLO-Vorgängen aus Eingaben abgeleitet werden können). Dennoch ist die Vorgangssignatur bewusst Teil der StableHLO-Syntax, da sie mit MLIR kompatibel ist.

Unten sehen Sie ein Beispiel für eine Operation, deren Gedächtnisstütze select_and_scatter ist. Er verbraucht 3 Eingabewerte (%operand, %source und %init_value), 2 Eingabefunktionen und 3 Eingabeattribute (window_dimensions, window_strides und padding). Beachten Sie, dass die Signatur des Vorgangs nur die Typen seiner Eingabewerte enthält (aber nicht die Typen von Eingabefunktionen und -attributen, die inline bereitgestellt werden).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Konstanten

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

StableHLO-Konstanten haben ein Literal und einen Typ, die zusammen einen StableHLO-Wert darstellen. Im Allgemeinen ist der Typ Teil der konstanten Syntax, es sei denn, er ist eindeutig (z.B. hat eine boolesche Konstante eindeutig den Typ i1, während eine Ganzzahlkonstante mehrere mögliche Typen haben kann).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Boolesche Konstanten stellen die booleschen Werte true und false dar. Boolesche Konstanten haben den Typ i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Ganzzahlkonstanten stellen Ganzzahlwerte über Strings in Dezimal- oder Hexadezimalschreibweise dar. Andere Basen, z.B. binäre oder Oktalwerte, werden nicht unterstützt. Ganzzahlkonstanten haben die folgenden Einschränkungen:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Gleitkommakonstanten stellen Gleitkommawerte über Strings in Dezimal oder wissenschaftlicher Notation dar. Außerdem kann die Hexadezimalschreibweise verwendet werden, um die zugrunde liegenden Bits direkt im Gleitkommaformat des entsprechenden Typs anzugeben. Für Gleitkommakonstanten gelten folgende Einschränkungen:

  • (C1) Wenn die nicht-hexadezimale Notation verwendet wird, is_wellformed(float_literal, float_type).
  • (C2) Wenn die Hexadezimalschreibweise verwendet wird: size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Komplexe Konstanten stellen komplexe Werte mithilfe von Listen aus einem reellen Teil (kommt zuerst) und eines imaginären Teils (zweiter) dar. Beispiel: (1.0, 0.0) : complex<f32> steht für 1.0 + 0.0i und (0.0, 1.0) : complex<f32> für 0.0 + 1.0i. Die Reihenfolge, in der diese Teile dann im Speicher gespeichert werden, hängt von der Implementierung ab. Für komplexe Konstanten gelten die folgenden Einschränkungen:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Tensor-Konstanten stellen Tensorwerte mithilfe verschachtelter Listen dar, die über die NumPy-Notation angegeben werden. dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> stellt beispielsweise einen Tensorwert mit der folgenden Zuordnung von Indizes zu Elementen dar: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Die Reihenfolge, in der diese Elemente dann im Arbeitsspeicher gespeichert werden, ist durch die Implementierung definiert. Für Tensorkonstanten gelten folgende Einschränkungen:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), wobei:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), wobei:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • Andernfalls false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Quantisierte Tensorkonstanten stellen quantisierte Tensorwerte mit derselben Notation wie Tensorkonstanten dar, wobei Elemente als Konstanten ihres Speichertyps angegeben werden. Quantisierte Tensorkonstanten haben die folgenden Einschränkungen:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Stringliterale bestehen aus Byte, die mit ASCII-Zeichen und Escape-Sequenzen angegeben werden. Sie sind codierungsunabhängig, sodass die Interpretation dieser Byte implementierungsspezifisch definiert ist. Stringliterale haben den Typ string.

Operativer Betrieb

abs

Semantik

Führt die elementweise Abs-Operationen für den operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für vorzeichenbehaftete Ganzzahlen: ganzzahliger Modulus.
  • Für Gleitkommazahlen: abs aus IEEE-754.
  • Bei komplexen Zahlen: komplexer Modulus.
  • Für quantisierte Typen: dequantize_op_quantize(abs, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor einer vorzeichenbehafteten Ganzzahl, Gleitkommazahl oder eines komplexen Typs oder eines pro Tensor quantisierten Tensors (C1–C2)

Ausgaben

Name Typ Einschränkung
result Tensor einer vorzeichenbehafteten Ganzzahl oder Gleitkommazahl oder pro Tensor quantisierter Tensor (C1–C2)

Einschränkung

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) ist so definiert:
    • complex_element_type(element_type(operand)), wenn is_complex(operand).
    • Andernfalls baseline_element_type(operand).

Beispiele

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Weitere Beispiele

Hinzufügen

Semantik

Führt die elementweise Addition der beiden Tensoren lhs und rhs durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches ODER
  • Für Ganzzahlen: ganzzahlige Addition
  • Für Gleitkommazahlen: addition aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Addition.
  • Für quantisierte Typen: dequantize_op_quantize(add, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor oder quantisierter Tensor pro Tensor (C1)
(I2) rhs Tensor oder quantisierter Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Weitere Beispiele

after_all

Semantik

Damit wird sichergestellt, dass die Vorgänge, die inputs generieren, vor Vorgängen ausgeführt werden, die von result abhängen. Die Ausführung dieses Vorgangs bewirkt nichts, sondern dient nur zum Einrichten von Datenabhängigkeiten von result bis inputs.

Eingaben

Label Name Typ
(I1) inputs variadische Anzahl von token

Ausgaben

Name Typ
result token

Beispiele

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Weitere Beispiele

all_gather

Semantik

Verkettet in jeder Prozessgruppe im StableHLO-Prozessraster die Werte des Tensors operand von jedem Prozess entlang von all_gather_dim und erzeugt einen result-Tensor.

Der Vorgang teilt das StableHLO-Prozessraster in process_groups auf, die so definiert ist:

  • cross_replica(replica_groups) wenn channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) wenn channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) wenn channel_id > 0 and use_global_device_ids = true.

Danach geschieht in jedem process_group Folgendes:

  • operands@receiver = [operand@sender for sender in process_group] für alle receiver in process_group.
  • result@process = concatenate(operands@process, all_gather_dim) für alle process in process_group.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1), (C6)
(I2) all_gather_dim Konstante vom Typ si64 (C1), (C6)
(I3) replica_groups 2-dimensionale Tensorkonstante vom Typ si64 (C2–C4)
(I4) channel_id Konstante vom Typ si64 (C5)
(I5) use_global_device_ids Konstante vom Typ i1 (C5)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C6)

Einschränkung

  • (C1) 0 <= all_gather_dim < rank(operand).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) ist so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_replicas, wenn cross_replica_and_partition verwendet wird.
    • num_processes, wenn flattened_ids verwendet wird.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Wenn use_global_device_ids = true, dann channel_id > 0.
  • (C6) type(result) = type(operand) außer:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

Weitere Beispiele

all_reduce

Semantik

Wendet in jeder Prozessgruppe im StableHLO-Prozessraster die Reduktionsfunktion computation auf die Werte des Tensors operand aus jedem Prozess an und erzeugt einen result-Tensor.

Der Vorgang teilt das StableHLO-Prozessraster in process_groups auf, die so definiert ist:

  • cross_replica(replica_groups) wenn channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) wenn channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) wenn channel_id > 0 and use_global_device_ids = true.

Danach geschieht in jedem process_group Folgendes:

  • result@process[result_index] = exec(schedule) für eine binäre Baumstruktur schedule, wobei Folgendes gilt:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule ist ein implementierungsdefinierter Binärbaum, dessen Durchlauf in der Reihenfolge to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])) ist.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C5), (C6)
(I2) replica_groups variadische Anzahl eindimensionaler Tensorkonstanten vom Typ si64 (C1–C3)
(I3) channel_id Konstante vom Typ si64 (C4)
(I4) use_global_device_ids Konstante vom Typ i1 (C4)
(I5) computation Funktion (C5)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C6–C7)

Einschränkung

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) ist so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_replicas, wenn cross_replica_and_partition verwendet wird.
    • num_processes, wenn flattened_ids verwendet wird.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Wenn use_global_device_ids = true, dann channel_id > 0.
  • (C5) computation hat den Typ (tensor<E>, tensor<E>) -> (tensor<E>), wobei is_promotable(element_type(operand), E) ist.
  • (C6) shape(result) = shape(operand).
  • (C7) element_type(result) = E.

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

Weitere Beispiele

all_to_all

Semantik

Teilt in jeder Prozessgruppe im StableHLO-Prozessraster die Werte des Tensors operand entlang von split_dimension in Teile auf, verteilt sie auf die Prozesse, verkettet die verteilten Teile entlang von concat_dimension und erzeugt einen result-Tensor.

Der Vorgang teilt das StableHLO-Prozessraster in process_groups auf, die so definiert ist:

  • cross_replica(replica_groups), wenn channel_id <= 0.
  • cross_partition(replica_groups), wenn channel_id > 0.

Danach geschieht in jedem process_group Folgendes:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) für alle sender in process_group.
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group], wobei receiver_index = process_group.index(receiver).
  • result@process = concatenate(scattered_parts@process, concat_dimension).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1–C3), (C9)
(I2) split_dimension Konstante vom Typ si64 (C1), (C2), (C9)
(I3) concat_dimension Konstante vom Typ si64 (C3), (C9)
(I4) split_count Konstante vom Typ si64 (C2), (C4), (C8), (C9)
(I5) replica_groups 2-dimensionale Tensorkonstante vom Typ si64 (C5–C8)
(I6) channel_id Konstante vom Typ si64

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C9)

Einschränkung

  • (C1) 0 <= split_dimension < rank(operand).
  • (C2) dim(operand, split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operand).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) ist so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(result) = type(operand) außer:
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

Weitere Beispiele

sowie

Semantik

Führt ein elementweises AND von zwei Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches UND.
  • Für Ganzzahlen: bitweises UND.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)
(I2) rhs Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)

Ausgaben

Name Typ Einschränkung
result Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)

Einschränkung

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

Semantik

Führt eine elementweise atan2-Operation für den lhs- und rhs-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: atan2 aus IEEE-754.
  • Bei komplexen Zahlen: komplex atan2.
  • Für quantisierte Typen: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)
(I2) rhs Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Weitere Beispiele

batch_norm_grad

Semantik

Berechnet Gradienten mehrerer Eingaben von batch_norm_training, die von grad_output zurückpropagiert werden, und erzeugt die Tensoren grad_operand, grad_scale und grad_offset. Formaler kann dieser Vorgang als Zerlegung vorhandener StableHLO-Vorgänge mithilfe der Python-Syntax wie folgt ausgedrückt werden:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Führt für quantisierte Typen den Befehl dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1–C3), (C5)
(I2) scale Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C4), (C5)
(I3) mean Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C4)
(I4) variance Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C4)
(I5) grad_output Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C2), (C3)
(I6) epsilon Konstante vom Typ f32
(I7) feature_index Konstante vom Typ si64 (C1), (C5)

Ausgaben

Name Typ Einschränkung
grad_operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C2), (C3)
grad_scale Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C4)
grad_offset Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C4)

Einschränkung

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale und grad_offset haben dasselbe baseline_element_type.
  • (C3) operand, grad_output und grad_operand haben die gleiche Form.
  • (C4) scale, mean, variance, grad_scale und grad_offset haben die gleiche Form.
  • (C5) size(scale) = dim(operand, feature_index).

Beispiele

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantik

Normalisiert den operand-Tensor in allen Dimensionen mit Ausnahme der feature_index-Dimension und erzeugt einen result-Tensor. Formaler kann dieser Vorgang als Zerlegung vorhandener StableHLO-Vorgänge mithilfe der Python-Syntax ausgedrückt werden:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Führt für quantisierte Typen den Befehl dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1–C7)
(I2) scale Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C3)
(I3) offset Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C4)
(I4) mean Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C5)
(I5) variance Eindimensionaler Tensor vom Gleitkomma- oder Pro-Tensor-quantisierten Typ (C2), (C6)
(I6) epsilon Konstante vom Typ f32
(I7) feature_index Konstante vom Typ si64 (C1), (C3–C6)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C2), (C7)

Einschränkung

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance und result haben denselben baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantik

Berechnet den Mittelwert und die Varianz über alle Dimensionen mit Ausnahme der Dimension feature_index und normalisiert den Tensor operand, der die Tensoren output, batch_mean und batch_var erzeugt. Formaler kann dieser Vorgang als Zerlegung vorhandener StableHLO-Vorgänge mithilfe der Python-Syntax so ausgedrückt werden:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Führt für quantisierte Typen den Befehl dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)
(I2) scale Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor (C2), (C3)
(I3) offset Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor (C2), (C4)
(I4) epsilon Konstante vom Typ f32 (C1), (C3–C6)
(I5) feature_index Konstante vom Typ si64 (C1), (C3–C6)

Ausgaben

Name Typ Einschränkung
output Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C7)
batch_mean Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor (C2), (C5)
batch_var Eindimensionaler Tensor von Gleitkomma oder pro Tensor quantisiertem Tensor (C2), (C6)

Einschränkung

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var und output haben denselben baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Beispiele

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantik

Führt eine Bitcast-Operation auf dem operand-Tensor aus und erzeugt einen result-Tensor, bei dem die Bits des gesamten operand-Tensors mit dem Typ des result-Tensors neu interpretiert werden.

Formeller anhand von E = element_type(operand), E' = element_type(result) und R = rank(operand):

  • Wenn num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Wenn num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Wenn num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits gibt die speicherinterne Darstellung eines bestimmten Werts zurück. Sein Verhalten ist implementierungsabhängig, da die genaue Darstellung der Tensoren durch die Implementierung sowie die genaue Darstellung von Elementtypen ebenfalls durch die Implementierung definiert wird.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor (C1–C2)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor (C1–C2)

Einschränkung

  • (C1) Für E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) und R = rank(operand) gilt:
    • Wenn num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Wenn num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) für alle 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Wenn num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) für alle 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Wenn is_complex(operand) or is_complex(result), dann is_complex(operand) and is_complex(result).

Beispiele

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Weitere Beispiele

broadcast_in_dim

Semantik

Erweitert die Dimensionen und/oder den Rang eines Eingabetensors durch Duplizieren der Daten im operand-Tensor und Erzeugt einen result-Tensor. Formaler gesagt: result[result_index] = operand[operand_index], wobei für alle d in axes(operand) Folgendes gilt:

  • operand_index[d] = 0, wenn dim(operand, d) = 1.
  • Andernfalls operand_index[d] = result_index[broadcast_dimensions[d]].

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor (C1–C2), (C5–C6)
(I2) broadcast_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C2–C6)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor (C1), (C3), (C5-C6)

Einschränkung

  • (C1) element_type(result) wird gegeben durch:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), allerdings können sich quantization_dimension(operand), scales(operand) und zero_points(operand) von quantization_dimension(result), scales(result) und zero_points(result) unterscheiden.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Für alle d in axes(operand):
    • dim(operand, d) = 1 oder
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Wenn is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Wenn dim(operand, quantization_dimension(operand)) = 1, dann scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Beispiele

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Weitere Beispiele

Supportanfrage

Semantik

Erzeugt die Ausgabe aus der Ausführung genau einer Funktion in branches in Abhängigkeit vom Wert von index. Formell. Für result = selected_branch() gilt:

  • selected_branch = branches[index], wenn 0 <= index < size(branches).
  • Andernfalls selected_branch = branches[-1].

Eingaben

Label Name Typ Einschränkung
(I1) index 0-dimensionaler Tensor vom Typ si32
(I2) branches variadische Anzahl von Funktionen (C1–C4)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C4)

Einschränkung

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Beispiele

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Weitere Beispiele

CBRT

Semantik

Führt eine elementweise kubische Wurzeloperation für den operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: rootn(x, 3) aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Kubikwurzel.
  • Für quantisierte Typen: dequantize_op_quantize(cbrt, operand, type(result))

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Weitere Beispiele

Ceil

Semantik

Führt einen elementweisen Ceil des operand-Tensors aus und erzeugt einen result-Tensor. Implementiert den Vorgang roundToIntegralTowardPositive aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(ceil, operand, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Weitere Beispiele

Cholesky

Semantik

Berechnet die Cholesky-Zersetzung eines Batches von Matrizen.

Formaler gesagt ist result[i0, ..., iR-3, :, :] für alle i in index_space(result) eine Cholesky-Zerlegung von a[i0, ..., iR-3, :, :] in Form einer Matrix für das untere Dreieck (wenn lower true ist) oder eine obere Dreiecksmatrix (wenn lower false ist). Die Ausgabewerte im gegenüberliegenden Dreieck, d. h. das strenge obere Dreieck bzw. das strikte untere Dreieck entsprechend, sind implementierungsdefiniert.

Wenn i vorhanden ist, bei der die Eingabematrix keine positive Definite Matrix des Hermitian ist, ist das Verhalten nicht definiert.

Führt für quantisierte Typen den Befehl dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) a Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1–C3)
(I2) lower 0-dimensionale Tensorkonstante vom Typ i1

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Beispiele

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

einschränken

Semantik

Bindet jedes Element des Tensors operand zwischen einem Mindest- und Höchstwert ein und erzeugt einen result-Tensor. Formaler: result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), wobei min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Führt für quantisierte Typen den Befehl dequantize_op_quantize(clamp, min, operand, max, type(result)) aus.

Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung komplexer Zahlen für diese Operation (#560) in Zukunft einzustellen.

Eingaben

Label Name Typ Einschränkung
(I1) min Tensor oder quantisierter Tensor pro Tensor (C1), (C3)
(I2) operand Tensor oder quantisierter Tensor pro Tensor (C1–C4)
(I3) max Tensor oder quantisierter Tensor pro Tensor (C2), (C3)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C4)

Einschränkung

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Beispiele

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Weitere Beispiele

collective_broadcast

Semantik

Senden Sie in jeder Prozessgruppe im StableHLO-Prozessraster den Wert des Tensors operand vom Quellprozess an die Zielprozesse und erzeugen Sie einen result-Tensor.

Der Vorgang teilt das StableHLO-Prozessraster in process_groups auf, die so definiert ist:

  • cross_replica(replica_groups), wenn channel_id <= 0.
  • cross_partition(replica_groups), wenn channel_id > 0.

Danach wird result@process angegeben durch:

  • operand@process_groups[i, 0], wenn ein i vorhanden ist, sodass sich der Prozess in process_groups[i] befindet.
  • Andernfalls broadcast_in_dim(constant(0, element_type(result)), [], type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor (C3)
(I2) replica_groups variadische Anzahl eindimensionaler Tensorkonstanten vom Typ si64 (C1), (C2)
(I3) channel_id Konstante vom Typ si64

Ausgaben

Name Typ Einschränkung
result Tensor (C3)

Einschränkung

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, wobei N so definiert ist:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C3) type(result) = type(operand).

Beispiele

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantik

Sendet in jeder Prozessgruppe im StableHLO-Prozessraster den Wert des Tensors operand vom Quellprozess an den Zielprozess und erzeugt einen result-Tensor.

Der Vorgang teilt das StableHLO-Prozessraster in process_groups auf, die so definiert ist:

  • cross_replica(source_target_pairs), wenn channel_id <= 0.
  • cross_partition(source_target_pairs), wenn channel_id > 0.

Danach wird result@process angegeben durch:

  • operand@process_groups[i, 0], wenn eine i mit process_groups[i, 1] = process vorhanden ist.
  • Andernfalls broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C5)
(I2) source_target_pairs 2-dimensionale Tensorkonstante vom Typ si64 (C1–C4)
(I3) channel_id Konstante vom Typ si64

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1)

Einschränkung

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, wobei N so definiert ist:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_partitions, wenn cross_partition verwendet wird.
  • (C5) type(result) = type(operand).

Beispiele

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Weitere Beispiele

Nutzer*innen

Semantik

Führt einen elementweisen Vergleich der Tensoren lhs und rhs gemäß comparison_direction und compare_type durch und erzeugt einen result-Tensor.

Die Werte von comparison_direction und compare_type haben die folgende Semantik:

Für boolesche und ganzzahlige Elementtypen:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

Für Gleitkomma-Elementtypen mit compare_type = FLOAT implementiert der Vorgang die folgenden IEEE-754-Vorgänge:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

Bei Gleitkomma-Elementtypen mit compare_type = TOTALORDER verwendet der Vorgang die Kombination aus totalOrder- und compareQuietEqual-Vorgängen aus IEEE-754. Diese Funktion scheint ungenutzt zu sein und wird demnächst entfernt (#584).

Bei komplexen Elementtypen wird der lexikografische Vergleich von (real, imag)-Paaren mit den bereitgestellten comparison_direction und compare_type durchgeführt. Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung für komplexe Zahlen in Zukunft einzustellen, wenn comparison_direction GE, GT, LE oder LT ist (#560).

Für quantisierte Typen. Führt dequantize_compare(lhs, rhs, comparison_direction) aus.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor oder quantisierter Tensor pro Tensor (C1–C3)
(I2) rhs Tensor oder quantisierter Tensor pro Tensor (C1–C2)
(I3) comparison_direction Aufzählung von EQ, NE, GE, GT, LE und LT
(I4) compare_type Aufzählung von FLOAT, TOTALORDER, SIGNED und UNSIGNED (C3)

Ausgaben

Name Typ Einschränkung
result Tensor des booleschen Typs (C2)

Einschränkung

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type ist so definiert:
    • SIGNED, wenn is_signed_integer(element_type(lhs)).
    • UNSIGNED, wenn is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT oder TOTALORDER, wenn is_float(element_type(lhs)).
    • FLOAT, wenn is_complex(element_type(lhs)).

Beispiele

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Weitere Beispiele

komplex

Semantik

Führt eine elementweise Umwandlung in einen komplexen Wert aus einem Paar reeller und imaginärer Werte, lhs und rhs, durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor vom Typ f32 oder f64 (C1–C3)
(I2) rhs Tensor vom Typ f32 oder f64 (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des komplexen Typs (C2), (C3)

Einschränkung

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) hat den Typ complex<E>, wobei E = element_type(lhs) ist.

Beispiele

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Weitere Beispiele

concatenate

Semantik

Verkettet inputs entlang der Dimension dimension in der gleichen Reihenfolge wie die angegebenen Argumente und erzeugt einen result-Tensor. Formeller gesagt, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], wobei:

  1. id = d0 + ... + dk-1 + kd.
  2. d ist gleich dimension und d0, ... sind die d. Dimensionsgrößen von inputs.

Eingaben

Label Name Typ Einschränkung
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C6)
(I2) dimension Konstante vom Typ si64 (C2), (C4), (C6)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C5–C6)

Einschränkung

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) mit Ausnahme von dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) mit Ausnahme von:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Beispiele

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Weitere Beispiele

Konstante

Semantik

Erzeugt einen output-Tensor aus einer konstanten value.

Eingaben

Label Name Typ Einschränkung
(I1) value Konstante (C1)

Ausgaben

Name Typ Einschränkung
output Tensor oder quantisierter Tensor (C1)

Einschränkung

  • (C1) type(value) = type(output).

Beispiele

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Weitere Beispiele

eine Conversion ausführen

Semantik

Führt eine elementweise Umwandlung von einem Elementtyp in einen anderen auf dem operand-Tensor durch und erzeugt einen result-Tensor.

Bei Conversions vom Typ boolean-to-any-supported-type wird der Wert false in null und der Wert true in eins umgewandelt. Bei any-supported-type-to-boolean wird ein Nullwert in false und Werte ungleich null in true umgewandelt. Im Folgenden erfahren Sie, wie dies bei komplexen Typen funktioniert.

Bei Konvertierungen mit integer-to-integer, integer-to-floating-point oder floating-point-to-floating-point und wenn der Quellwert im Zieltyp genau dargestellt werden kann, ist der Ergebniswert genau diese Darstellung. Andernfalls ist das Verhalten noch nicht festgelegt (#180).

Bei Konvertierungen mit floating-point-to-integer wird der Bruchteil abgeschnitten. Wenn der abgeschnittene Wert im Zieltyp nicht dargestellt werden kann, gilt das Verhalten noch offen (#180).

Umwandlungen von komplexen bis komplex folgen demselben Verhalten wie Gleitkomma-zu-Gleitkomma-Konvertierungen zur Konvertierung von Real- und imaginären Teilen.

Bei Konvertierungen vom Typ complex-to-any-other-type und complex-to-any-other-type wird der imaginäre Quellwert ignoriert bzw. der imaginäre Zielwert auf null gesetzt. Die Umwandlung des tatsächlichen Teils folgt der Gleitkommawertkonvertierung.

Im Prinzip könnte dieser Vorgang Dequantisierung (Umwandlung von quantisierten Tensoren in reguläre Tensoren), Quantisierung (Umwandlung von regulären Tensoren in quantisierte Tensoren) und Requantisierung (Umwandlung zwischen quantisierten Tensoren) ausdrücken, aber derzeit haben wir spezielle Operationen dafür: uniform_dequantize für den ersten Anwendungsfall und uniform_quantize für den zweiten und dritten Anwendungsfall. In Zukunft werden diese beiden Vorgänge möglicherweise zu convert zusammengeführt (#1576).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor (C1)

Einschränkung

  • (C1) shape(operand) = shape(result).

Beispiele

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Weitere Beispiele

Faltung

Semantik

Berechnet Punktprodukte zwischen Fenstern von lhs und Segmenten von rhs und erzeugt result. Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result aus lhs und rhs berechnet werden.

Formeller können Sie die folgende Neuformulierung der Eingaben in Bezug auf lhs in Betracht ziehen, um Fenster von lhs auszudrücken:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Dabei werden die folgenden Hilfsfunktionen verwendet:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1], wobei j[d] = i[permutation[d]].

Wenn feature_group_count = 1 und batch_group_count = 1, dann für alle output_spatial_index in index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product, wobei:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]): Diese Funktion scheint ungenutzt zu sein und wird demnächst entfernt (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Wenn feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Wenn batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Führt für quantisierte Typen den Befehl dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor oder quantisierter Tensor pro Tensor (C1), (C10-C11), (C14) (C25), (C27-C30)
(I2) rhs Tensor oder quantisierter Tensor (C1), (C14-C16), (C25), (C27-C32)
(I3) window_strides Eindimensionale Tensorkonstante vom Typ si64 (C2–C3), (C25)
(I4) padding 2-dimensionale Tensorkonstante vom Typ si64 (C4), (C25)
(I5) lhs_dilation Eindimensionale Tensorkonstante vom Typ si64 (C5–C6), (C25)
(I6) rhs_dilation Eindimensionale Tensorkonstante vom Typ si64 (C7–C8), (C25)
(I7) window_reversal Eindimensionale Tensorkonstante vom Typ i1 (C9)
(I8) input_batch_dimension Konstante vom Typ si64 (C10), (C13), (C25)
(I9) input_feature_dimension Konstante vom Typ si64 (C11), (C13–C14)
(I10) input_spatial_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension Konstante vom Typ si64 (C14), (C18)
(I12) kernel_output_feature_dimension Konstante vom Typ si64 (C15-C16), (C18), (C25), (C32)
(I13) kernel_spatial_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C17–C18), (C25)
(I14) output_batch_dimension Konstante vom Typ si64 (C20), (C25)
(I15) output_feature_dimension Konstante vom Typ si64 (C20), (C25), (C33)
(I16) output_spatial_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C19–C20), (C25)
(I17) feature_group_count Konstante vom Typ si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count Konstante vom Typ si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config variadische Anzahl von Enums von DEFAULT, HIGH und HIGHEST (C24)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor (C25–C28), (C30–C31), (C33)

Einschränkung

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Angegebene input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Angegebene kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Angegebene output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) ist so definiert:
    • dim(lhs, input_batch_dimension) / batch_group_count, wenn result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension), wenn result_dim = output_feature_dimension.
    • Andernfalls num_windows. Dabei gilt:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Wenn der Vorgang nicht quantisierte Tensoren verwendet:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Wenn der Vorgang quantisierte Tensoren verwendet:
    • (C28) is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result).
    • (C29) storage_type(lhs) = storage_type(rhs).
    • (C30) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C31) Wenn is_per_tensor_quantized(rhs), dann is_per_tensor_quantized(result).
    • (C32) Wenn is_per_axis_quantized(rhs), dann quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C33) Wenn is_per_axis_quantized(result), dann quantization_dimension(result) = output_feature_dimension.

Beispiele

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

Kosinus

Semantik

Führt eine elementweise Kosinusoperation für den operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: cos aus IEEE-754.
  • Bei komplexen Zahlen: komplexer Kosinus.
  • Für quantisierte Typen: dequantize_op_quantize(cosine, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Weitere Beispiele

count_leading_zeros

Semantik

Führt eine elementweise Zählung der Anzahl führender Null-Bits im operand-Tensor durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Ganzzahltyps (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Ganzzahltyps (C1)

Einschränkung

  • (C1) type(operand) = type(result).

Beispiele

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Weitere Beispiele

custom_call

Semantik

Kapselt einen implementierungsdefinierten Vorgang call_target_name ein, der inputs und called_computations annimmt und results erzeugt. Mit has_side_effect, backend_config und api_version können zusätzliche durch die Implementierung definierte Metadaten bereitgestellt werden.

Derzeit enthält dieser Vorgang eine ziemlich unorganisierte Sammlung von Metadaten, die die organische Entwicklung des entsprechenden Vorgangs im XLA-Compiler widerspiegelt. Wir planen, diese Metadaten künftig zu vereinheitlichen (#741).

Eingaben

Label Name Typ
(I1) inputs variadische Anzahl von Werten
(I2) call_target_name Konstante vom Typ string
(I3) has_side_effect Konstante vom Typ i1
(I4) backend_config Konstante vom Typ string
(I5) api_version Konstante vom Typ si32
(I6) called_computations variadische Anzahl von Konstanten vom Typ string

Ausgaben

Name Typ
results variadische Anzahl von Werten

Beispiele

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

Dividieren

Semantik

Führt die elementweise Division der Tensoren lhs und Divisor rhs durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Ganzzahlen: Ganzzahldivision, die den algebraischen Quotienten erzeugt, wobei jeder Bruchteil verworfen wird.
  • Für Gleitkommazahlen: division aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Division.
  • Für quantisierte Typen:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)
(I2) rhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Einschränkung

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Weitere Beispiele

dot_general

Semantik

Berechnet Punktprodukte zwischen Segmenten von lhs und Segmenten von rhs und erzeugt einen result-Tensor.

Formell: result[result_index] = dot_product, wobei:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index, wobei size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) und size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Führt für quantisierte Typen den Befehl dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)) aus.

Gibt nur die Semantik für die Quantisierung pro Tensor an. Die Quantisierung pro Achse wird ausgeführt (#1574). Zukünftig werden wir auch erwägen, die Unterstützung für die Hybridquantisierung hinzuzufügen (#1575).

precision_config steuert den Kompromiss zwischen Geschwindigkeit und Genauigkeit bei Berechnungen auf Beschleuniger-Back-Ends. Dabei kann es sich um einen der folgenden Werte handeln (derzeit ist die Semantik dieser enum-Werte zu niedrig, wir planen jedoch, dies in #755 zu beheben):

  • DEFAULT: schnellste Berechnung, aber am wenigsten genaue Näherung an die ursprüngliche Zahl.
  • HIGH: Langsamere Berechnung, aber genauere Annäherung an die ursprüngliche Zahl.
  • HIGHEST: langsamste Berechnung, aber genaueste Näherung an die ursprüngliche Zahl.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor oder quantisierter Tensor pro Tensor (C5-C6), (C9-C10), (C12-C16)
(I2) rhs Tensor oder quantisierter Tensor pro Tensor (C7–C10), (C12)
(I3) lhs_batching_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C2), (C4), (C8), (C10)
(I7) precision_config variadische Anzahl von Enums von DEFAULT, HIGH und HIGHEST (C11)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C12), (C14), (C16)

Einschränkung

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Wenn der Vorgang nicht quantisierte Tensoren verwendet:
    • (C13) element_type(lhs) = element_type(rhs).
  • Wenn der Vorgang quantisierte Tensoren verwendet:
    • (C14) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C15) storage_type(lhs) = storage_type(rhs).
    • (C16) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C17) zero_points(rhs) = 0.

Beispiele

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Weitere Beispiele

dynamic_slice

Semantik

Extrahiert mithilfe von dynamisch berechneten Startindexen ein Slice aus dem operand und erzeugt einen result-Tensor. start_indices enthält die Startindexe des Segments für jede Dimension, für die eine Anpassung möglich ist, und slice_sizes die Größe des Segments für jede Dimension. Formaler gesagt, result[result_index] = operand[operand_index], wobei:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1), (C2), (C4)
(I2) start_indices variadische Anzahl von 0-dimensionalen Tensoren vom Typ „Ganzzahl“ (C2), (C3)
(I3) slice_sizes Eindimensionale Tensorkonstante vom Typ si64 (C2), (C4), (C5)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1), (C5)

Einschränkung

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Beispiele

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Weitere Beispiele

dynamic_update_slice

Semantik

Erzeugt einen result-Tensor, der dem operand-Tensor entspricht, mit der Ausnahme, dass das Slice, das bei start_indices beginnt, mit den Werten in update aktualisiert wird. Formell ist result[result_index] so definiert:

  • update[update_index], wenn 0 <= update_index < shape(update), wobei:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • Andernfalls operand[result_index].

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1–C4), (C6)
(I2) update Tensor oder quantisierter Tensor pro Tensor (C2), (C3), (C6)
(I3) start_indices variadische Anzahl von 0-dimensionalen Tensoren vom Typ „Ganzzahl“ (C4), (C5)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1)

Einschränkung

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Beispiele

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Weitere Beispiele

Exponentialfunktionen

Semantik

Führt eine elementweise exponentielle Operation auf dem operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: exp aus IEEE-754.
  • Bei komplexen Zahlen: komplex exponentiell.
  • Für quantisierte Typen: dequantize_op_quantize(exponential, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Weitere Beispiele

exponential_minus_one

Semantik

Führt eine elementweise exponentielle minus eins Operation auf dem operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: expm1 aus IEEE-754.
  • Bei komplexen Zahlen: komplexes Exponential minus eins.
  • Für quantisierte Typen: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Weitere Beispiele

fft

Semantik

Führt die Vorwärts- und Inverse Fourier-Transformationen für reale und komplexe Ein-/Ausgaben aus.

fft_type ist einer der folgenden Werte:

  • FFT: Komplexe zu komplexen FFT weiterleiten.
  • IFFT: Umgekehrte komplex-komplexe FFT.
  • RFFT: Eine reelle zu komplexe FFT weiterleiten.
  • IRFFT: Umgekehrte reell-komplexe FFT (nimmt komplex an, gibt reelle Zahlen zurück).

Formeller ausgedrückt, erzeugt die Funktion fft, die eindimensionale Tensoren komplexer Typen als Eingabe verwendet, eindimensionale Tensoren derselben Typen wie die Ausgabe und berechnet die diskrete Fourier-Transformation:

Für fft_type = FFT ist result als Endergebnis einer Reihe von L-Berechnungen definiert, wobei L = size(fft_length) ist. Beispiel für L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Wenn die Funktion ifft die gleiche Typsignatur hat und den Kehrwert von fft berechnet, gilt außerdem Folgendes:

Für fft_type = IFFT ist result als Kehrwert der Berechnungen für fft_type = FFT definiert. Beispiel für L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Wenn die Funktion rfft, die eindimensionale Tensoren von Gleitkommatypen annimmt, eindimensionale Tensoren komplexer Typen mit derselben Gleitkommasemantik erzeugt und so funktioniert:

  • rfft(real_operand) = truncated_result, wobei
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

Wenn die diskrete Fourier-Transformation für reelle Operanden berechnet wird, definieren die ersten N/2 + 1-Elemente des Ergebnisses den Rest des Ergebnisses eindeutig, sodass das Ergebnis von rfft abgeschnitten wird, um die Berechnung redundanter Elemente zu vermeiden.

Für fft_type = RFFT ist result als Endergebnis einer Reihe von L-Berechnungen definiert, wobei L = size(fft_length) ist. Beispiel für L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Wenn die Funktion irfft die gleiche Typsignatur hat und den Kehrwert von rfft berechnet, gilt Folgendes:

Für fft_type = IRFFT ist result als Kehrwert der Berechnungen für fft_type = RFFT definiert. Beispiel für L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkomma- oder komplexen Typs (C1), (C2), (C4), (C5)
(I2) fft_type Aufzählung von FFT, IFFT, RFFT und IRFFT (C2), (C5)
(I3) fft_length Eindimensionale Tensorkonstante vom Typ si64 (C1), (C3), (C4)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkomma- oder komplexen Typs (C2), (C4), (C5)

Einschränkung

  • (C1) size(fft_length) <= rank(operand).
  • (C2) Die Beziehung zwischen den Elementtypen operand und result variiert:
    • Wenn fft_type = FFT, element_type(operand) und element_type(result) denselben komplexen Typ haben.
    • Wenn fft_type = IFFT, element_type(operand) und element_type(result) denselben komplexen Typ haben.
    • Bei fft_type = RFFT ist element_type(operand) ein Gleitkommatyp und element_type(result) ein komplexer Typ derselben Gleitkommasemantik.
    • Bei fft_type = IRFFT ist element_type(operand) ein komplexer Typ und element_type(result) ein Gleitkommatyp mit derselben Gleitkommasemantik.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Wenn zwischen operand und result ein Tensor real eines Gleitkommatyps vorhanden ist, dann shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) mit Ausnahme von:
    • Wenn fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Wenn fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Beispiele

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

Etage

Semantik

Führt das elementweise Stockwerk des operand-Tensors aus und erzeugt einen result-Tensor. Implementiert den Vorgang roundToIntegralTowardNegative aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(floor, operand, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Weitere Beispiele

sammeln

Semantik

Erfasst Segmente des Tensors operand aus den in start_indices angegebenen Offsets und erzeugt einen result-Tensor.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result den Elementen in operand zugeordnet werden. Im Diagramm werden einige Beispiele für result-Indexe ausgewählt und erklärt, welchen operand-Indizes sie entsprechen.

Formeller gesagt, result[result_index] = operand[operand_index], wobei:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index ist so definiert:
    • start_indices[bi0, ..., :, ..., biN], wobei bi einzelne Elemente in batch_index sind und : in den Index index_vector_dim eingefügt wird, wenn index_vector_dim < rank(start_indices) ist.
    • Andernfalls [start_indices[batch_index]].
  • Für d_operand in axes(operand):
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) wenn d_operand = start_index_map[d_start].
    • Andernfalls full_start_index[d_operand] = 0.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], wobei oi einzelne Elemente in offset_index sind und 0 an Indizes von collapsed_slice_dims eingefügt wird.
  • operand_index = full_start_index + full_offset_index.

Wenn indices_are_sorted den Wert true hat, kann die Implementierung davon ausgehen, dass start_indices in Bezug auf start_index_map sortiert sind. Andernfalls ist das Verhalten nicht definiert. Formell für alle i1 < i2 aus indices(result), full_start_index(i1) <= full_start_index(i2).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1), (C7), (C10-C12), (C14)
(I2) start_indices Tensor des Ganzzahltyps (C2), (C3), (C13)
(I3) offset_dims Eindimensionale Tensorkonstante vom Typ si64 (C1), (C4-C5), (C13)
(I4) collapsed_slice_dims Eindimensionale Tensorkonstante vom Typ si64 (C1), (C6-C8), (C13)
(I5) start_index_map Eindimensionale Tensorkonstante vom Typ si64 (C3), (C9), (C10)
(I6) index_vector_dim Konstante vom Typ si64 (C2), (C3), (C13)
(I7) slice_sizes Eindimensionale Tensorkonstante vom Typ si64 (C8), (C11–C13)
(I8) indices_are_sorted Konstante vom Typ i1

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C5), (C13–C14)

Einschränkung

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes), wobei:
    • batch_dim_sizes = shape(start_indices), allerdings ist die Dimensionsgröße von start_indices, die index_vector_dim entspricht, nicht enthalten.
    • offset_dim_sizes = shape(slice_sizes), außer dass die Dimensionsgrößen in slice_sizes, die collapsed_slice_dims entsprechen, nicht enthalten sind.
    • combine platziert batch_dim_sizes an den Achsen, die batch_dims entsprechen, und offset_dim_sizes an den Achsen, die offset_dims entsprechen.
  • (C14) element_type(operand) = element_type(result).

Beispiele

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Weitere Beispiele

get_dimension_size

Semantik

Erzeugt die Größe der angegebenen dimension von operand. Formeller gesagt: result = dim(operand, dimension).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor (C1)
(I2) dimension Konstante vom Typ si64 (C1)

Ausgaben

Name Typ
result 0-dimensionaler Tensor vom Typ si32

Einschränkung

  • (C1) 0 <= dimension < rank(operand).

Beispiele

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Weitere Beispiele

get_tuple_element

Semantik

Extrahiert ein Element an der Position index des Tupels operand und erzeugt eine result. Formeller gesagt: result = operand[index].

Eingaben

Label Name Typ Einschränkung
(I1) operand tuple (C1), (C2)
(I2) index Konstante vom Typ si32 (C1), (C2)

Ausgaben

Name Typ Einschränkung
result alle unterstützten Typen (C2)

Einschränkung

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Beispiele

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Weitere Beispiele

if

Semantik

Erzeugt die Ausgabe aus der Ausführung genau einer Funktion aus true_branch oder false_branch, abhängig vom Wert von pred. Formeller gesagt: result = pred ? true_branch() : false_branch().

Eingaben

Label Name Typ Einschränkung
(I1) pred 0-dimensionaler Tensor vom Typ i1
(I2) true_branch Funktion (C1–C3)
(I3) false_branch Funktion (C1), (C2)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C3)

Einschränkung

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Beispiele

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Weitere Beispiele

Bild

Semantik

Extrahiert den Imaginärteil elementweise aus dem operand und erzeugt einen result-Tensor. Formeller für jedes Element x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkomma- oder komplexen Typs (C1), (C2)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps (C1), (C2)

Einschränkung

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) ist so definiert:
    • complex_element_type(element_type(operand)), wenn is_complex(operand).
    • Andernfalls element_type(operand).

Beispiele

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Weitere Beispiele

Einspeisung

Semantik

Liest Daten aus dem Feed und erzeugt results.

Die Semantik von infeed_config ist implementierungsdefiniert.

results bestehen aus Nutzlastwerten, die an erster Stelle stehen, und einem Token, das zuletzt steht. Zur besseren Verständlichkeit planen wir, die Nutzlast und das Token in Zukunft auf zwei separate Ausgaben aufzuteilen (#670).

Eingaben

Label Name Typ
(I1) token token
(I2) infeed_config Konstante vom Typ string

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C1–C3)

Einschränkung

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) oder is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Beispiele

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Weitere Beispiele

Iota

Semantik

Füllt einen output-Tensor mit Werten in aufsteigender Reihenfolge, beginnend bei null, entlang der iota_dimension-Dimension. Formell

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

Eingaben

Label Name Typ Einschränkung
(I1) iota_dimension si64 (C1)

Ausgaben

Name Typ Einschränkung
output Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Einschränkung

  • (C1) 0 <= iota_dimension < rank(output).

Beispiele

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Weitere Beispiele

is_finite

Semantik

Führt eine elementweise Prüfung des Werts in x durch, d.h. ist weder +Inf, -Inf noch NaN, und erzeugt einen y-Tensor. Implementiert den Vorgang isFinite aus der IEEE-754-Spezifikation. Bei quantisierten Typen ist das Ergebnis immer true.

Eingaben

Label Name Typ Einschränkung
(I1) x Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
y Tensor des booleschen Typs (C1)

Einschränkung

  • (C1) shape(x) = shape(y).

Beispiele

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Weitere Beispiele

log

Semantik

Führt eine elementweise Logarithmusoperation auf dem Tensor operand durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: log aus IEEE-754.
  • Bei komplexen Zahlen: komplexer Logarithmus.
  • Für quantisierte Typen: dequantize_op_quantize(log, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Weitere Beispiele

log_plus_one

Semantik

Führt einen elementweisen Logarithmus plus eine Operation auf dem operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: logp1 aus IEEE-754.
  • Bei komplexen Zahlen: komplexer Logarithmus plus eins.
  • Für quantisierte Typen: dequantize_op_quantize(log_plus_one, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Weitere Beispiele

Logistik

Semantik

Führt eine elementweise logistische Operation für den operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: division(1, addition(1, exp(-x))) aus IEEE-754.
  • Bei komplexen Zahlen: komplexe logistische Daten.
  • Für quantisierte Typen: dequantize_op_quantize(logistic, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Weitere Beispiele

Karte

Semantik

Wendet eine Kartenfunktion computation auf inputs entlang der dimensions an und erzeugt einen result-Tensor.

Formeller gesagt: result[result_index] = computation(inputs...[result_index]). dimensions werden derzeit nicht verwendet und wahrscheinlich in Zukunft entfernt (#487).

Eingaben

Label Name Typ Einschränkung
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C4)
(I2) dimensions Eindimensionale Tensorkonstante vom Typ si64 (C3)
(I3) computation Funktion (C4)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1), (C4)

Einschränkung

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation hat den Typ (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, wobei Ei = element_type(inputs[i]) und E' = element_type(result) verwendet werden.

Beispiele

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Weitere Beispiele

Maximum

Semantik

Führt eine elementweise maximale Operation für die Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches ODER
  • Für Ganzzahlen: ganzzahliges Maximum.
  • Für Gleitkommazahlen: maximum aus IEEE-754.
  • Bei komplexen Zahlen: das lexikografische Maximum für das (real, imaginary)-Paar. Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung komplexer Zahlen für diese Operation (#560) in Zukunft einzustellen.
  • Für quantisierte Typen:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor oder quantisierter Tensor pro Tensor (C1)
(I2) rhs Tensor oder quantisierter Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Weitere Beispiele

Minimum

Semantik

Führt eine elementweise minimale Operation für die Tensoren lhs und rhs durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches UND.
  • Für Ganzzahlen: minimale Ganzzahl.
  • Für Gleitkommazahlen: minimum aus IEEE-754.
  • Bei komplexen Zahlen: das lexikografische Minimum für das (real, imaginary)-Paar. Die Anordnung komplexer Zahlen erfordert eine überraschende Semantik. Daher planen wir, die Unterstützung komplexer Zahlen für diese Operation (#560) in Zukunft einzustellen.
  • Für quantisierte Typen:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor oder quantisierter Tensor pro Tensor (C1)
(I2) rhs Tensor oder quantisierter Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Weitere Beispiele

Multiplizieren

Semantik

Führt ein elementweises Produkt der beiden Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches UND.
  • Für Ganzzahlen: Ganzzahlmultiplikation.
  • Für Gleitkommazahlen: multiplication aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Multiplikation.
  • Für quantisierte Typen:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor oder quantisierter Tensor pro Tensor (C1)
(I2) rhs Tensor oder quantisierter Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Weitere Beispiele

negate

Semantik

Führt die elementweise Negation des Tensors operand durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei vorzeichenbehafteten Ganzzahlen: Ganzzahlennegation.
  • Für vorzeichenlose Ganzzahlen: Bitcast zu vorzeichenbehaftete Ganzzahl, Ganzzahl-Negation, Bitcast zurück in vorzeichenlose Ganzzahl.
  • Für Gleitkommazahlen: negate aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Negation.
  • Für quantisierte Typen: dequantize_op_quantize(negate, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Weitere Beispiele

nicht

Semantik

Führt ein elementweises NOT des Tensors operand aus und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches NOT.
  • Für Ganzzahlen: Das bitweise NOT.

Argumente

Name Typ Einschränkung
operand Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)

Ausgaben

Name Typ Einschränkung
result Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)

Einschränkung

  • (C1) type(operand) = type(result).

Beispiele

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

Semantik

Sorgt dafür, dass die Vorgänge, die das operand erzeugen, vor Vorgängen ausgeführt werden, die von result abhängig sind, und verhindert, dass Compiler-Transformationen Vorgänge über die Barriere verschieben. Ansonsten ist der Vorgang eine Identität, z.B. result = operand.

Argumente

Name Typ Einschränkung
operand variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result variadische Anzahl von Tensoren, quantisierte Tensoren oder Tokens pro Tensor (C1)

Einschränkung

  • (C1) type(operand...) = type(result...).

Beispiele

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Weitere Beispiele

oder

Semantik

Führt ein elementweises OR von zwei Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches ODER
  • Für Ganzzahlen: bitweises OR.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor des Typs „Ganzzahl“ oder „Boolescher Wert“ (C1)
(I2) rhs Tensor des Typs „Ganzzahl“ oder „Boolescher Wert“ (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Typs „Ganzzahl“ oder „Boolescher Wert“ (C1)

Einschränkung

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

Outfeed

Semantik

Schreibt inputs in den Outfeed und generiert ein result-Token.

Die Semantik von outfeed_config ist implementierungsdefiniert.

Eingaben

Label Name Typ
(I1) inputs variadische Anzahl von Tensoren oder quantisierten Tensoren
(I2) token token
(I3) outfeed_config Konstante vom Typ string

Ausgaben

Name Typ
result token

Beispiele

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Weitere Beispiele

Feld

Semantik

Maximiert operand durch Auffüllung um den Tensor sowie zwischen die Elemente des Tensors mit dem angegebenen padding_value.

edge_padding_low und edge_padding_high geben den Abstand an, der am unteren Ende (neben Index 0) bzw. am oberen Rand (neben dem höchsten Index) jeder Dimension hinzugefügt wird. Der Grad des Innenabstands kann negativ sein, wobei der absolute Wert des negativen Werts die Anzahl der Elemente angibt, die aus der angegebenen Dimension entfernt werden sollen.

interior_padding gibt den Abstand zwischen zwei Elementen in jeder Dimension an, der nicht negativ sein darf. Das Innen-Padding erfolgt vor dem Rand-Padding. Bei einem negativen Rand-Padding werden Elemente aus dem Operanden mit Auffüllung entfernt.

Formell ist result[result_index] so definiert:

  • operand[operand_index], wenn result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • Andernfalls padding_value.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1), (C2), (C4)
(I2) padding_value 0-dimensionaler Tensor oder quantisierter Tensor pro Tensor (C1)
(I3) edge_padding_low Eindimensionale Tensorkonstante vom Typ si64 (C1), (C4)
(I4) edge_padding_high Eindimensionale Tensorkonstante vom Typ si64 (C1), (C4)
(I5) interior_padding Eindimensionale Tensorkonstante vom Typ si64 (C2–C4)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C3–C6)

Einschränkung

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Beispiele

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Weitere Beispiele

partition_id

Semantik

Erzeugt partition_id des aktuellen Prozesses.

Ausgaben

Name Typ
result 0-dimensionaler Tensor vom Typ ui32

Beispiele

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Weitere Beispiele

Popcnt

Semantik

Führt eine elementweise Zählung der im operand-Tensor festgelegten Anzahl von Bits durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Ganzzahltyps (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Ganzzahltyps (C1)

Einschränkung

  • (C1) type(operand) = type(result).

Beispiele

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Weitere Beispiele

Leistung

Semantik

Führt die elementweise Exponentiierung des lhs-Tensors mit dem rhs-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Ganzzahlen: ganzzahlige Exponentiation.
  • Für Gleitkommazahlen: pow aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Exponentiation.
  • Für quantisierte Typen: dequantize_op_quantize(power, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)
(I2) rhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Weitere Beispiele

real

Semantik

Extrahiert den reellen Teil elementweise aus operand und erzeugt einen result-Tensor. Formeller für jedes Element x: real(x) = is_complex(x) ? real_part(x) : x.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkomma- oder komplexen Typs (C1), (C2)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps (C1), (C2)

Einschränkung

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) ist so definiert:
    • complex_element_type(element_type(operand)), wenn is_complex(operand).
    • Andernfalls element_type(operand).

Beispiele

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Weitere Beispiele

Recv

Semantik

Erhält Daten von einem Kanal mit channel_id und erzeugt results.

Wenn is_host_transfer den Wert true hat, überträgt der Vorgang Daten vom Host. Andernfalls werden Daten von einem anderen Gerät übertragen. Was das bedeutet, ist implementierungsdefiniert. Dieses Flag dupliziert die in channel_type bereitgestellten Informationen. Daher planen wir, in Zukunft nur eines davon zu behalten (#666).

results bestehen aus Nutzlastwerten, die an erster Stelle stehen, und einem Token, das zuletzt steht. Zur besseren Verständlichkeit planen wir, die Nutzlast und das Token in Zukunft auf zwei separate Ausgaben aufzuteilen (#670).

Eingaben

Label Name Typ Einschränkung
(I1) token token (C4)
(I2) channel_id Konstante vom Typ si64
(I3) channel_type Aufzählung von DEVICE_TO_DEVICE und HOST_TO_DEVICE (C1)
(I4) is_host_transfer Konstante vom Typ i1 (C1)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C2–C4)

Einschränkung

  • (C1) channel_type ist so definiert:
    • HOST_TO_DEVICE, wenn is_host_transfer = true,
    • Andernfalls DEVICE_TO_DEVICE.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) oder is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Beispiele

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Weitere Beispiele

reduce

Semantik

Wendet eine Reduktionsfunktion body auf inputs und init_values entlang der dimensions an und erzeugt results-Tensoren.

Die Reihenfolge der Reduzierungen ist implementierungsdefiniert. Das bedeutet, dass body und init_values ein Monoid bilden müssen, um zu garantieren, dass der Vorgang für alle Eingaben und Implementierungen dieselben Ergebnisse liefert. Diese Bedingung gilt jedoch nicht für viele gängige Reduzierungen. Beispielsweise bilden die Gleitkommazahlen für body und Null für init_values kein Monoid, da das Addieren von Gleitkommazahlen nicht assoziativ ist.

Formeller gesagt, results...[j0, ..., jR-1] = reduce(input_slices_converted), wobei:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], wobei : bei dimensions eingefügt wird.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) für eine binäre Baumstruktur schedule, wobei Folgendes gilt:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule ist ein implementierungsdefinierter vollständiger binärer Baum, dessen Durchlauf in einer bestimmten Reihenfolge aus folgenden Elementen besteht:
    • input_slices_converted...[index]-Werte für alle index in index_space(input_slices_converted) in aufsteigender lexikografischer Reihenfolge von index.
    • Wird mit einer implementierungsdefinierten Menge von init_values_converted an implementierungsdefinierten Positionen eingefügt.

Eingaben

Label Name Typ Einschränkung
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C4), (C6), (C7)
(I2) init_values variadische Anzahl von 0-dimensionalen Tensoren oder quantisierten Tensoren pro Tensor (C2), (C3)
(I3) dimensions Eindimensionale Tensorkonstante vom Typ si64 (C4), (C5), (C7)
(I4) body Funktion (C6)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C3), (C7), (C8)

Einschränkung

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body hat den Typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), wobei is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), außer dass die Dimensionsgrößen von inputs..., die dimensions entsprechen, nicht enthalten sind.
  • (C8) element_type(results[i]) = Ei für alle i in [0,N).

Beispiele

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Weitere Beispiele

reduce_precision

Semantik

Führt die elementweise Umwandlung von operand in einen anderen Gleitkommatyp, der exponent_bits und mantissa_bits verwendet, sowie zurück zum ursprünglichen Gleitkommatyp durch und erzeugt einen output-Tensor.

Formeller:

  • Die Mantissen-Bits des ursprünglichen Werts werden aktualisiert, um den ursprünglichen Wert auf den nächsten Wert zu runden, der mit mantissa_bits unter Verwendung der roundToIntegralTiesToEven-Semantik dargestellt werden kann.
  • Wenn mantissa_bits dann kleiner als die Anzahl der Mantissen-Bits des ursprünglichen Werts ist, werden die Mantissen-Bits auf mantissa_bits gekürzt.
  • Wenn dann die Exponentenbits des Zwischenergebnisses nicht in den von exponent_bits bereitgestellten Bereich passen, läuft das Zwischenergebnis mit dem ursprünglichen Vorzeichen ins Unendlichkeit über oder geht mit dem ursprünglichen Vorzeichen auf null über.
  • Führt für quantisierte Typen den Befehl dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)
(I2) exponent_bits Konstante vom Typ si32 (C2)
(I3) mantissa_bits Konstante vom Typ si32 (C3)

Ausgaben

Name Typ Einschränkung
output Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Beispiele

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Weitere Beispiele

reduce_scatter

Semantik

Führt in jeder Prozessgruppe im StableHLO-Prozessraster die Reduktion mit computations über die Werte des Tensors operand der einzelnen Prozesse durch, teilt das Reduktionsergebnis entlang von scatter_dimension in Teile auf und verteilt die Teile auf die Prozesse, um die result zu erzeugen.

Der Vorgang teilt das StableHLO-Prozessraster in process_groups auf, die so definiert ist:

  • cross_replica(replica_groups) wenn channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) wenn channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) wenn channel_id > 0 and use_global_device_ids = true.

Danach geschieht in jedem process_group Folgendes:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] für alle sender in process_group, wobei receiver_index = process_group.index(receiver).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1), (C2), (C7), (C8)
(I2) scatter_dimension Konstante vom Typ si64 (C1), (C2), (C8)
(I3) replica_groups 2-dimensionale Tensorkonstante vom Typ si64 (C3–C5)
(I4) channel_id Konstante vom Typ si64 (C6)
(I5) use_global_device_ids Konstante vom Typ i1 (C6)
(I6) computation Funktion (C7)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C8–C9)

Einschränkung

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) ist so definiert:
    • num_replicas, wenn cross_replica verwendet wird.
    • num_replicas, wenn cross_replica_and_partition verwendet wird.
    • num_processes, wenn flattened_ids verwendet wird.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Wenn use_global_device_ids = true, dann channel_id > 0.
  • (C7) computation hat den Typ (tensor<E>, tensor<E>) -> (tensor<E>), wobei is_promotable(element_type(operand), E) ist.
  • (C8) shape(result) = shape(operand) außer:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Beispiele

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Weitere Beispiele

reduce_window

Semantik

Wendet die Reduzierungsfunktion body auf die Fenster von inputs und init_values an und erzeugt results.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in results... aus inputs... berechnet werden.

Formaler gilt für results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (siehe reduzieren). Dabei gilt:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Eingaben

Label Name Typ Einschränkung
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values variadische Anzahl von 0-dimensionalen Tensoren oder quantisierten Tensoren pro Tensor (C1), (C13)
(I3) window_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C4), (C5), (C15)
(I4) window_strides Eindimensionale Tensorkonstante vom Typ si64 (C6), (C7), (C15)
(I5) base_dilations Eindimensionale Tensorkonstante vom Typ si64 (C8), (C9), (C15)
(I6) window_dilations Eindimensionale Tensorkonstante vom Typ si64 (C10), (C11), (C15)
(I7) padding 2-dimensionale Tensorkonstante vom Typ si64 (C12), (C15)
(I8) body Funktion (C13)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1), (C14–C16)

Einschränkung

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body hat den Typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), wobei is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows, wobei:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei für alle i in [0,N).

Beispiele

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Weitere Beispiele

Rest

Semantik

Führt den elementweisen Rest der Tensoren lhs und Divisor rhs aus und erzeugt einen result-Tensor.

Formaler wird das Vorzeichen des Ergebnisses vom Dividenden entnommen, und der absolute Wert des Ergebnisses ist immer kleiner als der absolute Wert des Divisors. Der Rest wird als lhs - d * rhs berechnet, wobei d durch Folgendes gegeben ist:

  • Für Ganzzahlen: stablehlo.divide(lhs, rhs).
  • Für Gleitkommazahlen: division(lhs, rhs) aus IEEE-754 mit Rundungsattribut roundTowardZero.
  • Für komplexe Zahlen: TBD (#997).
  • Für quantisierte Typen:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Bei Gleitkomma-Elementtypen steht dieser Vorgang im Gegensatz zum remainder-Vorgang aus der IEEE-754-Spezifikation, bei dem d ein ganzzahliger Wert ist, der dem exakten Wert von lhs/rhs am nächsten ist, aber gleichbedeutend mit einem geraden Wert ist.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)
(I2) rhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Weitere Beispiele

replica_id

Semantik

Erzeugt replica_id des aktuellen Prozesses.

Ausgaben

Name Typ
result 0-dimensionaler Tensor vom Typ ui32

Beispiele

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Weitere Beispiele

Form ändern

Semantik

Führt die Umformung des operand-Tensors in einen result-Tensor durch. Konzeptionell geht es darum, dieselbe kanonische Darstellung beizubehalten, aber möglicherweise die Form zu ändern, z.B. von tensor<2x3xf32> in tensor<3x2xf32> oder tensor<6xf32>.

Formell weiter ist result[result_index] = operand[operand_index], wobei result_index und operand_index dieselbe Position in der lexikografischen Reihenfolge von index_space(result) und index_space(operand) haben.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor (C1–C3)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor (C1–C3)

Einschränkung

  • (C1) element_type(result) wird gegeben durch:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), allerdings können sich quantization_dimension(operand) und quantization_dimension(result) unterscheiden.
  • (C2) size(operand) = size(result).
  • (C3) Wenn is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Beispiele

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Weitere Beispiele

reverse

Semantik

Kehrt die Reihenfolge der Elemente in operand entlang der angegebenen dimensions um und erzeugt einen result-Tensor. Formaler gesagt, result[result_index] = operand[operand_index], wobei:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 wenn d in dimensions ist.
  • Andernfalls operand_index[d] = result_index[d].

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1), (C3)
(I2) dimensions Eindimensionale Tensorkonstante vom Typ si64 (C2), (C3)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1), (C3)

Einschränkung

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Beispiele

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Weitere Beispiele

RNG

Semantik

Erzeugt Zufallszahlen mit dem rng_distribution-Algorithmus und erzeugt einen result-Tensor der Form shape.

Bei rng_distribution = UNIFORM werden die Zufallszahlen entsprechend der gleichmäßigen Verteilung über das Intervall [a, b) generiert. Wenn a >= b, ist das Verhalten nicht definiert.

Bei rng_distribution = NORMAL werden die Zufallszahlen entsprechend der Normalverteilung mit einem Mittelwert = a und der Standardabweichung = b generiert. Wenn b < 0, ist das Verhalten nicht definiert.

Die genaue Generierung von Zufallszahlen ist von der Implementierung definiert. Sie können beispielsweise deterministisch sein oder nicht, und sie können einen verborgenen Status verwenden.

In Gesprächen mit vielen Stakeholdern wurde festgestellt, dass diese Operation faktisch verworfen wurde. Daher planen wir, sie in Zukunft zu entfernen (#597).

Eingaben

Label Name Typ Einschränkung
(I1) a 0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolesch“ oder „Gleitkomma“ (C1), (C2)
(I2) b 0-dimensionaler Tensor vom Typ „Ganzzahl“, „Boolesch“ oder „Gleitkomma“ (C1), (C2)
(I3) shape Eindimensionale Tensorkonstante vom Typ si64 (C3)
(I4) rng_distribution Aufzählung von UNIFORM und NORMAL (C2)

Ausgaben

Name Typ Einschränkung
result Tensor eines Ganzzahl-, booleschen oder Gleitkommatyps (C1–C3)

Einschränkung

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Wenn rng_distribution = NORMAL, dann is_float(a).
  • (C3) shape(result) = shape.

Beispiele

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantik

Gibt ein output mit einheitlichen Zufallsbits und einem aktualisierten Ausgabestatus output_state unter Verwendung des Pseudozufallszahlengenerator-Algorithmus rng_algorithm bei einem Ausgangszustand initial_state zurück. Es wird garantiert, dass die Ausgabe eine deterministische Funktion von initial_state ist. Es ist jedoch nicht garantiert, dass sie zwischen Implementierungen deterministisch ist.

rng_algorithm ist einer der folgenden Werte:

  • DEFAULT: Implementierungsdefinierter Algorithmus.
  • THREE_FRY: Implementierungsdefinierte Variante des Threefry-Algorithmus.*
  • PHILOX: Implementierungsdefinierte Variante des Philox-Algorithmus.*

* Siehe Salmon et al. SC 2011. Parallele Zufallszahlen – so einfach wie 1, 2, 3.

Eingaben

Label Name Typ Einschränkung
(I1) rng_algorithm Aufzählung von DEFAULT, THREE_FRY und PHILOX (C2)
(I2) initial_state Eindimensionaler Tensor vom Typ ui64 (C1), (C2)

Ausgaben

Name Typ Einschränkung
output_state Eindimensionaler Tensor vom Typ ui64 (C1)
output Tensor des Ganzzahl- oder Gleitkommatyps

Einschränkung

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) ist so definiert:
    • Implementierung definiert, wenn rng_algorithm = DEFAULT.
    • 2, wenn rng_algorithm = THREE_FRY.
    • 2 oder 3, wenn rng_algorithm = PHILOX.

Beispiele

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantik

Führt eine elementweise Rundung zur nächsten Ganzzahl auf dem Tensor operand durch und hebt die Verknüpfungen von Null auf, um einen result-Tensor zu erzeugen. Implementiert den Vorgang roundToIntegralTiesToAway aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(round_nearest_afz, operand, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Weitere Beispiele

round_nearest_even

Semantik

Führt eine elementweise Rundung zur nächsten Ganzzahl auf und löst die Verbindungen zur geraden Ganzzahl auf dem Tensor operand aus und erzeugt einen result-Tensor. Implementiert den Vorgang roundToIntegralTiesToEven aus der IEEE-754-Spezifikation. Führt für quantisierte Typen den Befehl dequantize_op_quantize(round_nearest_even, operand, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps oder quantisierten Tensor pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Weitere Beispiele

RSS

Semantik

Führt eine elementweise reziproke Quadratwurzeloperation auf dem operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: rSqrt aus IEEE-754.
  • Bei komplexen Zahlen: komplexe reziproke Quadratwurzel.
  • Für quantisierte Typen: dequantize_op_quantize(rsqrt, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Weitere Beispiele

scatter

Semantik

Erzeugt results-Tensoren, die den Tensoren inputs entsprechen, mit der Ausnahme, dass mehrere von scatter_indices angegebene Segmente mithilfe von update_computation mit den Werten updates aktualisiert werden.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in updates... den Elementen in results... zugeordnet werden. Im Diagramm werden einige Beispiele für updates...-Indizes ausgewählt und detailliert erläutert, welchen results...-Indizes sie entsprechen.

Formell für alle update_index in index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index ist so definiert:
    • scatter_indices[si0, ..., :, ..., siN], wobei si einzelne Elemente in update_scatter_index sind und : in den Index index_vector_dim eingefügt wird, wenn index_vector_dim < rank(scatter_indices) ist.
    • Andernfalls [scatter_indices[update_scatter_index]].
  • Für d_input in axes(inputs[0]):
    • full_start_index[d_input] = start_index[d_start], wenn d_input = scatter_dims_to_operand_dims[d_start].
    • Andernfalls full_start_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], wobei wi einzelne Elemente in update_window_index sind und 0 an Indizes von inserted_window_dims eingefügt wird.
  • result_index = full_start_index + full_window_index.

Daher gilt results = exec(schedule, inputs), wobei Folgendes gilt:

  • schedule ist eine implementierungsdefinierte Permutation von index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results), wobei:
    • Wenn result_index innerhalb der Grenzen für shape(results...) liegt
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results ist eine Kopie von results, wobei results...[result_index] auf updated_values... festgelegt ist.
    • Andernfalls:
    • updated_results = results.
  • exec([], results) = results.

Wenn indices_are_sorted den Wert true hat, kann die Implementierung davon ausgehen, dass scatter_indices in Bezug auf scatter_dims_to_operand_dims sortiert sind. Andernfalls ist das Verhalten nicht definiert. Formell: Für alle i1 < i2 aus indices(result) gilt: full_start_index(i1) <= full_start_index(i2).

Wenn unique_indices den Wert true hat, kann die Implementierung davon ausgehen, dass alle result_index-Indizes, auf die verstreut sind, eindeutig sind. Wenn unique_indices den Wert true hat, die Indizes jedoch nicht eindeutig sind, ist das Verhalten nicht definiert.

Eingaben

Label Name Typ Einschränkung
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1), (C2), (C4-C6), (C10), (C13), (C15-C16)
(I2) scatter_indices Tensor des Ganzzahltyps (C4), (C11), (C14)
(I3) updates variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C3–C6), (C8)
(I4) update_window_dims Eindimensionale Tensorkonstante vom Typ si64 (C2), (C4), (C7), (C8)
(I5) inserted_window_dims Eindimensionale Tensorkonstante vom Typ si64 (C2), (C4), (C9), (C10)
(I6) scatter_dims_to_operand_dims Eindimensionale Tensorkonstante vom Typ si64 (C11–C13)
(I7) index_vector_dim Konstante vom Typ si64 (C4), (C11), (C14)
(I8) indices_are_sorted Konstante vom Typ i1
(I9) unique_indices Konstante vom Typ i1
(I10) update_computation Funktion (C15)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C15–C17)

Einschränkung

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes), wobei:
    • update_scatter_dim_sizes = shape(scatter_indices), allerdings ist die Dimensionsgröße von scatter_indices, die index_vector_dim entspricht, nicht enthalten.
    • update_window_dim_sizes <= shape(inputs[0]), außer dass die Dimensionsgrößen in inputs[0], die inserted_window_dims entsprechen, nicht enthalten sind.
    • combine setzt update_scatter_dim_sizes an den Achsen, die update_scatter_dims entsprechen, und update_window_dim_sizes an den Achsen, die update_window_dims entsprechen.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(inserted_window_dims) and is_sorted(update_window_dims).
  • (C10) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C11) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C12) is_unique(scatter_dims_to_operand_dims).
  • (C13) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C14) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C15) update_computation hat den Typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), wobei is_promotable(element_type(inputs[i]), Ei) ist.
  • (C16) shape(inputs...) = shape(results...).
  • (C17) element_type(results[i]) = Ei für alle i in [0,N).

Beispiele

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

Weitere Beispiele

auswählen

Semantik

Erzeugt einen result-Tensor, bei dem jedes Element aus dem on_true- oder on_false-Tensor basierend auf dem Wert des entsprechenden Elements von pred ausgewählt wird. Formeller: result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], wobei pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Führt für quantisierte Typen den Befehl dequantize_select_quantize(pred, on_true, on_false, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) pred Tensor vom Typ i1 (C1)
(I2) on_true Tensor oder quantisierter Tensor pro Tensor (C1–C2)
(I3) on_false Tensor oder quantisierter Tensor pro Tensor (C2)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C2)

Einschränkung

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Beispiele

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Weitere Beispiele

select_and_scatter

Semantik

Streut die Werte vom Tensor source mit scatter anhand des Ergebnisses von reduce_window des Tensors input unter Verwendung von select und erzeugt einen result-Tensor.

Das folgende Diagramm zeigt anhand eines konkreten Beispiels, wie Elemente in result aus operand und source berechnet werden.

Formeller:

  • selected_values = reduce_window_without_init(...) mit den folgenden Eingaben:

    • `inputs = [operand].
    • window_dimensions, window_strides und padding, die unverändert verwendet werden.
    • base_dilations = windows_dilations = 1.
    • body ist definiert als:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    wobei E = element_type(operand) und reduce_window_without_init genau wie reduce_window funktionieren, mit der Ausnahme, dass der schedule der zugrunde liegenden reduce (siehe reduzieren) keine Initialisierungswerte enthält. Derzeit ist nicht festgelegt, was geschieht, wenn das entsprechende Fenster keine Werte enthält (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter), wobei:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index, wenn selected_values[source_index] das Element operand aus operand_index enthält.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1–C4), (C6), (C8–C11)
(I2) source Tensor oder quantisierter Tensor pro Tensor (C1), (C2)
(I3) init_value 0-dimensionaler Tensor oder quantisierter Tensor pro Tensor (C3)
(I4) window_dimensions Eindimensionale Tensorkonstante vom Typ si64 (C2), (C4), (C5)
(I5) window_strides Eindimensionale Tensorkonstante vom Typ si64 (C2), (C6), (C7)
(I6) padding 2-dimensionale Tensorkonstante vom Typ si64 (C2), (C8)
(I7) select Funktion (C9)
(I8) scatter Funktion (C10)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C11–C12)

Einschränkung

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows, wobei:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select hat den Typ (tensor<E>, tensor<E>) -> tensor<i1>, wobei E = element_type(operand) ist.
  • (C10) scatter hat den Typ (tensor<E>, tensor<E>) -> tensor<E>, wobei is_promotable(element_type(operand), E) ist.
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Beispiele

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Weitere Beispiele

Senden

Semantik

Sendet inputs an einen Kanal channel_id und erzeugt ein result-Token.

Wenn is_host_transfer den Wert true hat, werden durch den Vorgang Daten an den Host übertragen. Andernfalls werden die Daten auf ein anderes Gerät übertragen. Was das bedeutet, ist implementierungsdefiniert. Dieses Flag dupliziert die in channel_type bereitgestellten Informationen. Daher planen wir, in Zukunft nur eines davon zu behalten (#666).

Eingaben

Label Name Typ Einschränkung
(I1) inputs variadische Anzahl von Tensoren oder quantisierten Tensoren
(I2) token token
(I3) channel_id Konstante vom Typ si64
(I4) channel_type Aufzählung von DEVICE_TO_DEVICE und DEVICE_TO_HOST (C1)
(I5) is_host_transfer Konstante vom Typ i1 (C1)

Ausgaben

Name Typ
result token

Einschränkung

  • (C1) channel_type ist so definiert:
    • DEVICE_TO_HOST, wenn is_host_transfer = true,
    • Andernfalls DEVICE_TO_DEVICE.

Beispiele

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Weitere Beispiele

shift_left

Semantik

Führt eine elementweise Linksverschiebungsoperation auf dem lhs-Tensor mit der Anzahl von rhs Bits durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor des Ganzzahltyps (C1)
(I2) rhs Tensor des Ganzzahltyps (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Ganzzahltyps (C1)

Einschränkung

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Weitere Beispiele

shift_right_arithmetic

Semantik

Führt eine elementweise arithmetische Rechtsverschiebungsoperation auf dem lhs-Tensor mit der Anzahl von rhs Bits durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor des Ganzzahltyps (C1)
(I2) rhs Tensor des Ganzzahltyps (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Ganzzahltyps (C1)

Einschränkung

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Weitere Beispiele

shift_right_logical

Semantik

Führt eine elementweise logische Rechtsverschiebungsoperation am lhs-Tensor mit der Anzahl von rhs Bits durch und erzeugt einen result-Tensor.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor des Ganzzahltyps (C1)
(I2) rhs Tensor des Ganzzahltyps (C1)

Ausgaben

Name Typ Einschränkung
result Tensor des Ganzzahltyps (C1)

Einschränkung

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Weitere Beispiele

Signieren

Semantik

Gibt das Vorzeichen von operand elementweise zurück und erzeugt einen result-Tensor. Formaler kann die Semantik für jedes Element x mit der Python-Syntax so ausgedrückt werden:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Führt für quantisierte Typen den Befehl dequantize_op_quantize(sign, operand, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor einer vorzeichenbehafteten Ganzzahl, Gleitkommazahl oder eines komplexen Typs oder eines pro Tensor quantisierten Tensors (C1)

Ausgaben

Name Typ Einschränkung
result Tensor einer vorzeichenbehafteten Ganzzahl, Gleitkommazahl oder eines komplexen Typs oder eines pro Tensor quantisierten Tensors (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Weitere Beispiele

Sinus

Semantik

Führt eine elementweise Sinusoperation für den operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: sin aus IEEE-754.
  • Bei komplexen Zahlen: komplexer Sinus.
  • Für quantisierte Typen: dequantize_op_quantize(sine, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Weitere Beispiele

Slice

Semantik

Extrahiert mithilfe von statisch berechneten Startindexen ein Slice aus dem operand und erzeugt einen result-Tensor. start_indices enthält die Startindexe des Segments für jede Dimension, limit_indices die Endindexe (ausschließlich) für das Segment für jede Dimension und strides die Schritte für die einzelnen Dimensionen.

Formeller gesagt: result[result_index] = operand[operand_index], wobei operand_index = start_indices + result_index * strides.

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor pro Tensor (C1–C3), (C5)
(I2) start_indices Eindimensionale Tensorkonstante vom Typ si64 (C2), (C3), (C5)
(I3) limit_indices Eindimensionale Tensorkonstante vom Typ si64 (C2), (C3), (C5)
(I4) strides Eindimensionale Tensorkonstante vom Typ si64 (C2), (C4)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor pro Tensor (C1), (C5)

Einschränkung

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Beispiele

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Weitere Beispiele

sort

Semantik

Sortiert eindimensionale Segmente von inputs entlang der Dimension dimension nach einem comparator zusammen und erstellt results.

Im Gegensatz zu ähnlichen Eingaben in anderen Vorgängen erlaubt dimension mit der unten beschriebenen Semantik negative Werte. Zukünftig kann dies aus Konsistenzgründen nicht zulässig sein (#1377).

Wenn is_stable auf „true“ gesetzt ist, ist die Sortierung stabil, d. h., die relative Reihenfolge der Elemente, die vom Vergleichsoperator als gleich betrachtet werden, wird beibehalten. Wenn eine einzelne Eingabe vorhanden ist, werden die beiden Elemente e1 und e2 vom Vergleichsoperator als gleichbedeutend angesehen, wenn und nur dann comparator(e1, e2) = comparator(e2, e1) = false. In der folgenden Formalisierung wird gezeigt, wie sich dies auf mehrere Eingaben verallgemeinern lässt.

Formell für alle result_index in index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], wobei riN einzelne Elemente in result_index sind und : bei adjusted_dimension eingefügt wird.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • Dabei sortiert sort ein eindimensionales Segment in nicht absteigender Reihenfolge und erwartet, dass comparator_together true zurückgibt, wenn das Argument auf der linken Seite kleiner als das zweite Argument ist.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Eingaben

Label Name Typ Einschränkung
(I1) inputs variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C1–C5)
(I2) dimension Konstante vom Typ si64 (C4)
(I3) is_stable Konstante vom Typ i1
(I4) comparator Funktion (C5)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren oder quantisierte Tensoren pro Tensor (C2), (C3)

Einschränkung

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, wobei R = rank(inputs[0]).
  • (C5) comparator hat den Typ (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, wobei Ei = element_type(inputs[i]) ist.

Beispiele

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Weitere Beispiele

sqrt

Semantik

Führt eine elementweise Quadratwurzeloperation auf dem operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: squareRoot aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Quadratwurzel.
  • Für quantisierte Typen: dequantize_op_quantize(sqrt, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Weitere Beispiele

subtract

Semantik

Führt die elementweise Subtraktion der beiden Tensoren lhs und rhs durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Ganzzahlen: Subtraktion von Ganzzahlen.
  • Für Gleitkommazahlen: subtraction aus IEEE-754.
  • Bei komplexen Zahlen: komplexe Subtraktion.
  • Für quantisierte Typen:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)
(I2) rhs Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Ganzzahl-, Gleitkomma- oder komplexen Typs oder pro Tensor quantisierter Tensor (C1)

Einschränkung

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Beispiele

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Weitere Beispiele

Tanh

Semantik

Führt eine elementweise hyperbolische Tangensoperation am operand-Tensor durch und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Für Gleitkommazahlen: tanh aus IEEE-754.
  • Bei komplexen Zahlen: komplexer hyperbolischer Tangens
  • Für quantisierte Typen:
    • dequantize_op_quantize(tanh, operand, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_type(operand) = baseline_type(result).

Beispiele

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Weitere Beispiele

Transponieren

Semantik

Permutet die Dimensionen des Tensors operand mit permutation und erzeugt einen result-Tensor. Formeller gesagt: result[result_index] = operand[operand_index], wobei result_index[d] = operand_index[permutation[d]].

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor oder quantisierter Tensor (C1–C4)
(I2) permutation Eindimensionale Tensorkonstante vom Typ si64 (C2–C4)

Ausgaben

Name Typ Einschränkung
result Tensor oder quantisierter Tensor (C1), (C3–C4)

Einschränkung

  • (C1) element_type(result) wird gegeben durch:
    • element_type(operand), wenn !is_per_axis_quantized(operand).
    • element_type(operand), allerdings können sich quantization_dimension(operand) und quantization_dimension(result) unterscheiden.
  • (C2) permutation ist eine Permutation von range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Wenn is_per_axis_quantized(result), dann quantization_dimension(operand) = permutation(quantization_dimension(result)).

Beispiele

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Weitere Beispiele

triangular_solve

Semantik

Löst Reihen von linearen Gleichungssystemen mit unteren oder oberen Dreiecksmatrizen.

Formell ist bei a und b result[i0, ..., iR-3, :, :] die Lösung für op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :], wenn left_side true oder x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] ist, wenn left_side false ist. Dabei wird die Variable x ermittelt, wobei op(a) durch transpose_a bestimmt wird. Dabei kann es sich um einen der folgenden Werte handeln:

  • NO_TRANSPOSE: Vorgang mit a unverändert ausführen.
  • TRANSPOSE: Vorgang beim Transponieren von a ausführen.
  • ADJOINT: Eine Operation beim konjugierten Transponieren von a durchführen.

Eingabedaten werden nur aus dem unteren Dreieck von a gelesen, wenn lower gleich true oder dem oberen Dreieck von a ist. Andernfalls werden die Eingabedaten gelesen. Die Ausgabedaten werden im selben Dreieck zurückgegeben. Die Werte im anderen Dreieck sind implementierungsdefiniert.

Wenn unit_diagonal „true“ ist, kann die Implementierung davon ausgehen, dass die diagonalen Elemente von a gleich 1 sind. Andernfalls ist das Verhalten nicht definiert.

Führt für quantisierte Typen den Befehl dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)) aus.

Eingaben

Label Name Typ Einschränkung
(I1) a Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1–C3)
(I2) b Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1–C4)
(I3) left_side Konstante vom Typ i1 (C3)
(I4) lower Konstante vom Typ i1
(I5) unit_diagonal Konstante vom Typ i1
(I6) transpose_a Aufzählung von NO_TRANSPOSE, TRANSPOSE und ADJOINT

Ausgaben

Name Typ Einschränkung
result Tensor eines Gleitkomma- oder komplexen Typs oder quantisierten Tensors pro Tensor (C1)

Einschränkung

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) Die Beziehung zwischen shape(a) und shape(b) ist so definiert:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Beispiele

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Semantik

Erzeugt ein result-Tupel aus den Werten val.

Eingaben

Label Name Typ Einschränkung
(I1) val variadische Anzahl von Werten (C1)

Ausgaben

Name Typ Einschränkung
result tuple (C1)

Einschränkung

  • (C1) result hat den Typ tuple<E0, ..., EN-1>, wobei Ei = type(val[i]).

Beispiele

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Weitere Beispiele

uniform_dequantize

Semantik

Führt die elementweise Konvertierung des quantisierten Tensors operand in einen Gleitkomma-Tensor result gemäß den durch den Typ operand definierten Quantisierungsparametern durch.

Formeller gesagt: result = dequantize(operand).

Eingaben

Label Name Typ Einschränkung
(I1) operand quantisierter Tensor (C1), (C2)

Ausgaben

Name Typ Einschränkung
result Tensor des Gleitkommatyps (C1), (C2)

Einschränkung

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Beispiele

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semantik

Führt eine elementweise Umwandlung des Gleitkommatensors oder des quantisierten Tensors operand in einen quantisierten Tensor result gemäß den durch den Typ result definierten Quantisierungsparametern durch.

Formell

  • Wenn is_float(operand):
    • result = quantize(operand, type(result)).
  • Wenn is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Eingaben

Label Name Typ Einschränkung
(I1) operand Tensor des Gleitkomma- oder quantisierten Typs (C1), (C2)

Ausgaben

Name Typ Einschränkung
result quantisierter Tensor (C1), (C2)

Einschränkung

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Beispiele

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

während

Semantik

Erzeugt die Ausgabe, indem die Funktion body null oder mehr Mal ausgeführt wird, während die Funktion cond den Wert true ausgibt. Formell kann die Semantik mit der Python-Syntax folgendermaßen ausgedrückt werden:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Das Verhalten einer Endlosschleife ist noch nicht festgelegt (#383).

Eingaben

Label Name Typ Einschränkung
(I1) operand variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C1–C3)
(I2) cond Funktion (C1)
(I3) body Funktion (C2)

Ausgaben

Name Typ Einschränkung
results variadische Anzahl von Tensoren, quantisierten Tensoren oder Tokens (C3)

Einschränkung

  • (C1) cond hat den Typ (T0, ..., TN-1) -> tensor<i1>, wobei Ti = type(operand[i]).
  • (C2) body hat den Typ (T0, ..., TN-1) -> (T0, ..., TN-1), wobei Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Beispiele

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Weitere Beispiele

Xor

Semantik

Führt ein elementweises XOR der beiden Tensoren lhs und rhs aus und erzeugt einen result-Tensor. Folgendes wird je nach Elementtyp ausgeführt:

  • Bei booleschen Werten: logisches XOR.
  • Für Ganzzahlen: bitweises XOR.

Eingaben

Label Name Typ Einschränkung
(I1) lhs Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)
(I2) rhs Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)

Ausgaben

Name Typ Einschränkung
result Tensor vom Typ „Boolescher Wert“ oder „Integer“ (C1)

Einschränkung

  • (C1) type(lhs) = type(rhs) = type(result).

Beispiele

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Umsetzung

Sequenzielle Ausführung

Ein StableHLO-Programm wird ausgeführt, indem Eingabewerte für die main-Funktion bereitgestellt und Ausgabewerte berechnet werden. Ausgabewerte einer Funktion werden durch Ausführen des Graphen der Vorgänge berechnet, die im entsprechenden return-Vorgang verwurzelt sind.

Die Ausführungsreihenfolge ist implementierungsdefiniert, solange sie am Datenfluss ausgerichtet ist, d.h. wenn Vorgänge vor ihrer Verwendung ausgeführt werden. In StableHLO verbrauchen alle Nebenwirkungen ein Token und erzeugen ein Token (mehrere Tokens können über after_all zu einem Token zusammengefasst werden), sodass die Ausführungsreihenfolge von Nebeneffekten ebenfalls auf Dataflow abgestimmt ist. Mögliche Ausführungsreihenfolgen für das obige Beispielprogramm sind %0%1%2%3%4return oder %3%0%1%2%4return.

Formal ist ein StableHLO-Prozess eine Kombination aus: 1) einem StableHLO-Programm, 2) Vorgangsstatus (noch nicht ausgeführt, bereits ausgeführt) und 3) Zwischenwerten, an denen der Prozess arbeitet. Der Prozess beginnt mit Eingabewerten für die Funktion main, durchläuft die Grafik der Vorgänge, die Vorgangsstatus und Zwischenwerte aktualisiert, und endet mit Ausgabewerten. Die weitere Formalisierung steht noch nicht fest (#484).

Parallele Ausführung

StableHLO-Programme können parallel ausgeführt werden, organisiert in einem 2D-Prozessraster aus num_replicas nach num_partitions, die beide den Typ ui32 haben.

Im StableHLO-Prozessraster werden num_replicas * num_partitions der StableHLO-Prozesse gleichzeitig ausgeführt. Jeder Prozess hat eine eindeutige process_id = (replica_id, partition_id), wobei replica_id in replica_ids = range(num_replicas) und partition_id in partition_ids = range(num_partitions) den Typ ui32 haben.

Die Größe des Prozessrasters ist für jedes Programm statisch bekannt (in Zukunft planen wir, es zu einem expliziten Teil der StableHLO-Programme #650 zu machen) und die Position innerhalb des Prozessrasters ist für jeden Prozess statisch bekannt. Jeder Prozess hat über die Operationen replica_id und partition_id Zugriff auf seine Position im Prozessraster.

Innerhalb des Prozessrasters können die Programme gleich sein (im Stil „Einzelnes Programm, mehrere Daten“), unterschiedlich sein (im Stil „Mehrere Programme, mehrere Daten“) oder etwas dazwischen. Für die Zukunft planen wir, Unterstützung für andere Redewendungen zur Definition paralleler StableHLO-Programme anzubieten, einschließlich GSPMD (#619).

Innerhalb des Prozessrasters sind die Prozesse größtenteils unabhängig voneinander – sie haben unterschiedliche Betriebsstatus sowie separate Eingabe-/Zwischen-/Ausgabewerte und die meisten Vorgänge werden getrennt zwischen Prozessen ausgeführt, mit Ausnahme einer kleinen Anzahl von kollektiven Vorgängen, die unten beschrieben werden.

Da die Ausführung der meisten Vorgänge nur Werte aus demselben Prozess verwendet, ist es normalerweise eindeutig, mit ihren Namen auf diese Werte zu verweisen. Bei der Beschreibung der Semantik kollektiver Vorgänge ist dies jedoch nicht ausreichend. Daher kann die Notation name@process_id in einem bestimmten Prozess auf den Wert name verweisen. Aus dieser Sicht kann nicht qualifizierte name als Abkürzung für name@(replica_id(), partition_id()) betrachtet werden.

Die prozessübergreifende Ausführungsreihenfolge ist von der Implementierung definiert, mit Ausnahme der Synchronisierung, die durch die Punkt-zu-Punkt-Kommunikation und kollektive Vorgänge wie unten beschrieben eingeführt wird.

Punkt-zu-Punkt-Kommunikation

Stabile HLO-Prozesse können über StableHLO-Kanäle miteinander kommunizieren. Ein Kanal wird durch eine positive ID vom Typ si64 dargestellt. Über verschiedene Vorgänge ist es möglich, Werte an Kanäle zu senden und von Kanälen zu empfangen.

Eine weitere Formulierung, z.B. woher diese Kanal-IDs stammen, wie Prozesse von Programmen erkannt werden und welche Art von Synchronisierung durch sie eingeleitet wird, ist noch nicht festgelegt (#484).

Streamingkommunikation

Jeder StableHLO-Prozess hat Zugriff auf zwei Streaming-Schnittstellen:

  • InFeed, aus dem gelesen werden kann.
  • Outfeed (Ausgang), in den geschrieben werden kann.

Im Gegensatz zu Channels, die für die Kommunikation zwischen Prozessen verwendet werden und daher Prozesse an beiden Enden enthalten, ist bei Infeeds und Outfeeds die jeweils andere Implementierung definiert.

Eine weitere Formalisierung, z.B. wie sich die Streamingkommunikation auf die Ausführungsreihenfolge auswirkt und welche Art von Synchronisierung damit eingeleitet wird, steht noch nicht fest (#484).

Kollektive Vorgänge

In StableHLO gibt es sechs gemeinsame Vorgänge: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute und reduce_scatter. Alle diese Vorgänge teilen die Prozesse im StableHLO-Prozessraster in StableHLO-Prozessgruppen auf und führen eine gemeinsame Berechnung innerhalb jeder Prozessgruppe aus, unabhängig von anderen Prozessgruppen.

Innerhalb jeder Prozessgruppe können kollektive Vorgänge eine Synchronisierungsbarriere darstellen. Eine weitere Formulierung, die z.B. erläutert wird, wann genau diese Synchronisierung stattfindet, wie genau die Prozesse zu dieser Barriere gelangen und was passiert, wenn dies nicht der Fall ist, steht noch nicht fest (#484).

Wenn die Prozessgruppe partitionübergreifende Kommunikation umfasst, d.h., wenn sich in der Prozessgruppe Prozesse mit unterschiedlichen Partitions-IDs befinden, benötigt die Ausführung des gemeinsamen Vorgangs einen Kanal und der kollektive Vorgang muss eine positive channel_id vom Typ si64 bereitstellen. Für die replizübergreifende Kommunikation sind keine Kanäle erforderlich.

Die von den kollektiven Vorgängen durchgeführten Berechnungen sind für einzelne Operationen spezifisch und werden in den einzelnen Vorgangsabschnitten oben beschrieben. Die Strategien, nach denen das Prozessraster in Prozessgruppen aufgeteilt wird, sind jedoch auf diese Vorgänge verteilt und werden in diesem Abschnitt beschrieben. Formal unterstützt StableHLO die folgenden vier Strategien.

cross_replica

Nur die replizübergreifende Kommunikation findet innerhalb jeder Prozessgruppe statt. Diese Strategie verwendet replica_groups – eine Liste von Listen von Replikat-IDs – und berechnet ein kartesisches Produkt aus replica_groups nach partition_ids. replica_groups muss eindeutige Elemente haben und alle replica_ids abdecken. Formeller mithilfe der Python-Syntax:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Für replica_groups = [[0, 1], [2, 3]] und num_partitions = 2 erzeugt cross_replica beispielsweise [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Nur die partitionübergreifende Kommunikation innerhalb jeder Prozessgruppe. Diese Strategie verwendet partition_groups – eine Liste mit Listen von Partitions-IDs – und berechnet ein kartesisches Produkt aus partition_groups nach replica_ids. partition_groups muss eindeutige Elemente haben und alle partition_ids abdecken. Formell unter Verwendung der Python-Syntax:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Für partition_groups = [[0, 1]] und num_replicas = 4 erzeugt cross_partition beispielsweise [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Sowohl die Replikat- als auch die partitionübergreifende Kommunikation kann innerhalb jeder Prozessgruppe erfolgen. Diese Strategie verwendet replica_groups (eine Liste mit Listen von Replikat-IDs) und berechnet kartesische Produkte jeder replica_group nach partition_ids. replica_groups muss eindeutige Elemente haben und alle replica_ids abdecken. Formell unter Verwendung der Python-Syntax:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Für replica_groups = [[0, 1], [2, 3]] und num_partitions = 2 erzeugt cross_replica_and_partition beispielsweise [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Diese Strategie nimmt flattened_id_groups – eine Liste von Listen mit vereinfachten Prozess-IDs im Format replica_id * num_partitions + partition_id – und wandelt sie in Prozess-IDs um. flattened_id_groups muss eindeutige Elemente haben und alle process_ids abdecken. Formell unter Verwendung der Python-Syntax:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Für flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 und num_partitions = 2 erzeugt flattened_ids beispielsweise [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Genauigkeit

Derzeit bietet StableHLO keine Garantien für die numerische Genauigkeit. Dies kann sich jedoch in Zukunft ändern (#1156).

Fehler

StableHLO-Programme werden durch eine umfassende Reihe von Einschränkungen für einzelne Vorgänge validiert, wodurch viele Fehlerklassen vor der Laufzeit ausgeschlossen werden. Fehlerbedingungen sind jedoch weiterhin möglich, z.B. durch Ganzzahlüberläufe, Zugriffe außerhalb des Bereichs usw. Sofern nicht ausdrücklich anders angegeben, führen alle diese Fehler zu einem implementierungsdefinierten Verhalten, das sich in Zukunft jedoch ändern kann (#1157).

Als Ausnahme von dieser Regel haben Gleitkommaausnahmen in StableHLO-Programmen ein klar definiertes Verhalten. Vorgänge, die zu Ausnahmen führen, die durch den IEEE-754-Standard definiert sind (ungültige Vorgänge, Division durch Null, Überlauf, Unterlauf oder ungenaue Ausnahmen), erzeugen Standardergebnisse (wie im Standard definiert) und werden fortgesetzt, ohne das entsprechende Status-Flag zu setzen, ähnlich wie bei der raiseNoFlag-Ausnahmebehandlung des Standards. Ausnahmen für nicht standardmäßige Operationen (z.B. komplexe arithmetische und bestimmte transzendentale Funktionen) sind implementierungsdefiniert.

Notation

Zur Beschreibung der Syntax wird in diesem Dokument die modifizierte ISO-Variante der EBNF-Syntax (ISO/IEC 14977:1996, Wikipedia) mit zwei Änderungen verwendet: 1) Regeln werden mit ::= statt mit = definiert.

2) Die Verkettung wird durch Gegenüberstellung anstelle von , ausgedrückt.

Zur Beschreibung der Semantik (d.h. in den Abschnitten „Typen“, „Konstanten“ und „Ops“) verwenden wir Formeln, die auf Python-Syntax basieren und mit Unterstützung für das prägnante Ausdrucken von Arrayvorgängen wie unten beschrieben unterstützt werden. Dies funktioniert gut bei kleinen Code-Snippets. Wenn jedoch größere Code-Snippets erforderlich sind, verwenden wir in seltenen Fällen eine einfache Python-Syntax, die immer explizit eingeführt wird.

Formeln

Sehen wir uns anhand eines Beispiels aus der Spezifikation dot_general an, wie Formeln funktionieren. Eine der Einschränkungen für diesen Vorgang sieht so aus: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Die in dieser Formel verwendeten Namen stammen aus zwei Quellen: 1) globale Funktionen, d.h. dim, 2) Mitgliedsdefinitionen des entsprechenden Programmelements, d.h. Eingaben lhs, lhs_batching_dimensions, rhs und rhs_batching_dimensions, die im Abschnitt „Eingaben“ von dot_general definiert sind.

Wie oben erwähnt, basiert die Syntax dieser Formel auf Python und bietet einige Erweiterungen, die auf Genauigkeit ausgerichtet sind. Um die Formel verständlich zu machen, wandeln wir sie in einfache Python-Syntax um.

A) In diesen Formeln verwenden wir =, um Gleichheit darzustellen. Der erste Schritt zum Abrufen der Python-Syntax besteht also darin, = durch == zu ersetzen: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Diese Formeln unterstützen außerdem Ellipsen (...), die skalare Ausdrücke in Tensorausdrücke umwandeln. Kurz gesagt bedeutet f(xs...) ungefähr „für jede skalare x im Tensor xs, einen skalaren f(x) zu berechnen und dann alle diese skalaren Ergebnisse zusammen als Tensorergebnis zurückzugeben“. In der einfachen Python-Syntax wird aus unserer Beispielformel [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Dank Ellipsen ist es oft möglich, die Arbeit auf der Ebene einzelner Skalare zu vermeiden. In einigen kniffligen Fällen kann jedoch eine halbinformelle Syntax mit niedrigerer Stufe wie in der start_indices[bi0, ..., :, ..., biN]-Formel aus der gather-Spezifikation verwendet werden. Im Sinne der Präzision bieten wir keinen exakten Formalismus für die Übersetzung einer solchen Syntax in einfaches Python, in der Hoffnung, dass sie von Fall zu Fall dennoch intuitiv verständlich ist. Bitte teilen Sie uns mit, wenn bestimmte Formeln undurchsichtig erscheinen. Wir werden versuchen, sie zu verbessern.

Sie werden auch feststellen, dass Formeln mithilfe von Ellipsen alle Arten von Listen erweitern, einschließlich Tensoren, Listen von Tensoren (die z.B. aus einer variierenden Anzahl von Tensoren entstehen können) usw. Dies ist ein weiterer Bereich, in dem wir keine exakte Formalität angeben (z.B. Listen sind nicht einmal Teil des StableHLO-Typsystems) und basieren stattdessen auf intuitiver Verständlichkeit.

C) Das letzte nennenswerte Instrument, das wir verwenden, ist die implizite Übertragung. Das StableHLO-Opset unterstützt zwar kein implizites Broadcasting, die Formeln unterstützen dies aber auch im Rahmen der Präzision. Kurz gesagt: Wenn ein Skalar in einem Kontext verwendet wird, in dem ein Tensor erwartet wird, wird der Skalar in der erwarteten Form übertragen.

Um das Beispiel dot_general fortzusetzen, gibt es eine weitere Einschränkung: 0 <= lhs_batching_dimensions < rank(lhs). Wie in der dot_general-Spezifikation definiert, ist lhs_batching_dimensions ein Tensor. Sowohl 0 als auch rank(lhs) sind jedoch Skalare. Nach Anwendung des impliziten Broadcasts wird die Formel zu [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Bei Anwendung auf einen bestimmten dot_general-Vorgang wird diese Formel als boolescher Tensor ausgewertet. Wenn Formeln als Einschränkungen verwendet werden, gilt die Einschränkung, ob die Formel entweder true oder einen Tensor auswertet, der nur true-Elemente hat.

Namen

In Formeln umfasst der lexikalische Geltungsbereich: 1) globale Funktionen, 2) Mitgliederdefinitionen,

3) lokale Definitionen. Die Liste der globalen Funktionen finden Sie unten. Die Liste der Elementdefinitionen hängt vom Programmelement ab, auf das die Notation angewendet wird:

  • Bei Vorgängen umfassen Mitgliederdefinitionen Namen, die in den Abschnitten "Eingaben" und "Ausgaben" eingeführt wurden.
  • Für alles andere enthalten Mitgliedsdefinitionen strukturelle Teile des Programmelements, die nach den entsprechenden EBNF-Nicht-Terminals benannt sind. In den meisten Fällen werden die Namen dieser Strukturteile durch Konvertierung der Namen der Nicht-Terminals in Snake Case (z. B. IntegerLiteral => integer_literal) abgerufen. Manchmal werden Namen jedoch während des Prozesses abgekürzt (z. B. QuantizationStorageType => storage_type). In diesem Fall werden die Namen in den Vorgangsspezifikationen ausdrücklich ähnlich wie die Abschnitte „Inputs“ und „Outputs“ eingeführt.
  • Außerdem enthalten Mitgliedsdefinitionen immer self, um auf das entsprechende Programmelement zu verweisen.

Werte

Wenn Formeln ausgewertet werden, funktionieren sie mit den folgenden Wertetypen: 1) Value (tatsächliche Werte, z.B. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; sie kennen immer ihren Typ), 2) Placeholder (zukünftige Werte, z.B. lhs, rhs oder result; ihre tatsächlichen Werte sind noch nicht bekannt, nur ihre Typen sind bekannt), 3) Type (Typen wie im Abschnitt „Typen“ definiert), 4) Function (globale Funktionen wie im Abschnitt „Funktionen“ definiert).

Je nach Kontext können sich Namen auf verschiedene Werte beziehen. Genauer gesagt definiert der Abschnitt „Semantik“ für Vorgänge (und Äquivalente für andere Programmelemente) Laufzeitlogik, sodass alle Eingaben als Value verfügbar sind. Im Gegensatz dazu wird im Abschnitt „Einschränkungen“ für Vorgänge (und Äquivalente) die Logik für die Kompilierungszeit definiert, d.h. etwas, das normalerweise vor der Laufzeit ausgeführt wird, sodass nur konstante Eingaben als Value und andere Eingaben nur als Placeholder verfügbar sind.

Namen In „Semantik“ Unter „Einschränkungen“
Globale Funktionen Function Function
Konstante Eingaben Value Value
Nicht konstante Eingaben Value Placeholder
Ausgaben Value Placeholder
Lokale Definitionen Hängt von der Definition ab Hängt von der Definition ab

Sehen wir uns einen transpose-Beispielvorgang an:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Für diesen Vorgang ist permutation eine Konstante und somit sowohl in Bezug auf Semantik als auch Einschränkungen als Value verfügbar. Im Gegensatz dazu sind operand und result in der Semantik als Value verfügbar, in Einschränkungen jedoch nur als Placeholder.

Funktionen

Konstruktion von Typen

Es gibt keine Funktionen, die zum Erstellen von Typen verwendet werden können. Stattdessen verwenden wir direkt die Typsyntax, da sie normalerweise prägnanter ist. Beispiel: (tensor<E>, tensor<E>) -> (tensor<E>) statt function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Funktionen für Typen

  • element_type ist für Tensortypen und quantisierte Tensortypen definiert und gibt jeweils den TensorElementType- oder QuantizedTensorElementType-Teil der entsprechenden TensorType oder QuantizedTensorType zurück.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value ist ein Kurzbefehl für is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value ist ein Kürzel für is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool prüft, ob der Typ x zum Typ y hochgestuft werden kann. Wenn x und y den Wert QuantizedTensorElementType haben, wird das Angebot nur auf storage_type angewendet. Diese spezielle Version der Hochstufung wird derzeit im Zusammenhang mit der Berechnung von Reduzierungen verwendet (weitere Informationen finden Sie unter RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value ist eine Kurzform für is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Verfügbar für alle Typen. Beispiel: is_float(x) gibt true zurück, wenn x ein FloatType ist. Wenn x ein Wert oder Platzhalter ist, ist diese Funktion eine Kurzform für is_type_name(type(x)).

  • max_value(x: Type) -> Value gibt den Maximalwert von TensorElementType zurück. Wenn x kein TensorElementType ist, wird None zurückgegeben.

  • min_value(x: Type) -> Value gibt den kleinstmöglichen Wert von TensorElementType zurück. Wenn x kein TensorElementType ist, wird None zurückgegeben.

  • member_name(x: Value | Placeholder | Type) -> Any: Verfügbar für alle Mitgliedsdefinitionen member_name aller Typen. tensor_element_type(x) gibt beispielsweise den TensorElementType-Teil einer entsprechenden TensorType zurück. Wenn x ein Wert oder Platzhalter ist, ist diese Funktion eine Kurzform für member_name(type(x)). Wenn x kein Typ mit einem geeigneten Mitglied oder einem Wert oder Platzhalter dieses Typs ist, wird None zurückgegeben.

Konstruktion von Werten

  • operation_name(*xs: Value | Type) -> Value. Für alle Vorgänge verfügbar. add(lhs, rhs) nimmt beispielsweise die beiden Tensorwerte lhs und rhs und gibt die Ausgabe der Auswertung des Vorgangs add mit diesen Eingaben zurück. Bei einigen Vorgängen, z.B. broadcast_in_dim, sind die Ausgabentypen „lastig“, d.h. sie werden benötigt, um einen Vorgang zu bewerten. In diesem Fall übernimmt die Funktion diese Typen als Argumente.

Funktion für Werte

  • Alle Operatoren und Funktionen von Python sind verfügbar. Beispielsweise stehen sowohl subscription- als auch Slicing-Notationen aus Python zur Indexierung in Tensoren, quantisierten Tensoren und Tupeln zur Verfügung.

  • to_destination_type(x: Value, destination_type: Type) -> Value ist auf Tensoren definiert und gibt den konvertierten Wert von x basierend auf type(x) und destination_type so zurück:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Es gibt bereits eine Diskussion über das Zusammenführen von convert-, uniform_quantize- und uniform_dequantize-Vorgängen (#1576). Nach der Zusammenführung benötigen wir die obige Funktion nicht und können stattdessen den Vorgangsnamen für convert verwenden.

  • is_nan(x: Value) -> Value ist auf Tensoren definiert und gibt true zurück, wenn alle Elemente von x NaN sind, andernfalls false. Wenn x kein Tensor ist, wird None zurückgegeben.

  • is_sorted(x: Value) -> Value ist auf Tensoren definiert und gibt true zurück, wenn Elemente von x in aufsteigender Reihenfolge in Bezug auf die aufsteigende lexikografische Reihenfolge ihrer Indexe sortiert werden, andernfalls nach false. Wenn x kein Tensor ist, wird None zurückgegeben.

  • is_unique(x: Value) -> Value ist auf Tensoren definiert und gibt true zurück, wenn x keine doppelten Elemente oder sonst false hat. Wenn x kein Tensor ist, wird None zurückgegeben.

  • member_name(x: Value) -> Any ist für alle Mitgliedsdefinitionen member_name aller Werte definiert. real_part(x) gibt beispielsweise den RealPart-Teil einer entsprechenden ComplexConstant zurück. Wenn x kein Wert für ein geeignetes Mitglied ist, wird None zurückgegeben.

  • same(x: Value) -> Value ist auf Tensoren definiert und gibt true zurück, wenn die Elemente von x alle gleich sind, oder andernfalls false. Wenn der Tensor keine Elemente hat, zählt dies als „alle gleich“, d.h., die Funktion gibt true zurück. Wenn x kein Tensor ist, wird None zurückgegeben.

  • split(x: Value, num_results: Value, axis: Value) -> Value ist auf Tensoren definiert und gibt num_results-Segmente von x entlang der axis-Achse zurück. Wenn x weder ein Tensor noch dim(x, axis) % num_results != 0 ist, wird None zurückgegeben.

Formberechnungen

  • axes(x: Value | Placeholder | Type) -> Value ist eine Kurzform für range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value ist eine Kurzform für shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List ist eine Kurzform für list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value ist auf Tensoren definiert und gibt size(x)-Indizes für die entsprechende TensorType zurück, sortiert in aufsteigender lexikografischer Reihenfolge, z.B. [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Wenn x kein Tensortyp, kein quantisierter Tensortyp oder ein Wert oder Platzhalter eines dieser Typen ist, wird None zurückgegeben.

  • rank(x: Value | Placeholder | Type) -> Value ist eine Kurzform für size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value wird im Abschnitt „Funktionen für Typen“ über member_name definiert.

  • size(x: Value | Placeholder | Type) -> Value ist eine Kurzform für reduce(lambda x, y: x * y, shape(x)).

Quantisierungsberechnungen

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type ist ein Kürzel für element_type(baseline_type(x)).

  • baseline_type wird für Tensortypen und quantisierte Tensortypen definiert und transformiert sie in eine "Basislinie", d.h. einen Typ mit der gleichen Form, aber mit den Quantisierungsparametern des Elementtyps auf Standardwerte zurückgesetzt. Dies wird als praktischer Trick verwendet, um sowohl Tensor- als auch quantisierte Tensortypen einheitlich zu vergleichen, was recht häufig erforderlich ist. Bei quantisierten Typen ermöglicht dies, dass Vergleichstypen die Quantisierungsparameter ignorieren. Das heißt, shape, storage_type, expressed_type, storage_min, storage_max und quantization_dimension (für quantisierten Typ pro Achse) müssen alle übereinstimmen, scales und zero points können jedoch abweichen.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize wird für quantisierte Tensortypen definiert und wandelt sie in Gleitkomma-Tensortypen um. Dazu werden quantisierte Elemente, die Ganzzahlwerte des Speichertyps darstellen, in entsprechende Gleitkommawerte des Ausdruckstyps konvertiert. Dabei werden der Nullpunkt und die Skalierung verwendet, die dem quantisierten Elementtyp zugeordnet sind.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize wird für Gleitkomma-Tensortypen definiert und wandelt sie in quantisierte Tensortypen um. Dazu werden Gleitkommawerte des ausgedrückten Typs mithilfe des Nullpunkts und der Skalierung, die dem quantisierten Elementtyp zugeordnet sind, in entsprechende Ganzzahlwerte des Speichertyps konvertiert.
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize wird verwendet, um elementweise Berechnungen für quantisierte Tensoren anzugeben. Sie dequantisiert, d.h. quantisiert quantisierte Elemente, wandelt sie also in ihre Ausdruckstypen um, führt dann einen Vorgang aus und quantisiert dann die Ergebnisse, d.h., wandelt die Ergebnisse wieder in ihren Speichertyp um. Derzeit funktioniert diese Funktion nur für die Quantisierung pro Tensor. Die Quantisierung pro Achse wird ausgeführt (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

Rasterberechnungen

  • cross_partition(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt "cross_ Nachbau".

  • cross_replica(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt "cross_ Nachbau".

  • cross_replica_and_partition(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt "cross_replica_and_partition".

  • flattened_ids(replica_groups: Value) -> Value. Weitere Informationen finden Sie oben im Abschnitt "flattened_ids".