Specyfikacja StableHLO

StableHLO to zestaw operacji dla operacji wysokiego poziomu (HLO) w modelach uczenia maszynowego. StableHLO działa jako warstwa przenośności między różnymi platformami ML i kompilatorami ML: platformy ML, które generują programy StableHLO, są zgodne z kompilatorami ML, które je wykorzystują.

Naszym celem jest uproszczenie i przyspieszenie rozwoju ML poprzez zwiększenie interoperacyjności między różnymi platformami ML (takimi jak TensorFlow, JAX i PyTorch) oraz kompilatorami ML (takimi jak XLA i IREE). W tym celu w tym dokumencie podajemy specyfikację języka programowania StableHLO.

Specyfikacja ta składa się z 3 głównych sekcji. Po pierwsze, sekcja Programy opisuje strukturę programów StableHLO, które składają się z funkcji StableHLO, a te z kolei z operacji StableHLO. W tej strukturze sekcja Ops określa semantykę poszczególnych operacji. Sekcja Wykonanie zawiera semantykę wszystkich tych operacji wykonywanych razem w programie. W sekcji Notacja omówiono notację używaną w całej specyfikacji.

Aby wyświetlić specyfikację z poprzedniej wersji StableHLO, otwórz repozytorium w interesującej Cię wersji z tagiem. Na przykład specyfikacja StableHLO w wersji 0.19.0. Aby zobaczyć zmiany, które nastąpiły w poszczególnych wersjach StableHLO, zapoznaj się z dziennikiem wersji w pliku VhloDialect.td.

Programy

Program ::= {Func}

Programy StableHLO składają się z dowolnej liczby funkcji StableHLO. Poniżej znajduje się przykładowy program z funkcją @main, która ma 3 dane wejściowe (%image, %weights%bias) i 1 dane wyjściowe. Treść funkcji zawiera 6 operacji.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Funkcje

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Funkcje StableHLO (nazywane też funkcjami nazwanymi) mają identyfikator, dane wejściowe i wyjściowe oraz treść. W przyszłości planujemy wprowadzić dodatkowe metadane funkcji, aby zapewnić lepszą zgodność z HLO (#425, #626, #740, #744).

Identyfikatory

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Identyfikatory StableHLO są podobne do identyfikatorów w wielu językach programowania, ale mają 2 cechy szczególne: 1) wszystkie identyfikatory mają symbole, które odróżniają różne rodzaje identyfikatorów, 2) identyfikatory wartości mogą być w całości numeryczne, aby uprościć generowanie programów StableHLO.

Typy

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Typy StableHLO dzielą się na typy wartości (nazywane też typami pierwszego rzędu), które reprezentują wartości StableHLO, oraz typy inne niż wartości, które opisują inne elementy programu. Typy StableHLO są podobne do typów w wielu językach programowania, a ich główną cechą jest specyfika StableHLO, która powoduje pewne nietypowe wyniki (np. typy skalarne nie są typami wartości).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Typy tensorów reprezentują tensory, czyli wielowymiarowe tablice. Mają one kształt i typ elementu, gdzie kształt reprezentuje nieujemne lub nieznane rozmiary wymiarów w kolejności rosnącej odpowiednich wymiarów (zwanych też osiami) numerowanych od 0 do R-1. Liczba wymiarów R jest nazywana rangą. Na przykład tensor<2x3xf32> to typ tensora o kształcie 2x3 i typie elementu f32. Ma 2 wymiary (czyli 2 osie) – wymiar 0 i wymiar 1 – o rozmiarach 2 i 3. Jego pozycja to 2.

Kształty mogą być częściowo lub całkowicie nieznane (dynamiczne), np. tensor<?x2xf64> jest częściowo nieznany, a tensor<?x?xf64> jest całkowicie nieznany. Rozmiary wymiarów dynamicznych są oznaczone symbolem ?. Kształtów nie można usunąć z rankingu.

W przyszłości planujemy rozszerzyć typy tensorów poza rozmiary wymiarów i typy elementów, np. o układy (#629) i rzadkość (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nazwa Typ Ograniczenia
storage_type typ liczby całkowitej, (C1-C3), (C8)
storage_min stała całkowita (C1), (C3), (C7)
storage_max stała całkowita (C2), (C3), (C7)
expressed_type typ zmiennoprzecinkowy (C4)
quantization_dimension opcjonalna stała całkowita (C10-C12)
scales zmienna liczba stałych zmiennoprzecinkowych (C4-C6), (C9), (C10), (C13)
zero_points zmienna liczba stałych całkowitych, (C7-C9)

Skwantowane typy elementów reprezentują wartości całkowite typu pamięci w zakresie od storage_min do storage_max (włącznie), które odpowiadają wartościom zmiennoprzecinkowym typu wyrażonego. Dla danej wartości całkowitej i odpowiadającą jej wartość zmiennoprzecinkową f można obliczyć jako f = (i - zero_point) * scale, gdzie scalezero_point to parametry kwantyzacji. storage_minstorage_max są opcjonalne w gramatyce, ale mają wartości domyślne min_value(storage_type)max_value(storage_type). Typy elementów skwantowanych mają te ograniczenia:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9)size(scales) = size(zero_points)
  • (C10) Jeśli is_empty(quantization_dimension), to size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Obecnie QuantizationScale jest stałą zmiennoprzecinkową, ale istnieje duże zainteresowanie skalami opartymi na liczbach całkowitych, reprezentowanymi za pomocą mnożników i przesunięć. Planujemy to sprawdzić w najbliższej przyszłości (#1404).

Trwa dyskusja na temat semantyki QuantizationZeroPoint, w tym typu, wartości i tego, czy w skwantyzowanym typie tensora może być tylko jeden punkt zerowy, czy potencjalnie wiele. Na podstawie wyników tej dyskusji specyfikacja dotycząca zera punktów może w przyszłości ulec zmianie (#1405).

Kolejna dyskusja dotyczy semantyki QuantizationStorageMinQuantizationStorageMax, aby określić, czy należy nałożyć jakiekolwiek ograniczenia na te wartości i wartości skwantyzowanych tensorów (#1406).

Planujemy też zbadać możliwość reprezentowania nieznanych skal i punktów zerowych, podobnie jak w przypadku nieznanych rozmiarów wymiarów (#1407).

Skwantowane typy tensorów reprezentują tensory ze skwantowanymi elementami. Te tensory są dokładnie takie same jak zwykłe tensory, z tym wyjątkiem, że ich elementy mają skwantowane typy elementów zamiast zwykłych typów elementów.

W przypadku skwantyzowanych tensorów kwantyzacja może być na tensor, co oznacza, że dla całego tensora jest jedna para wartości scalezero_point, lub na oś, co oznacza, że jest wiele par wartości scaleszero_points, po jednej parze na wycinek w określonym wymiarze quantization_dimension. Bardziej formalnie, w tensorze t z kwantyzacją na osi znajdują się dim(t, quantization_dimension) wycinki quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] itd. Wszystkie elementy w i-tym wycinku używają scales[i]zero_points[i] jako parametrów kwantyzacji. Typy tensorów skwantowanych mają te ograniczenia:

  • W przypadku kwantyzacji na poziomie tensora:
    • Brak dodatkowych ograniczeń.
  • W przypadku kwantyzacji na osi:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Typy tokenów reprezentują tokeny, czyli nieprzejrzyste wartości generowane i wykorzystywane przez niektóre operacje. Tokeny służą do narzucania kolejności wykonywania operacji, jak opisano w sekcji Wykonywanie.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Typy buforów reprezentują bufory. Na przykład w XLA bufory to wielowymiarowe tablice o spójnym sposobie przechowywania. Podobnie jak typy tensorów, typy buforów mają kształt i typ elementu, przy czym kształt reprezentuje nieujemne lub nieznane rozmiary wymiarów w kolejności rosnącej odpowiednich wymiarów (zwanych też osiami) numerowanych od 0 do R-1. Liczba wymiarów R jest nazywana rangą. Na przykład memref<2x3xf32> to typ bufora o kształcie 2x3 i typie elementu f32. Ma 2 wymiary (czyli 2 osie) – wymiar 0 i wymiar 1 – o rozmiarach 2 i 3. Jego pozycja to 2.

Bufory można przydzielać za pomocą funkcji custom_call do CreateBuffer lub Pin i zwalniać za pomocą funkcji custom_call do Unpin. Tylko operacje custom_call mogą odczytywać i zapisywać zawartość buforów. Więcej informacji znajdziesz w sekcji custom_call.

Typy krotek reprezentują krotki, czyli listy niejednorodne. Krotki to starsza funkcja, która istnieje tylko ze względu na zgodność z HLO. W HLO krotki służą do reprezentowania wejść i wyjść o zmiennej liczbie argumentów. W StableHLO dane wejściowe i wyjściowe o zmiennej liczbie argumentów są obsługiwane natywnie, a krotki są używane tylko do kompleksowego reprezentowania interfejsu ABI HLO, w którym np. T, tuple<T>tuple<tuple<T>> mogą się znacznie różnić w zależności od konkretnej implementacji. W przyszłości planujemy wprowadzić zmiany w interfejsie HLO ABI, które mogą nam umożliwić usunięcie typów krotek z StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Typy elementów reprezentują elementy typów tensorów. W przeciwieństwie do wielu języków programowania te typy nie są w StableHLO typami pierwszej klasy. Oznacza to, że programy StableHLO nie mogą bezpośrednio reprezentować wartości tych typów (w rezultacie wartości skalarne typu T są zwykle reprezentowane za pomocą 0-wymiarowych wartości tensorowych typu tensor<T>).

  • Typ logiczny reprezentuje wartości logiczne truefalse.
  • Typy liczb całkowitych mogą być ze znakiem (si) lub bez znaku (ui) i mieć jedną z obsługiwanych szerokości bitowych (2, 4, 8, 16, 32 lub 64). Typy siN ze znakiem reprezentują wartości całkowite od -2^(N-1) do 2^(N-1)-1 włącznie, a typy uiN bez znaku reprezentują wartości całkowite od 0 do 2^N-1 włącznie.
  • Typy zmiennoprzecinkowe mogą być jednymi z tych typów:
  • Typy złożone reprezentują wartości złożone, które mają część rzeczywistą i część urojoną tego samego typu elementu. Obsługiwane typy złożone to complex<f32> (obie części są typu f32) i complex<f64> (obie części są typu f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Typy funkcji reprezentują zarówno funkcje nazwane, jak i anonimowe. Mają one typy wejściowe (lista typów po lewej stronie symbolu ->) i typy wyjściowe (lista typów po prawej stronie symbolu ->). W wielu językach programowania typy funkcji są typami pierwszej klasy, ale nie w StableHLO.

StringType ::= 'string'

Typ String reprezentuje sekwencje bajtów. W przeciwieństwie do wielu języków programowania typ string nie jest w StableHLO typem pierwszej klasy i jest używany tylko do określania statycznych metadanych elementów programu.

Operacje

Operacje StableHLO (nazywane też operacjami) stanowią zamknięty zbiór operacji wysokiego poziomu w modelach uczenia maszynowego. Jak wspomnieliśmy powyżej, składnia StableHLO jest w dużej mierze inspirowana MLIR, co niekoniecznie jest najbardziej ergonomiczną alternatywą, ale prawdopodobnie najlepiej pasuje do celu StableHLO, jakim jest zwiększenie interoperacyjności między platformami ML a kompilatorami ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Operacje StableHLO (nazywane też operacjami) mają nazwę, dane wejściowe i wyjściowe oraz sygnaturę. Nazwa składa się z stablehlo. prefiksu i mnemonika, który jednoznacznie identyfikuje jedną z obsługiwanych operacji. Poniżej znajdziesz pełną listę obsługiwanych operacji.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Operacje wykorzystują dane wejściowe i generują dane wyjściowe. Dane wejściowe są podzielone na wartości wejściowe (obliczane podczas wykonywania), funkcje wejściowe (dostarczane statycznie, ponieważ w StableHLO funkcje nie są wartościami pierwszej klasy) i atrybuty wejściowe (również dostarczane statycznie). Rodzaj danych wejściowych i wyjściowych używanych i generowanych przez operację zależy od jej mnemonika. Na przykład operacja add op pobiera 2 wartości wejściowe i zwraca 1 wartość wyjściową. Dla porównania operacja select_and_scatter zużywa 3 wartości wejściowe, 2 funkcje wejściowe i 3 atrybuty wejściowe.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Funkcje wejściowe (nazywane też funkcjami anonimowymi) są bardzo podobne do funkcji nazwanych, z tą różnicą, że: 1) nie mają identyfikatora (stąd nazwa „anonimowe”), 2) nie deklarują typów danych wyjściowych (typy danych wyjściowych są wywnioskowane z operacji return w funkcji).

Składnia funkcji wejściowych zawiera obecnie nieużywaną część (patrz Unusedprodukcja powyżej), która jest tam umieszczona ze względu na zgodność z MLIR. W MLIR istnieje bardziej ogólna koncepcja „regionów”, które mogą mieć wiele „bloków” operacji połączonych ze sobą za pomocą operacji skoku. Bloki te mają identyfikatory odpowiadające Unused produkcji, dzięki czemu można je od siebie odróżnić. StableHLO nie ma instrukcji skoku, więc odpowiednia część składni MLIR jest nieużywana (ale nadal istnieje).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Atrybuty wejściowe mają nazwę i wartość, która jest jedną z obsługiwanych stałych. Są one podstawowym sposobem określania statycznych metadanych elementów programu. Na przykład operacja concatenate op używa atrybutu dimension, aby określić wymiar, wzdłuż którego łączone są wartości wejściowe. Podobnie operator slice używa wielu atrybutów, takich jak start_indices i limit_indices, aby określić granice używane do dzielenia wartości wejściowej.

Obecnie programy StableHLO w praktyce czasami zawierają atrybuty, które nie są opisane w tym dokumencie. W przyszłości planujemy włączyć te atrybuty do zestawu operacji StableHLO lub zabronić ich pojawiania się w programach StableHLO. Oto lista tych atrybutów:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Metadane lokalizacji (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

Sygnatura operacji składa się z typów wszystkich wartości wejściowych (lista typów po lewej stronie znaku ->) i typów wszystkich wartości wyjściowych (lista typów po prawej stronie znaku ->). Ściśle mówiąc, typy wejściowe są zbędne, a typy wyjściowe są prawie zawsze zbędne (ponieważ w przypadku większości operacji StableHLO typy wyjściowe można wywnioskować z danych wejściowych). Mimo to sygnatura operacji jest celowo częścią składni StableHLO, aby zapewnić zgodność z MLIR.

Poniżej znajdziesz przykład operacji, której kod mnemoniczny to select_and_scatter. Pobiera 3 wartości wejściowe (%operand, %source%init_value), 2 funkcje wejściowe i 3 atrybuty wejściowe (window_dimensions, window_stridespadding). Zwróć uwagę, że sygnatura operacji zawiera tylko typy wartości wejściowych (ale nie typy funkcji wejściowych i atrybutów, które są podawane w tekście).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Stałe

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Stałe StableHLO mają literał i typ, które razem reprezentują wartość StableHLO. Zwykle typ jest częścią składni stałej, z wyjątkiem sytuacji, gdy jest jednoznaczny (np. stała logiczna ma jednoznacznie typ i1, podczas gdy stała całkowita może mieć wiele możliwych typów).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Stałe logiczne reprezentują wartości logiczne truefalse. Stałe logiczne mają typ i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Stałe całkowite reprezentują wartości całkowite za pomocą ciągów tekstowych, które używają notacji dziesiętnej lub szesnastkowej. Inne systemy liczbowe, np. dwójkowy czy ósemkowy, nie są obsługiwane. Stałe całkowite mają te ograniczenia:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Stałe zmiennoprzecinkowe reprezentują wartości zmiennoprzecinkowe za pomocą ciągów znaków, które używają notacji dziesiętnej lub naukowej. Dodatkowo notacji szesnastkowej można używać do bezpośredniego określania podstawowych bitów w formacie zmiennoprzecinkowym odpowiedniego typu. Stałe zmiennoprzecinkowe mają te ograniczenia:

  • (C1) Jeśli używana jest notacja inna niż szesnastkowa, is_wellformed(float_literal, float_type).
  • (C2) Jeśli używana jest notacja szesnastkowa, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Stałe zespolone reprezentują wartości zespolone za pomocą list części rzeczywistej (pierwsza) i urojonej (druga). Na przykład (1.0, 0.0) : complex<f32> oznacza 1.0 + 0.0i, a (0.0, 1.0) : complex<f32> oznacza 0.0 + 1.0i. Kolejność, w jakiej te części są następnie przechowywane w pamięci, jest określona przez implementację. Stałe złożone mają te ograniczenia:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Stałe tensora reprezentują wartości tensora za pomocą zagnieżdżonych list określonych za pomocą notacji NumPy. Na przykład dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> reprezentuje wartość tensora z tym mapowaniem indeksów na elementy: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Kolejność, w jakiej te elementy są następnie przechowywane w pamięci, jest określona przez implementację. Stałe tensora podlegają tym ograniczeniom:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), gdzie:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), gdzie:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • w przeciwnym razie false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Skwantowane stałe tensora reprezentują skwantowane wartości tensora przy użyciu tej samej notacji co stałe tensora, przy czym elementy są określone jako stałe typu pamięci. Stałe tensory skwantowane mają te ograniczenia:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Literały ciągów znaków składają się z bajtów określonych za pomocą znaków ASCII i sekwencji ucieczki. Nie zależą od kodowania, więc interpretacja tych bajtów jest zdefiniowana przez implementację. Literały ciągów znaków mają typ string.

Operacje

abs

Semantyka

Wykonuje operację wartości bezwzględnej na każdym elemencie tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb całkowitych ze znakiem: moduł liczby całkowitej.
  • W przypadku liczb zmiennoprzecinkowych: abs z IEEE-754.
  • W przypadku liczb zespolonych: moduł liczby zespolonej.
  • W przypadku typów skwantowanych: dequantize_op_quantize(abs, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor ze znakiem liczby całkowitej, zmiennoprzecinkowej lub zespolonej albo tensor skwantowany na poziomie tensora (C1-C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor ze znakiem liczby całkowitej lub zmiennoprzecinkowej albo tensor skwantowany na poziomie tensora (C1-C2)

Ograniczenia

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) jest zdefiniowane jako:
    • complex_element_type(element_type(operand)) jeżeli is_complex(operand).
    • baseline_element_type(operand) w przeciwnym razie.

Przykłady

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]

Więcej przykładów

dodaj

Semantyka

Wykonuje dodawanie elementów dwóch tensorów lhsrhs i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: operator logiczny LUB.
  • W przypadku liczb całkowitych: dodawanie liczb całkowitych.
  • W przypadku liczb zmiennoprzecinkowych: addition z IEEE-754.
  • W przypadku liczb zespolonych: dodawanie liczb zespolonych.
  • W przypadku typów skwantowanych: dequantize_op_quantize(add, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantyzowany, (C1-C6)
(I2) rhs tensor lub tensor skwantyzowany, (C1-C5), (C7)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C1-C7)

Ograniczenia

  • Jeśli operacja używa tensorów niekwantyzowanych:
    • (C1) type(lhs) = type(rhs) = type(result).
  • Jeśli operacja używa skwantyzowanych tensorów:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Jeśli is_per_axis_quantized(lhs), to quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) = quantization_dimension(result).

Przykłady

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]

Więcej przykładów

after_all

Semantyka

Zapewnia, że operacje generujące inputs są wykonywane przed operacjami zależnymi od result. Wykonanie tej operacji nie powoduje żadnych zmian. Istnieje ona tylko po to, aby ustanowić zależności danych od result do inputs.

Wejścia

Etykieta Nazwa Typ
(I1) inputs zmienna liczba token

Wyniki

Nazwa Typ
result token

Przykłady

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token

Więcej przykładów

all_gather

Semantyka

W każdej grupie procesów w siatce procesów StableHLO łączy wartości tensorów operands z każdego procesu wzdłuż wymiaru all_gather_dim i tworzy tensory results.

Operacja dzieli siatkę procesów StableHLO na process_groups, co jest zdefiniowane w ten sposób:

  • cross_replica(replica_groups) jeśli channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jeśli channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jeśli channel_id > 0 and use_global_device_ids = true.

Następnie w każdym process_group:

  • operands...@receiver = [operand@sender for sender in process_group] dla wszystkichreceiverprocess_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) dla wszystkichprocessprocess_group.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operands zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1), (C6)
(I2) all_gather_dim stała typu si64 (C1), (C6)
(I3) replica_groups Stała tensora dwuwymiarowego typu si64 (C2-C4)
(I4) channel_id stała typu si64 (C5)
(I5) use_global_device_ids stała typu i1 (C5)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C6)

Ograniczenia

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_replicas, jeśli używana jest właściwość cross_replica_and_partition.
    • num_processes, jeśli używana jest właściwość flattened_ids.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Jeśli use_global_device_ids = true, to channel_id > 0.
  • (C6) type(results...) = type(operands...) z wyjątkiem:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
  // channel_id = 0
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
  // use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

Więcej przykładów

all_reduce

Semantyka

W każdej grupie procesów w siatce procesów StableHLO stosuje funkcję redukcji computation do wartości tensorów operands z każdego procesu i tworzy tensory results.

Operacja dzieli siatkę procesów StableHLO na process_groups, co jest zdefiniowane w ten sposób:

  • cross_replica(replica_groups) jeśli channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jeśli channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jeśli channel_id > 0 and use_global_device_ids = true.

Następnie w każdym process_group:

  • results...@process[result_index] = exec(schedule) dla pewnego drzewa binarnegoschedule, gdzie:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule to zdefiniowane przez implementację drzewo binarne, którego przejście w porządku środkowym to to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operands zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C5), (C6)
(I2) replica_groups zmienna liczba stałych tensorów 1-wymiarowych typu si64 (C1-C3)
(I3) channel_id stała typu si64 (C4)
(I4) use_global_device_ids stała typu i1 (C4)
(I5) computation funkcja (C5)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C6-C7)

Ograniczenia

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_replicas, jeśli używana jest właściwość cross_replica_and_partition.
    • num_processes, jeśli używana jest właściwość flattened_ids.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Jeśli use_global_device_ids = true, to channel_id > 0.
  • (C5) computation ma typ (tensor<E>, tensor<E>) -> (tensor<E>), gdzie is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  // channel_id = 0
  channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
  // use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

Więcej przykładów

all_to_all

Semantyka

all_to_all

W każdej grupie procesów w siatce procesów StableHLO dzieli wartości tensorów operands wzdłuż osi split_dimension na części, rozprasza podzielone części między procesy, łączy rozproszone części wzdłuż osi concat_dimension i tworzy tensory results. Operacja dzieli siatkę procesów StableHLO na process_groups, co jest zdefiniowane w ten sposób:

  • cross_replica(replica_groups) jeżeli channel_id <= 0.
  • cross_partition(replica_groups) jeżeli channel_id > 0.

Następnie w każdym process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension) dla wszystkich senderprocess_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group], gdziereceiver_index = process_group.index(receiver)
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operands zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1-C3), (C9)
(I2) split_dimension stała typu si64 (C1), (C2), (C9)
(I3) concat_dimension stała typu si64 (C3), (C9)
(I4) split_count stała typu si64 (C2), (C4), (C8), (C9)
(I5) replica_groups Stała tensora dwuwymiarowego typu si64 (C5-C8)
(I6) channel_id stała typu si64

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C9)

Ograniczenia

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...) z wyjątkiem sytuacji, gdy split_dimension != concat_dimension:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
  // channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

Więcej przykładów

i

Semantyka

Wykonuje operację AND na poszczególnych elementach 2 tensorów lhsrhs oraz tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: logiczne I.
  • W przypadku liczb całkowitych: bitowe AND.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu logicznego lub całkowitego, (C1)
(I2) rhs tensor typu logicznego lub całkowitego, (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego lub całkowitego, (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]

Więcej przykładów

atan2

Semantyka

Wykonuje operację atan2 na poszczególnych elementach tensorów lhsrhs, a wynikiem jest tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: atan2 z IEEE-754.
  • W przypadku liczb zespolonych: złożona funkcja atan2.
  • W przypadku typów skwantowanych: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)
(I2) rhs tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Więcej przykładów

batch_norm_grad

Semantyka

Oblicza gradienty kilku danych wejściowych batch_norm_training przez propagację wsteczną z grad_output i generuje tensory grad_operand, grad_scalegrad_offset. Bardziej formalnie tę operację można wyrazić jako dekompozycję na istniejące operacje StableHLO za pomocą składni Pythona w ten sposób:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

W przypadku typów skwantyzowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1-C3), (C5)
(I2) scale Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C4), (C5)
(I3) mean Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C4)
(I4) variance Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C4)
(I5) grad_output tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C2), (C3)
(I6) epsilon stała typu f32
(I7) feature_index stała typu si64 (C1), (C5)

Wyniki

Nazwa Typ Ograniczenia
grad_operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C2), (C3)
grad_scale Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C4)
grad_offset Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C4)

Ograniczenia

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scalegrad_offset mają ten sam baseline_element_type.
  • (C3) operand, grad_outputgrad_operand mają ten sam kształt.
  • (C4) Znaki scale, mean, variance, grad_scalegrad_offset mają ten sam kształt.
  • (C5) size(scale) = dim(operand, feature_index).

Przykłady

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
 <    tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantyka

Normalizuje tensor operand we wszystkich wymiarach z wyjątkiem wymiaru feature_index i tworzy tensor result. Bardziej formalnie tę operację można wyrazić jako dekompozycję na istniejące operacje StableHLO za pomocą składni Pythona w ten sposób:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1-C7)
(I2) scale Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C3)
(I3) offset Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C4)
(I4) mean Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C5)
(I5) variance Tensor 1-wymiarowy typu zmiennoprzecinkowego lub skwantowanego na poziomie tensora (C2), (C6)
(I6) epsilon stała typu f32
(I7) feature_index stała typu si64 (C1), (C3-C6)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C2), (C7)

Ograniczenia

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, varianceresult mają ten sam baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantyka

Oblicza średnią i wariancję we wszystkich wymiarach z wyjątkiem wymiaru feature_index i normalizuje tensor operand, tworząc tensory output, batch_meanbatch_var. Bardziej formalnie tę operację można wyrazić jako dekompozycję na istniejące operacje StableHLO za pomocą składni Pythona w ten sposób:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

W przypadku typów skwantyzowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)
(I2) scale 1-wymiarowy tensor liczb zmiennoprzecinkowych lub tensor kwantyzowany (C2), (C3)
(I3) offset 1-wymiarowy tensor liczb zmiennoprzecinkowych lub tensor kwantyzowany (C2), (C4)
(I4) epsilon stała typu f32 (C1), (C3-C6)
(I5) feature_index stała typu si64 (C1), (C3-C6)

Wyniki

Nazwa Typ Ograniczenia
output tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C7)
batch_mean 1-wymiarowy tensor liczb zmiennoprzecinkowych lub tensor kwantyzowany (C2), (C5)
batch_var 1-wymiarowy tensor liczb zmiennoprzecinkowych lub tensor kwantyzowany (C2), (C6)

Ograniczenia

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_varoutput mają ten sam baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Przykłady

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
 <   (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantyka

Wykonuje operację bitcast na tensorze operand i tworzy tensor result, w którym bity całego tensora operand są interpretowane na nowo przy użyciu typu tensora result.

Bardziej formalnie, jeśli mamy dane E = element_type(operand), E' = element_type(result)R = rank(operand):

  • Jeśli num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Jeśli num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Jeśli num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits zwraca reprezentację w pamięci danej wartości, a jej działanie jest zdefiniowane przez implementację, ponieważ dokładna reprezentacja tensorów jest zdefiniowana przez implementację, a dokładna reprezentacja typów elementów jest również zdefiniowana przez implementację.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantyzowany, (C1-C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C1-C2)

Ograniczenia

  • (C1) Dane: E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) i R = rank(operand):
    • Jeśli num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Jeśli num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) za wszystkie 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Jeśli num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) za wszystkie 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Jeśli is_complex(operand) or is_complex(result), to is_complex(operand) and is_complex(result).

Przykłady

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Więcej przykładów

broadcast_in_dim

Semantyka

broadcast_in_dim

Rozszerza wymiary lub rangę tensora wejściowego przez zduplikowanie danych w tensorze operand i tworzy tensor result. Bardziej formalnie:result[result_index] = operand[operand_index] gdzie dla wszystkich daxes(operand):

  • operand_index[d] = 0 jeżeli dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] w przeciwnym razie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantyzowany, (C1-C2), (C5-C6)
(I2) broadcast_dimensions 1-wymiarowa stała tensora typu si64 (C2-C6)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C1), (C3), (C5-C6)

Ograniczenia

  • (C1) element_type(result) jest określone wzorem:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand), z wyjątkiem tego, że quantization_dimension(operand), scales(operand)zero_points(operand) mogą różnić się od quantization_dimension(result), scales(result)zero_points(result) odpowiednio, w przeciwnym razie.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Dla wszystkich daxes(operand):
    • dim(operand, d) = 1 lub
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jeśli is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jeśli dim(operand, quantization_dimension(operand)) = 1, to scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Przykłady

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Więcej przykładów

etui

Semantyka

Zwraca wynik wykonania dokładnie 1 funkcji z branches w zależności od wartości index. Bardziej formalnie:result = selected_branch() gdzie:

  • selected_branch = branches[index] jeżeli 0 <= index < size(branches).
  • selected_branch = branches[-1] w przeciwnym razie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) index Tensor 0-wymiarowy typu si32
(I2) branches zmienna liczba funkcji, (C1-C4)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, skwantowanych tensorów lub tokenów; (C4)

Ograniczenia

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Przykłady

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
  "stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]

Więcej przykładów

cbrt

Semantyka

Wykonuje operację pierwiastka sześciennego na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: rootn(x, 3) z IEEE-754.
  • W przypadku liczb zespolonych: zespolony pierwiastek sześcienny.
  • W przypadku typów skwantyzowanych: dequantize_op_quantize(cbrt, operand, type(result))

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]

Więcej przykładów

ceil

Semantyka

Wykonuje operację zaokrąglania w górę na poszczególnych elementach tensora operand i tworzy tensor result. Implementuje operację roundToIntegralTowardPositive ze specyfikacji IEEE-754. W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(ceil, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Więcej przykładów

cholesky

Semantyka

Oblicza rozkład Cholesky’ego dla partii macierzy.

Bardziej formalnie, dla wszystkich iindex_space(result), result[i0, ..., iR-3, :, :] jest rozkładem Cholesky’ego macierzy a[i0, ..., iR-3, :, :] w postaci macierzy dolnotrójkątnej (jeśli lower to true) lub górnotrójkątnej (jeśli lower to false). Wartości wyjściowe w przeciwnym trójkącie, czyli odpowiednio w trójkącie górnym lub dolnym, są zdefiniowane przez implementację.

Jeśli istnieje i, gdzie macierz wejściowa nie jest hermitowską macierzą dodatnio określoną, działanie jest nieokreślone.

W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) a tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1-C3)
(I2) lower stała typu i1

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Przykłady

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

ograniczać (zakres)

Semantyka

Ogranicza każdy element tensora operand do wartości minimalnej i maksymalnej oraz tworzy tensor result. Formalnie: result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), gdzie min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(clamp, min, operand, max, type(result)).

Wprowadzenie porządku w liczbach zespolonych wiąże się z nieoczekiwaną semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) min tensor lub tensor skwantowany na poziomie tensora; (C1), (C3)
(I2) operand tensor lub tensor skwantowany na poziomie tensora; (C1-C4)
(I3) max tensor lub tensor skwantowany na poziomie tensora; (C2), (C3)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C4)

Ograniczenia

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Przykłady

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]

Więcej przykładów

collective_broadcast

Semantyka

W każdej grupie procesów w siatce procesów StableHLO wyślij wartość tensora operand z procesu źródłowego do procesów docelowych i utwórz tensor result.

Operacja dzieli siatkę procesów StableHLO na process_groups, co jest zdefiniowane w ten sposób:

  • cross_replica(replica_groups) jeżeli channel_id <= 0.
  • cross_partition(replica_groups) jeżeli channel_id > 0.

Wartość result@process jest obliczana w ten sposób:

  • operand@process_groups[i, 0] jeśli istnieje i, taki że proces znajduje się w process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) w przeciwnym razie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C3)
(I2) replica_groups zmienna liczba stałych tensorów 1-wymiarowych typu si64 (C1), (C2)
(I3) channel_id stała typu si64

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C3)

Ograniczenia

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, gdzie N jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C3) type(result) = type(operand).

Przykłady

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantyka

W każdej grupie procesów w siatce procesów StableHLO wysyła wartość tensora operand z procesu źródłowego do procesu docelowego i tworzy tensor result.

Operacja dzieli siatkę procesów StableHLO na process_groups, co jest zdefiniowane w ten sposób:

  • cross_replica(source_target_pairs) jeżeli channel_id <= 0.
  • cross_partition(source_target_pairs) jeżeli channel_id > 0.

Wartość result@process jest obliczana w ten sposób:

  • operand@process_groups[i, 0], jeśli istnieje i takie, że process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)) w przeciwnym razie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C5)
(I2) source_target_pairs Stała tensora dwuwymiarowego typu si64 (C1-C4)
(I3) channel_id stała typu si64

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1)

Ograniczenia

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, gdzie N jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C5) type(result) = type(operand).

Przykłady

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Więcej przykładów

porównaj

Semantyka

Wykonuje porównanie elementów tensorów lhsrhs zgodnie z warunkami comparison_directioncompare_type oraz tworzy tensor result.

Wartości comparison_directioncompare_type mają następującą semantykę:

W przypadku typów elementów logicznych i całkowitych:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

W przypadku typów elementów zmiennoprzecinkowych z wartością compare_type = FLOAT operacja implementuje te działania zgodne ze standardem IEEE-754:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

W przypadku typów elementów zmiennoprzecinkowych z compare_type = TOTALORDER operacja używa kombinacji operacji totalOrdercompareQuietEqual z IEEE-754.

W przypadku złożonych typów elementów porównanie leksykograficzne par (real, imag) jest przeprowadzane przy użyciu podanych wartości comparison_directioncompare_type. Wprowadzenie porządku w liczbach zespolonych wiąże się z nieoczekiwaną semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych, gdy comparison_direction ma wartość GE, GT, LE lub LT (#560).

W przypadku typów poddanych kwantyzacji wykonuje działanie dequantize_compare(lhs, rhs, comparison_direction).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantowany na poziomie tensora; (C1-C3)
(I2) rhs tensor lub tensor skwantowany na poziomie tensora; (C1-C2)
(I3) comparison_direction wyliczenie EQ, NE, GE, GT, LE i LT
(I4) compare_type wyliczenie FLOAT, TOTALORDER, SIGNED i UNSIGNED (C3)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego (C2)

Ograniczenia

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type jest zdefiniowane jako:
    • SIGNED jeżeli is_signed_integer(element_type(lhs)).
    • UNSIGNED jeżeli is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT lub TOTALORDER, jeśli is_float(element_type(lhs)).
    • FLOAT jeżeli is_complex(element_type(lhs)).

Przykłady

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = <#stablehlocomparison_di>rection LT,
  compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]

Więcej przykładów

złożone,

Semantyka

Wykonuje konwersję na poziomie elementów na wartość zespoloną z pary wartości rzeczywistych i urojonych, lhs i rhs, i tworzy tensor result.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu f32 lub f64 (C1-C3)
(I2) rhs tensor typu f32 lub f64 (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu złożonego, (C2), (C3)

Ograniczenia

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) ma typ complex<E>, gdzie E = element_type(lhs).

Przykłady

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]

Więcej przykładów

wieloskładnikowa

Semantyka

Zawiera operację składającą się z innych operacji StableHLO, która przyjmuje inputscomposite_attributes, a zwraca results. Semantyka operacji jest implementowana przez atrybut decomposition. Operację composite można zastąpić jej dekompozycją bez zmiany semantyki programu. Jeśli wstawienie dekompozycji nie zapewnia tej samej semantyki operacji, preferuj użycie custom_call.

Pole version (domyślnie 0) służy do oznaczania zmiany semantyki elementu złożonego.

Wejścia

Etykieta Nazwa Typ
(I1) inputs zmienna liczba wartości,
(I2) name stała typu string
(I3) composite_attributes słownik atrybutów,
(I4) decomposition stała typu string
(I5) version stała typu si32

Wyniki

Nazwa Typ
results zmienna liczba wartości,

Ograniczenia

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Przykłady

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
 < ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32

Więcej przykładów

łączyć,

Semantyka

Łączy tensory inputs wzdłuż wymiaru dimension w tej samej kolejności co podane argumenty i tworzy tensor result. Bardziej formalnie: result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], gdzie:

  1. id = d0 + ... + dk-1 + kd.
  2. d jest równe dimension, a d0, ... to rozmiary d-tego wymiaru inputs.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1-C6)
(I2) dimension stała typu si64 (C2), (C4), (C6)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C5-C6)

Ograniczenia

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) z wyjątkiem dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) z wyjątkiem:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Przykłady

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Więcej przykładów

stała

Semantyka

Tworzy tensor output ze stałej value.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) value stała (C1)

Wyniki

Nazwa Typ Ograniczenia
output tensor lub tensor skwantyzowany, (C1)

Ograniczenia

  • (C1) type(value) = type(output).

Przykłady

%output = "stablehlo.constant"() {
  val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]

Więcej przykładów

dokonają konwersji

Semantyka

Wykonuje konwersję elementów tensora operand z jednego typu na inny i tworzy tensor result.

W przypadku konwersji boolean-to-any-supported-type wartość false jest konwertowana na zero, a wartość true – na jeden. W przypadku konwersji any-supported-type-to-boolean wartość zero jest konwertowana na false, a wartości niezerowe są konwertowane na true. Poniżej dowiesz się, jak to działa w przypadku typów złożonych.

W przypadku konwersji liczby całkowitej na liczbę całkowitą, liczby całkowitej na liczbę zmiennoprzecinkową lub liczby zmiennoprzecinkowej na liczbę zmiennoprzecinkową, jeśli wartość źródłową można dokładnie przedstawić w typie docelowym, wartość wynikowa jest tym dokładnym przedstawieniem. W przeciwnym razie działanie jest do ustalenia(#180).

W przypadku konwersji z floating-point-to-integer część ułamkowa jest obcinana. Jeśli skróconej wartości nie można przedstawić w typie miejsca docelowego, sposób działania zostanie określony (#180).

Konwersje z liczby zespolonej na liczbę zespoloną działają tak samo jak konwersje z liczby zmiennoprzecinkowej na liczbę zmiennoprzecinkową w przypadku konwertowania części rzeczywistych i urojonych.

W przypadku konwersji z typu złożonego na dowolny inny typz dowolnego innego typu na typ złożony źródłowa wartość urojona jest ignorowana lub docelowa wartość urojona jest zerowana. Konwersja części rzeczywistej jest zgodna z konwersjami liczb zmiennoprzecinkowych.

W zasadzie ta operacja może wyrażać dekwantyzację (konwersję z tensorów skwantyzowanych na zwykłe), kwantyzację (konwersję ze zwykłych tensorów na skwantyzowane) i ponowną kwantyzację (konwersję między tensorami skwantyzowanymi), ale obecnie mamy do tego specjalne operacje – uniform_dequantize w pierwszym przypadku i uniform_quantize w drugim i trzecim. W przyszłości te 2 operacje mogą zostać połączone w convert (#1576).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor (C1)

Ograniczenia

  • (C1) shape(operand) = shape(result).

Przykłady

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Więcej przykładów

splot

Semantyka

Oblicza iloczyny skalarne między oknami lhs a wycinkami rhs i generuje result. Ten diagram pokazuje, jak elementy w result są obliczane na podstawie lhsrhs na konkretnym przykładzie.

splot

Bardziej formalnie, rozważmy następujące przekształcenie danych wejściowych w lhs, aby móc wyrazić okna lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Do zmiany ramki używane są te funkcje pomocnicze:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1], gdzie j[d] = i[permutation[d]].

Jeśli feature_group_count = 1 i batch_group_count = 1, to dla wszystkich output_spatial_index w index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product, gdzie:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Ta funkcja wydaje się nieużywana, więc w przyszłości planujemy ją usunąć (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Jeśli feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Jeśli batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

W przypadku typów skwantyzowanych wykonuje operację dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

W przypadku hybrydowych typów skwantyzowanych wykonuje hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantowany na poziomie tensora; (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs tensor lub tensor skwantyzowany, (C1), (C14-C16), (C25), (C27-C29), (C31-C34)
(I3) window_strides 1-wymiarowa stała tensora typu si64 (C2-C3), (C25)
(I4) padding Stała tensora dwuwymiarowego typu si64 (C4), (C25)
(I5) lhs_dilation 1-wymiarowa stała tensora typu si64 (C5-C6), (C25)
(I6) rhs_dilation 1-wymiarowa stała tensora typu si64 (C7-C8), (C25)
(I7) window_reversal 1-wymiarowa stała tensora typu i1 (C9)
(I8) input_batch_dimension stała typu si64 (C10), (C13), (C25)
(I9) input_feature_dimension stała typu si64 (C11), (C13-C14)
(I10) input_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension stała typu si64 (C14), (C18)
(I12) kernel_output_feature_dimension stała typu si64 (C15-C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C17-C18), (C25)
(I14) output_batch_dimension stała typu si64 (C20), (C25)
(I15) output_feature_dimension stała typu si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C19-C20), (C25)
(I17) feature_group_count stała typu si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count stała typu si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config zmienna liczba wyliczeń DEFAULT, HIGH i HIGHEST (C24)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C25-C28), (C30), (C32-34)

Ograniczenia

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9)size(window_reversal) = N - 2
  • (C10)dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dane input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dane kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Dane output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) jest zdefiniowane jako:
    • dim(lhs, input_batch_dimension) / batch_group_count jeżeli result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jeżeli result_dim = output_feature_dimension.
    • num_windows w innym przypadku, gdzie:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jeśli operacja używa tensorów niekwantyzowanych:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jeśli operacja używa skwantyzowanych tensorów:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jeśli is_per_axis_quantized(result), to quantization_dimension(result) = output_feature_dimension.
    • Jeśli is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) Jeśli is_per_tensor_quantized(rhs), to is_per_tensor_quantized(result).
    • Jeśli !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Przykłady

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strid<es = arra>yi64: 4, 4,
  paddi<n>g = dense<0 : ten>sor2x2xi64,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  // In the StableHLO dialect, dimension numbers are encoded vi<a:
  // `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" a<re spatial dimensions.
  d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
  batch_group_count = 1 : i64,
  fea<ture_group_count >= 1 : i64,
 < precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

Więcej przykładów

cosinus

Semantyka

Wykonuje operację cosinusową na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: cos z IEEE-754.
  • W przypadku liczb zespolonych: cosinus zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(cosine, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Więcej przykładów

count_leading_zeros

Semantyka

Wykonuje zliczanie elementów liczby początkowych bitów zerowych w tensorze operand i tworzy tensor result.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego (C1)

Ograniczenia

  • (C1) type(operand) = type(result).

Przykłady

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]

Więcej przykładów

custom_call

Semantyka

Zawiera zdefiniowaną przez implementację operację call_target_name, która przyjmuje inputscalled_computations oraz zwraca results. Właściwości has_side_effect, backend_configapi_version mogą służyć do podawania dodatkowych metadanych zdefiniowanych przez implementację.

Obecnie ta operacja zawiera dość nieuporządkowany zbiór metadanych, który odzwierciedla organiczną ewolucję jej odpowiednika w kompilatorze XLA. W przyszłości planujemy ujednolicić te metadane (#741).

Wejścia

Etykieta Nazwa Typ
(I1) inputs zmienna liczba wartości,
(I2) call_target_name stała typu string
(I3) has_side_effect stała typu i1
(I4) backend_config stała typu string lub słownik atrybutów
(I5) api_version stała typu si32
(I6) called_computations zmienna liczba stałych typu string
(I7) output_operand_aliases określ części aliasowania w danych wyjściowych i operandach;

Wyniki

Nazwa Typ
results zmienna liczba wartości,

(Obsługa GPU XLA) Specjalne cele custom_call

Istnieją 3 specjalne call_target_name związane z typami buffer: CreateBuffer tworzy niezainicjowany typ buffer, Pin tworzy zainicjowany typ buffer, a Unpin zwalnia pamięć typu buffer i zwraca zawartość typu buffer.

%uninitialized_buffer = "stablehlo.custom_call"() {
  call_target_name = "CreateBuffer",
  api_version> = 4 : <i32,
>} : () - memref4xf64

%initialized_buffer = "stablehlo.custom_call"(%init_value) {
  call_target_name = "Pin&quo<t;,
 > ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64

%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
  call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
  api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64

Alias

Niektóre operacje custom_call mogą wymagać, aby część danych wyjściowych i część operandów współdzieliły tę samą pamięć. Możesz to wyrazić za pomocą atrybutu output_operand_aliases. Reprezentacja pary aliasów składa się z listy indeksów krotek wyjściowych reprezentujących część wyjściową oraz indeksu operandu wraz z listą indeksów krotek operandów reprezentujących część operandu. Lista indeksów krotek wyjściowych lub argumentów jest pusta, jeśli odpowiedni typ nie jest typem tuple, i może być dowolnie długa w przypadku dowolnie zagnieżdżonego typu krotki. Jest to podobne do reprezentacji aliasu XLA.

Część wyjściowa i część wejściowa w parze aliasów muszą być tego samego typu. W przypadku operacji custom_call, które nie są wywołaniami funkcji CreateBuffer, PinUnpin, operand buffer może występować w co najwyżej 1 parze aliasów, a dane wyjściowe buffer muszą występować w 1 parze aliasów.

Przykłady

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64

%updated_buffer = "stablehlo.custom_call"(%buffer) {
  call_target_name = "Update",
  api_version = 4 : i32,
  output_operand_aliases< = [
    #stablehlo.output_operand_aliasoutput_tuple_indices = [],
      operand_ind>ex = 0,
     < oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64

dzielenie

Semantyka

Wykonuje dzielenie elementów tensorów dzielnej lhs i dzielnika rhs oraz tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb całkowitych: dzielenie całkowite, które daje iloraz algebraiczny bez części ułamkowej.
  • W przypadku liczb zmiennoprzecinkowych: division z IEEE-754.
  • W przypadku liczb zespolonych: dzielenie liczb zespolonych.
  • W przypadku typów skwantyzowanych:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)
(I2) rhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Więcej przykładów

dot_general

Semantyka

Oblicza iloczyny skalarne między wycinkami tensora lhs i wycinkami tensora rhs oraz zwraca tensor result.

Bardziej formalnie: result[result_index] = dot_product, gdzie:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index, gdzie size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions)size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

W przypadku typów skwantyzowanych wykonuje operację dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

W przypadku hybrydowych typów skwantyzowanych wykonuje hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config kontroluje kompromis między szybkością a dokładnością obliczeń na backendach akceleratorów. Może to być jedna z tych wartości (obecnie semantyka tych wartości wyliczeniowych jest niedostatecznie określona, ale planujemy to zmienić w ramach #755):

  • DEFAULT: Najszybsze obliczenia, ale najmniej dokładne przybliżenie oryginalnej liczby.
  • HIGH: Wolniejsze obliczanie, ale dokładniejsze przybliżenie pierwotnej liczby.
  • HIGHEST: najwolniejsze obliczenia, ale najdokładniejsze przybliżenie liczby pierwotnej.

DotAlgorithm określa główne właściwości algorytmu używanego do implementacji operacji iloczynu skalarnego, która określa też precyzję. Jeśli pola atrybutu algorytmu są ustawione, wartość precision_config musi być równa DEFAULT. DotAlgorithmsnie mają wartości domyślnej, ponieważ domyślne parametry są zdefiniowane przez implementację. Wszystkie pola algorytmu kropkowego można ustawić na None, aby określić pusty algorytm kropkowy, który zamiast tego będzie używać wartości precision_config.

Pola DotAlgorithm obejmują:

  • lhs_precision_typerhs_precision_type, precyzje, do których zaokrąglane są lewa i prawa strona operacji. Typy precyzji są niezależne od typów pamięci danych wejściowych i wyjściowych.
  • accumulation_type precyzję używaną do akumulacji;
  • lhs_component_count, rhs_component_countnum_primitive_operations są stosowane, gdy algorytm rozkłada lewą lub prawą stronę na wiele komponentów i wykonuje na tych wartościach wiele „podstawowych” operacji iloczynu skalarnego – zwykle w celu emulowania wyższej precyzji (np.Wykorzystanie typu danych bfloat16 w sztucznej inteligencji do obliczeń o wyższej precyzji: bf16_6x tf32_3x itp.). W przypadku algorytmów bez dekompozycji te wartości powinny być ustawione na 1.
  • allow_imprecise_accumulation, aby określić, czy w przypadku niektórych kroków dozwolone jest gromadzenie danych z mniejszą precyzją (np. CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Przykładowe atrybuty DotAlgorithm:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

To implementacje decydują, które kombinacje są obsługiwane. Ogólnie rzecz biorąc, nie ma gwarancji, że każdy algorytm jest obsługiwany na każdym typie akceleratora przez konsumenta StableHLO. Jeśli dany algorytm nie jest obsługiwany, należy zgłosić błąd, a nie używać alternatywnego algorytmu. Weryfikacja StableHLO będzie zapewniać weryfikację w największym możliwym stopniu, zapobiegając algorytmom, które nie są obsługiwane na żadnym sprzęcie.

Niektóre obsługiwane wartości algorytmu znajdziesz w sekcji xla_data.proto > Algorithm. Zgłoszenie nr 2483 zawiera plan utworzenia centralnego dokumentu dotyczącego obsługiwanych algorytmów według backendu.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantowany na poziomie tensora; (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs tensor lub tensor skwantyzowany, (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions 1-wymiarowa stała tensora typu si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions 1-wymiarowa stała tensora typu si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions 1-wymiarowa stała tensora typu si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions 1-wymiarowa stała tensora typu si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config zmienna liczba wyliczeń DEFAULT, HIGH i HIGHEST (C11), (C21)
(I8) lhs_precision_type FloatType lub TensorFloat32 (C21)
(I9) rhs_precision_type FloatType lub TensorFloat32 (C21)
(I10) accumulation_type FloatType lub TensorFloat32 (C21)
(I11) lhs_component_count stała typu si32 (C21), (C22)
(I12) rhs_component_count stała typu si32 (C21), (C23)
(I13) num_primitive_operations stała typu si32 (C21), (C24)
(I14) allow_imprecise_accumulation stała typu bool (C21)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C12), (C14), (C18-C20)

Ograniczenia

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9)dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
  • (C10)dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Jeśli operacja używa tensorów niekwantyzowanych:
    • (C13) element_type(lhs) = element_type(rhs).
  • Jeśli operacja używa skwantyzowanych tensorów:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) nie może być w rhs_contracting_dimensions.
    • Jeśli is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C19) Jeśli is_per_tensor_quantized(rhs), to is_per_tensor_quantized(result).
    • Jeśli !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result)
  • Jeśli !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Przykłady

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #sta<blehlo.dot
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimension>s = [1]
  ,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
  algorithm = #stablehlo.dot<_algorithm
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation >= false
  
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Więcej przykładów

dynamic_broadcast_in_dim

Semantyka

Ta operacja jest funkcjonalnie identyczna z operacją broadcast_in_dim, ale kształt wyniku jest określany dynamicznie za pomocą output_dimensions.

Operacja akceptuje też opcjonalne atrybuty known_expanding_dimensions, known_nonexpanding_dimensions, które wyrażają statyczną wiedzę o zachowaniu wymiarów podczas rozszerzania. Jeśli nie zostanie określony, zakłada się, że wszystkie wymiary mogą się rozszerzać.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantyzowany, (C1-C2), (C5-C6), (C9)
(I2) output_dimensions Tensor 1-wymiarowy typu całkowitego (C7)
(I3) broadcast_dimensions 1-wymiarowy tensor stały typu całkowitego (C2-C6)
(I4) known_expanding_dimensions 1-wymiarowy tensor stały typu całkowitego (C8-C9)
(I5) known_nonexpanding_dimensions 1-wymiarowy tensor stały typu całkowitego (C8-C9)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C1), (C3), (C5-C7)

Ograniczenia

  • (C1) element_type(result) jest określone wzorem:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand), z wyjątkiem tego, że quantization_dimension(operand), scales(operand)zero_points(operand) mogą różnić się od quantization_dimension(result), scales(result)zero_points(result) odpowiednio, w przeciwnym razie.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Dla wszystkich daxes(operand):
    • dim(operand, d) = 1 lub
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jeśli is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jeśli dim(operand, quantization_dimension(operand)) = 1, to scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • (C9)0 <= known_expanding_dimensions < rank(operand)
  • (C10)0 <= known_nonexpanding_dimensions < rank(operand)

Przykłady

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensio<ns = arra>yi64: 2, 1,
  known_expanding_dimensio<ns = a>rrayi64: 0,
  known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Więcej przykładów

dynamic_conv

Semantyka

Ta operacja jest funkcjonalnie identyczna z operacją convolution, ale dopełnienie jest określane dynamicznie za pomocą padding.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantowany na poziomie tensora; (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs tensor lub tensor skwantyzowany, (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding 2-wymiarowy tensor typu całkowitego (C4)
(I4) window_strides 1-wymiarowa stała tensora typu si64 (C2-C3)
(I5) lhs_dilation 1-wymiarowa stała tensora typu si64 (C5-C6)
(I6) rhs_dilation 1-wymiarowa stała tensora typu si64 (C7-C8)
(I7) window_reversal 1-wymiarowa stała tensora typu i1 (C9)
(I8) input_batch_dimension stała typu si64 (C10), (C13)
(I9) input_feature_dimension stała typu si64 (C11), (C13-C14)
(I10) input_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C12), (C13)
(I11) kernel_input_feature_dimension stała typu si64 (C14), (C18)
(I12) kernel_output_feature_dimension stała typu si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C17-C18)
(I14) output_batch_dimension stała typu si64 (C20)
(I15) output_feature_dimension stała typu si64 (C20), (C29)
(I16) output_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C19-C20)
(I17) feature_group_count stała typu si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count stała typu si64 (C10), (C15), (C22), (C23)
(I19) precision_config zmienna liczba wyliczeń DEFAULT, HIGH i HIGHEST (C24)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C25-C27), (C29), (C31-C33)

Ograniczenia

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9)size(window_reversal) = N - 2
  • (C10)dim(lhs, input_batch_dimension) % batch_group_count = 0
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dane input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dane kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Dane output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) jest zdefiniowane jako:
    • dim(lhs, input_batch_dimension) / batch_group_count jeżeli result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jeżeli result_dim = output_feature_dimension.
    • num_windows w innym przypadku, gdzie:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jeśli operacja używa tensorów niekwantyzowanych:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jeśli operacja używa skwantyzowanych tensorów:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jeśli is_per_axis_quantized(result), to quantization_dimension(result) = output_feature_dimension.
    • Jeśli is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32)expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
    • (C33) Jeśli is_per_tensor_quantized(rhs), to is_per_tensor_quantized(result).
    • Jeśli !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Przykłady

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strid<es = arra>yi64: 4, 4,
  lhs_dilati<on = arra>yi64: 2, 2,
  rhs_dilati<on = arra>yi64: 1, 1,
  window_revers<al = arrayi1: fa>lse, false,
  dimension_numbers = #stab<lehlo.convraw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions => [1, 2]
  ,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

Więcej przykładów

dynamic_gather

Semantyka

Ta operacja jest funkcjonalnie identyczna z operacją gather, przy czym slice_sizes jest określany dynamicznie jako wartość.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1), (C7), (C10-C12), (C14)
(I2) start_indices tensor typu całkowitego (C2), (C3), (C13)
(I3) slice_sizes Tensor 1-wymiarowy typu całkowitego (C8), (C11-C13)
(I4) offset_dims 1-wymiarowa stała tensora typu si64 (C1), (C4-C5), (C13)
(I5) collapsed_slice_dims 1-wymiarowa stała tensora typu si64 (C1), (C6-C8), (C13)
(I6) start_index_map 1-wymiarowa stała tensora typu si64 (C3), (C9), (C10)
(I7) index_vector_dim stała typu si64 (C2), (C3), (C13)
(I8) indices_are_sorted stała typu i1

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C5), (C13-C14)

Ograniczenia

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9)is_unique(start_index_map)
  • (C10)0 <= start_index_map < rank(operand)
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes), gdzie:
    • batch_dim_sizes = shape(start_indices) z wyjątkiem rozmiaru wymiaru start_indices odpowiadającego index_vector_dim, który nie jest uwzględniony.
    • offset_dim_sizes = shape(slice_sizes) z wyjątkiem rozmiarów wymiarów w slice_sizes odpowiadających collapsed_slice_dims, które nie są uwzględnione.
    • combine umieszcza batch_dim_sizes na osiach odpowiadających batch_dims, a offset_dim_sizes na osiach odpowiadających offset_dims.
  • (C14) element_type(operand) = element_type(result).

Przykłady

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vect>or_dim = 2,
  indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Więcej przykładów

dynamic_iota

Semantyka

Ta operacja jest funkcjonalnie identyczna z operacją iota, ale kształt wyniku jest określany dynamicznie za pomocą output_shape.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) output_shape Tensor 1-wymiarowy typu całkowitego (C1), (C2)
(I2) iota_dimension si64 (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C2)

Ograniczenia

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Przykłady

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

Więcej przykładów

dynamic_pad

Semantyka

Ta operacja jest funkcjonalnie identyczna z operacją pad, ale wartości edge_padding_low, edge_padding_highinterior_padding są określane dynamicznie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1), (C2), (C4)
(I2) padding_value Tensor 0-wymiarowy lub tensor skwantyzowany na poziomie tensora (C1)
(I3) edge_padding_low Tensor 1-wymiarowy typu całkowitego (C1), (C4)
(I4) edge_padding_high Tensor 1-wymiarowy typu całkowitego (C1), (C4)
(I5) interior_padding Tensor 1-wymiarowy typu całkowitego (C2-C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C3-C6)

Ograniczenia

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Przykłady

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Więcej przykładów

dynamic_reshape

Semantyka

Ta operacja jest funkcjonalnie identyczna z operacją reshape, ale kształt wyniku jest określany dynamicznie za pomocą output_shape.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantyzowany, (C1-C3)
(I2) output_shape Tensor 1-wymiarowy typu całkowitego (C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C1-C4)

Ograniczenia

  • (C1) element_type(result) jest określone wzorem:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand) z wyjątkiem tego, że quantization_dimension(operand)quantization_dimension(result) mogą się różnić.
  • (C2) size(operand) = size(result).
  • (C3) Jeśli is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Przykłady

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]

Więcej przykładów

dynamic_slice

Semantyka

Wyodrębnia wycinek z tensora operand przy użyciu dynamicznie obliczanych indeksów początkowych i tworzy tensor result. start_indices zawierają indeksy początkowe wycinka dla każdego wymiaru, które mogą ulec zmianie, a slice_sizes zawierają rozmiary wycinka dla każdego wymiaru. Bardziej formalnie: result[result_index] = operand[operand_index] gdzie:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1), (C2), (C4)
(I2) start_indices zmienna liczba tensorów 0-wymiarowych typu całkowitego (C2), (C3)
(I3) slice_sizes 1-wymiarowa stała tensora typu si64 (C2), (C4), (C5)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1), (C5)

Ograniczenia

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Przykłady

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Więcej przykładów

dynamic_update_slice

Semantyka

Zwraca tensor result, który jest równy tensorowi operand, z wyjątkiem tego, że wycinek rozpoczynający się od start_indices jest aktualizowany za pomocą wartości z tensora update. Bardziej formalnie result[result_index] jest zdefiniowane jako:

  • update[update_index] jeśli 0 <= update_index < shape(update), gdzie:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • operand[result_index] w przeciwnym razie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1-C4), (C6)
(I2) update tensor lub tensor skwantowany na poziomie tensora; (C2), (C3), (C6)
(I3) start_indices zmienna liczba tensorów 0-wymiarowych typu całkowitego (C4), (C5)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1)

Ograniczenia

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Przykłady

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
 < : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Więcej przykładów

wykładniczo

Semantyka

Wykonuje operację wykładniczą na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: exp z IEEE-754.
  • W przypadku liczb zespolonych: wykładnik zespolony.
  • W przypadku typów skwantowanych:dequantize_op_quantize(exponential, operand, type(result))

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Więcej przykładów

exponential_minus_one

Semantyka

Wykonuje operację wykładniczą minus jeden na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: expm1 z IEEE-754.
  • W przypadku liczb zespolonych: złożona funkcja wykładnicza minus jeden.
  • W przypadku typów skwantyzowanych:dequantize_op_quantize(exponential_minus_one, operand, type(result))

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]

Więcej przykładów

fft

Semantyka

Wykonuje transformaty Fouriera w przód i w tył dla rzeczywistych i zespolonych danych wejściowych/wyjściowych.

fft_type to jedna z tych wartości:

  • FFT: FFT z liczb zespolonych na liczby zespolone.
  • IFFT: Odwrotna złożona transformata FFT.
  • RFFT: przekazywanie rzeczywistej do złożonej szybkiej transformaty Fouriera.
  • IRFFT: odwrotna transformata FFT z liczb rzeczywistych na zespolone (pobiera liczby zespolone, zwraca rzeczywiste).

Bardziej formalnie, jeśli mamy funkcję fft, która przyjmuje jako dane wejściowe 1-wymiarowe tensory typów złożonych, generuje jako dane wyjściowe 1-wymiarowe tensory tych samych typów i oblicza dyskretną transformatę Fouriera:

W przypadku fft_type = FFT wartość result jest zdefiniowana jako ostateczny wynik serii obliczeń L, gdzie L = size(fft_length). Na przykład w przypadku L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Ponadto, biorąc pod uwagę funkcję ifft, która ma ten sam typ sygnatury i oblicza odwrotność funkcji fft:

W przypadku fft_type = IFFT wartość result jest odwrotnością obliczeń dla fft_type = FFT. Na przykład w przypadku L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Ponadto funkcja rfft, która przyjmuje 1-wymiarowe tensory typów zmiennoprzecinkowych, tworzy 1-wymiarowe tensory typów złożonych o tej samej semantyce zmiennoprzecinkowej i działa w ten sposób:

  • rfft(real_operand) = truncated_result, gdzie
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Gdy dyskretna transformata Fouriera jest obliczana dla operandów rzeczywistych, pierwsze N/2 + 1 elementów wyniku jednoznacznie definiuje resztę wyniku, więc wynik rfft jest obcinany, aby uniknąć obliczania zbędnych elementów).

W przypadku fft_type = RFFT wartość result jest zdefiniowana jako ostateczny wynik serii obliczeń L, gdzie L = size(fft_length). Na przykład w przypadku L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Na koniec, biorąc pod uwagę funkcję irfft, która ma taką samą sygnaturę typu i oblicza odwrotność funkcji rfft:

W przypadku fft_type = IRFFT wartość result jest odwrotnością obliczeń dla fft_type = RFFT. Na przykład w przypadku L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego (C1), (C2), (C4), (C5)
(I2) fft_type wyliczenie FFT, IFFT, RFFT i IRFFT (C2), (C5)
(I3) fft_length 1-wymiarowa stała tensora typu si64 (C1), (C3), (C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego (C2), (C4), (C5)

Ograniczenia

  • (C1) size(fft_length) <= rank(operand).
  • (C2) Relacja między typami elementów operandresult jest różna:
    • Jeśli fft_type = FFT, element_type(operand)element_type(result) mają ten sam typ złożony.
    • Jeśli fft_type = IFFT, element_type(operand)element_type(result) mają ten sam typ złożony.
    • Jeśli fft_type = RFFT, element_type(operand) jest typem zmiennoprzecinkowym, a element_type(result) jest typem złożonym o takiej samej semantyce zmiennoprzecinkowej.
    • Jeśli fft_type = IRFFT, element_type(operand) jest typem złożonym, a element_type(result) jest typem zmiennoprzecinkowym o takiej samej semantyce zmiennoprzecinkowej.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Jeśli wśród tensorów operandresult znajduje się tensor real typu zmiennoprzecinkowego, to shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) z wyjątkiem:
    • Jeśli fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Jeśli fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Przykłady

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = <#stablehloff>t_type FFT,
  fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

piętro

Semantyka

Wykonuje operację zaokrąglania w dół na każdym elemencie tensora operand i tworzy tensor result. Implementuje operację roundToIntegralTowardNegative ze specyfikacji IEEE-754. W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(floor, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Więcej przykładów

zbierać,

Semantyka

Zbiera wycinki z tensora operand na podstawie przesunięć określonych w tensorze start_indices i tworzy tensor result.

Ten diagram pokazuje, jak elementy w result są mapowane na elementy w operand na konkretnym przykładzie. Na diagramie wybrano kilka przykładowych resultindeksów i szczegółowo wyjaśniono, którym operandindeksom odpowiadają.

zbierać,

Bardziej formalnie: result[result_index] = operand[operand_index], gdzie:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index jest zdefiniowane jako:
    • start_indices[bi0, ..., :, ..., biN], gdzie bi to poszczególne elementy w batch_index, a : jest wstawiany w indeksie index_vector_dim, jeśli index_vector_dimrank(start_indices).
    • [start_indices[batch_index]] w przeciwnym razie.
  • Dla d_operandaxes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) jeśli d_operand = start_index_map[d_start].
    • full_start_index[d_operand] = 0 w przeciwnym razie.
  • Dla d_operandaxes(operand),
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jeśli d_operand = operand_batching_dims[i_batching]d_start = start_indices_batching_dims[i_batching].
    • full_batching_index[d_operand] = 0 w przeciwnym razie.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], gdzie oi to poszczególne elementy w offset_index, a 0 jest wstawiany w indeksach od collapsed_slice_dims do operand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Jeśli indices_are_sorted jest równe true, implementacja może założyć, że start_indices są posortowane względem start_index_map. W przeciwnym razie działanie jest niezdefiniowane. Bardziej formalnie: dla wszystkich i1 < i2 z zakresu indices(result), full_start_index(i1) <= full_start_index(i2).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices tensor typu całkowitego (C2-C3), (C14), (C17), (C22)
(I3) offset_dims 1-wymiarowa stała tensora typu si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims 1-wymiarowa stała tensora typu si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims 1-wymiarowa stała tensora typu si64 (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims 1-wymiarowa stała tensora typu si64 (C13-C17)
(I7) start_index_map 1-wymiarowa stała tensora typu si64 (C3), (C18-C19)
(I8) index_vector_dim stała typu si64 (C2-C3), (C15), (C22)
(I9) slice_sizes 1-wymiarowa stała tensora typu si64 (C9), (C12), (C20-C22)
(I10) indices_are_sorted stała typu i1

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C5), (C22-C23)

Ograniczenia

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9)slice_sizes[collapsed_slice_dims...] <= 1
  • (C10)is_sorted(operand_batching_dims)
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims))
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand)
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes), gdzie:
    • batch_dim_sizes = shape(start_indices) z wyjątkiem rozmiaru wymiaru start_indices odpowiadającego index_vector_dim, który nie jest uwzględniony.
    • offset_dim_sizes = slice_sizes z wyjątkiem rozmiarów wymiarów w slice_sizes odpowiadających collapsed_slice_dimsoperand_batching_dims.
    • combine umieszcza batch_dim_sizes na osiach odpowiadających batch_dims, a offset_dim_sizes na osiach odpowiadających offset_dims.
  • (C23) element_type(operand) = element_type(result).

Przykłady

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stable<hlo.gather
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vect>or_dim = 3,
  slice_siz<es = arrayi64: >1, 1, 2, 2,
  indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

Więcej przykładów

get_dimension_size

Semantyka

Zwraca rozmiar podanego dimension w operand. Bardziej formalnie:result = dim(operand, dimension). Semantyka dotyczy tylko komponentu kształtu typu. Typ elementu może być dowolny.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantyzowany, (C1)
(I2) dimension stała typu si64 (C1)

Wyniki

Nazwa Typ
result Tensor 0-wymiarowy typu si32

Ograniczenia

  • (C1) 0 <= dimension < rank(operand).

Przykłady

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3

Więcej przykładów

get_tuple_element

Semantyka

Wyodrębnia element na pozycji index w krotce operand i tworzy result. Bardziej formalnie: result = operand[index].

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand krotka (C1), (C2)
(I2) index stała typu si32 (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result dowolna wartość (C2)

Ograniczenia

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Przykłady

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]

Więcej przykładów

jeśli

Semantyka

Zwraca wynik wykonania dokładnie 1 funkcji z true_branch lub false_branch w zależności od wartości pred. Bardziej formalnie: result = pred ? true_branch() : false_branch().

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) pred Tensor 0-wymiarowy typu i1
(I2) true_branch funkcja (C1-C3)
(I3) false_branch funkcja (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, skwantowanych tensorów lub tokenów; (C3)

Ograniczenia

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Przykłady

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
  "stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10

Więcej przykładów

imag

Semantyka

Wyodrębnia część urojoną z każdego elementu tensora operand i tworzy tensor result. Bardziej formalnie: dla każdego elementu x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego (C1), (C2)

Ograniczenia

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) jest zdefiniowane jako:
    • complex_element_type(element_type(operand)) jeżeli is_complex(operand).
    • element_type(operand) w przeciwnym razie.

Przykłady

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]

Więcej przykładów

infeed

Semantyka

Odczytuje dane z pliku danych i generuje results.

Semantyka infeed_config jest zdefiniowana przez implementację.

results składają się z wartości ładunku, które są na początku, oraz tokena, który jest na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).

Wejścia

Etykieta Nazwa Typ
(I1) token token
(I2) infeed_config stała typu string

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, skwantowanych tensorów lub tokenów; (C1-C3)

Ograniczenia

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) lub is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Przykłady

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Więcej przykładów

iota

Semantyka

Wypełnia tensor output wartościami w kolejności rosnącej, zaczynając od zera, wzdłuż wymiaru iota_dimension. Bardziej formalnie:

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) iota_dimension si64 (C1)

Wyniki

Nazwa Typ Ograniczenia
output tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) 0 <= iota_dimension < rank(output).

Przykłady

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Więcej przykładów

is_finite

Semantyka

Sprawdza element po elemencie, czy wartość w x jest skończona (tzn. nie jest ani +Inf, ani -Inf, ani NaN), i tworzy tensor y. Implementuje operację isFinite z specyfikacji IEEE-754. W przypadku typów skwantyzowanych wynik jest zawsze true.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) x tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
y tensor typu logicznego (C1)

Ograniczenia

  • (C1) shape(x) = shape(y).

Przykłady

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]

Więcej przykładów

log

Semantyka

Wykonuje operację logarytmiczną na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: log z IEEE-754.
  • W przypadku liczb zespolonych: logarytm zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(log, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Więcej przykładów

log_plus_one

Semantyka

Wykonuje operację logarytm plus jeden na każdym elemencie tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: logp1 z IEEE-754.
  • W przypadku liczb zespolonych:complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))
  • W przypadku typów skwantyzowanych:dequantize_op_quantize(log_plus_one, operand, type(result))

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Więcej przykładów

logistyczny

Semantyka

Wykonuje operację logistyczną na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: division(1, addition(1, exp(-x))) z IEEE-754.
  • W przypadku liczb zespolonych: złożona funkcja logistyczna.
  • W przypadku typów skwantyzowanych:dequantize_op_quantize(logistic, operand, type(result))

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Więcej przykładów

mapa

Semantyka

Stosuje funkcję mapowania computation do inputs wzdłuż osi dimensions i tworzy tensor result.

Bardziej formalnie: result[result_index] = computation(inputs...[result_index]).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1-C4)
(I2) dimensions 1-wymiarowa stała tensora typu si64 (C3)
(I3) computation funkcja (C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1), (C4)

Ograniczenia

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation jest typu (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, gdzie Ei = element_type(inputs[i]) i E' = element_type(result).

Przykłady

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
    stablehlo.return %<0 :> tensori64
}) {
  dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]

Więcej przykładów

maksimum

Semantyka

Wykonuje operację max na poszczególnych elementach tensorów lhsrhs oraz tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: operator logiczny LUB.
  • W przypadku liczb całkowitych: maksymalna liczba całkowita.
  • W przypadku liczb zmiennoprzecinkowych: maximum z IEEE-754.
  • W przypadku liczb zespolonych: maksimum leksykograficzne dla pary (real, imaginary). Wprowadzenie porządku w liczbach zespolonych wiąże się z nieoczekiwaną semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).
  • W przypadku typów skwantyzowanych:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantowany na poziomie tensora; (C1)
(I2) rhs tensor lub tensor skwantowany na poziomie tensora; (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]

Więcej przykładów

minimum

Semantyka

Wykonuje operację minimum na elementach tensorów lhsrhs oraz tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: operator logiczny AND.
  • W przypadku liczb całkowitych: minimalna liczba całkowita.
  • W przypadku liczb zmiennoprzecinkowych: minimum z IEEE-754.
  • W przypadku liczb zespolonych: minimum leksykograficzne dla pary (real, imaginary). Wprowadzenie porządku w liczbach zespolonych wiąże się z nieoczekiwaną semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).
  • W przypadku typów skwantyzowanych:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantowany na poziomie tensora; (C1)
(I2) rhs tensor lub tensor skwantowany na poziomie tensora; (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]

Więcej przykładów

pomnóż

Semantyka

Wykonuje iloczyn elementów 2 tensorów lhsrhs i generuje tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: operator logiczny AND.
  • W przypadku liczb całkowitych: mnożenie liczb całkowitych.
  • W przypadku liczb zmiennoprzecinkowych: multiplication z IEEE-754.
  • W przypadku liczb zespolonych: mnożenie liczb zespolonych.
  • W przypadku typów skwantyzowanych:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor skwantowany na poziomie tensora; (C1)
(I2) rhs tensor lub tensor skwantowany na poziomie tensora; (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]

Więcej przykładów

negować,

Semantyka

Wykonuje negację elementów tensora operand i generuje tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb całkowitych ze znakiem: negacja liczby całkowitej.
  • W przypadku liczb całkowitych bez znaku: rzutowanie bitowe na liczbę całkowitą ze znakiem, negacja liczby całkowitej, rzutowanie bitowe z powrotem na liczbę całkowitą bez znaku.
  • W przypadku liczb zmiennoprzecinkowych: negate z IEEE-754.
  • W przypadku liczb zespolonych: negacja zespolona.
  • W przypadku typów skwantowanych:dequantize_op_quantize(negate, operand, type(result))

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]

Więcej przykładów

nie

Semantyka

Wykonuje operację NOT na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: operator logiczny NOT.
  • W przypadku liczb całkowitych: bitowe NIE.

Argumenty

Nazwa Typ Ograniczenia
operand tensor typu logicznego lub całkowitego, (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego lub całkowitego, (C1)

Ograniczenia

  • (C1) type(operand) = type(result).

Przykłady

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]

Więcej przykładów

optimization_barrier

Semantyka

Zapewnia, że operacje, które generują operand, są wykonywane przed operacjami zależnymi od result, i zapobiega przenoszeniu operacji przez barierę przez transformacje kompilatora. Poza tym operacja jest tożsamością, czyli result = operand.

Argumenty

Nazwa Typ Ograniczenia
operand zmienna liczba tensorów, skwantowane tensory lub tokeny dla każdego tensora; (C1)

Wyniki

Nazwa Typ Ograniczenia
result zmienna liczba tensorów, skwantowane tensory lub tokeny dla każdego tensora; (C1)

Ograniczenia

  • (C1) type(operand...) = type(result...).

Przykłady

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0

Więcej przykładów

lub

Semantyka

Wykonuje operację OR na poszczególnych elementach 2 tensorów lhsrhs oraz tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: operator logiczny LUB.
  • W przypadku liczb całkowitych: operacja bitowa LUB.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego lub logicznego (C1)
(I2) rhs tensor typu całkowitego lub logicznego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego lub logicznego (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]

Więcej przykładów

wyjście

Semantyka

Zapisuje inputs w pliku wyjściowym i generuje token result.

Semantyka outfeed_config jest zdefiniowana przez implementację.

Wejścia

Etykieta Nazwa Typ
(I1) inputs zmienna liczba tensorów lub tensorów skwantowanych,
(I2) token token
(I3) outfeed_config stała typu string

Wyniki

Nazwa Typ
result token

Przykłady

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token

Więcej przykładów

podkładka,

Semantyka

Rozszerza tensor operand, dodając dopełnienie wokół tensora oraz między jego elementami o podaną wartość padding_value.

Wartości edge_padding_lowedge_padding_high określają ilość dopełnienia dodanego odpowiednio na początku (obok indeksu 0) i na końcu (obok najwyższego indeksu) każdego wymiaru. Wartość dopełnienia może być ujemna. Wartość bezwzględna ujemnego dopełnienia wskazuje liczbę elementów do usunięcia z określonego wymiaru.

interior_padding określa ilość dopełnienia dodanego między dowolnymi 2 elementami w każdym wymiarze, która nie może być ujemna. Dopełnienie wewnętrzne występuje przed dopełnieniem krawędzi, więc ujemne dopełnienie krawędzi usunie elementy z operandu z dopełnieniem wewnętrznym.

Bardziej formalnie result[result_index] jest zdefiniowane jako:

  • operand[operand_index] if result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • padding_value w przeciwnym razie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1), (C2), (C4)
(I2) padding_value Tensor 0-wymiarowy lub tensor skwantyzowany na poziomie tensora (C1)
(I3) edge_padding_low 1-wymiarowa stała tensora typu si64 (C1), (C4)
(I4) edge_padding_high 1-wymiarowa stała tensora typu si64 (C1), (C4)
(I5) interior_padding 1-wymiarowa stała tensora typu si64 (C2-C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C3-C6)

Ograniczenia

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Przykłady

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_l<ow = arra>yi64: 0, 1,
  edge_padding_hi<gh = arra>yi64: 2, 1,
  interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Więcej przykładów

partition_id

Semantyka

Generuje partition_id bieżącego procesu.

Wyniki

Nazwa Typ
result Tensor 0-wymiarowy typu ui32

Przykłady

%result = "stablehlo.partition_id">;() : (<) - >tensorui32

Więcej przykładów

popcnt

Semantyka

Wykonuje zliczanie bitów ustawionych w tensorze operand i tworzy tensor result.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego (C1)

Ograniczenia

  • (C1) type(operand) = type(result).

Przykłady

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]

Więcej przykładów

moc

Semantyka

Wykonuje potęgowanie elementów tensora lhs przez tensor rhs i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb całkowitych: potęgowanie liczb całkowitych.
  • W przypadku liczb zmiennoprzecinkowych: pow z IEEE-754.
  • W przypadku liczb zespolonych: potęgowanie liczb zespolonych.
  • W przypadku typów skwantowanych: dequantize_op_quantize(power, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)
(I2) rhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Więcej przykładów

rzeczywiste,

Semantyka

Wyodrębnia część rzeczywistą z każdego elementu tensora operand i tworzy tensor result. Bardziej formalnie: dla każdego elementu x: real(x) = is_complex(x) ? real_part(x) : x.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego (C1), (C2)

Ograniczenia

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) jest zdefiniowane jako:
    • complex_element_type(element_type(operand)) jeżeli is_complex(operand).
    • element_type(operand) w przeciwnym razie.

Przykłady

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]

Więcej przykładów

recv

Semantyka

Otrzymuje dane z kanału z channel_id i generuje results.

Jeśli is_host_transfer ma wartość true, operacja przenosi dane z hosta. W przeciwnym razie przenosi dane z innego urządzenia na podstawie wartości parametru source_target_pairs. Ten flag powiela informacje podane w channel_type, dlatego w przyszłości planujemy zachować tylko jeden z nich (#666). Jeśli is_host_transfer = false, a wartość source_target_pairs to None lub pusta wartość, jest to uznawane za niezdefiniowane działanie.

results składają się z wartości ładunku, które są na początku, oraz tokena, który jest na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) token token
(I2) channel_id stała typu si64
(I3) channel_type wyliczenie DEVICE_TO_DEVICE i DEVICE_TO_HOST (C5)
(I4) is_host_transfer stała typu i1 (C5-C6)
(I5) source_target_pairs Stała tensora dwuwymiarowego typu si64 (C1-C4), (C6)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, skwantowanych tensorów lub tokenów; (C2-C4)

Ograniczenia

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, gdzie N jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C5) channel_type jest zdefiniowane jako:
    • DEVICE_TO_HOST jeżeli is_host_transfer = true,
    • DEVICE_TO_DEVICE w przeciwnym razie.

Przykłady

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)

Więcej przykładów

zmniejszyć

Semantyka

Stosuje funkcję redukcji body do tensorów inputsinit_values wzdłuż osi dimensions i tworzy tensory results.

Kolejność redukcji jest zdefiniowana przez implementację, co oznacza, że operacje bodyinit_values muszą tworzyć monoid, aby zagwarantować, że operacja daje takie same wyniki dla wszystkich danych wejściowych we wszystkich implementacjach. Jednak w przypadku wielu popularnych redukcji ten warunek nie jest spełniony. Np. dodawanie liczb zmiennoprzecinkowych dla body i zero dla init_values nie tworzą monoidu, ponieważ dodawanie liczb zmiennoprzecinkowych nie jest łączne.

Bardziej formalnie: results...[j0, ..., jR-1] = reduce(input_slices_converted), gdzie:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], gdzie : są wstawiane w miejscu dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) dla pewnego drzewa binarnegoschedule, gdzie:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule to zdefiniowane przez implementację pełne drzewo binarne, którego przejście w porządku infiksowym składa się z:
    • wartości input_slices_converted...[index] dla wszystkich indexindex_space(input_slices_converted) w rosnącym porządku leksykograficznym index.
    • W miejscach określonych przez implementację wstawiono zdefiniowaną przez nią liczbę znaków init_values_converted.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1-C4), (C6), (C7)
(I2) init_values zmienna liczba tensorów 0-wymiarowych lub tensorów skwantowanych na poziomie tensora (C2), (C3)
(I3) dimensions 1-wymiarowa stała tensora typu si64 (C4), (C5), (C7)
(I4) body funkcja (C6)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C3), (C7), (C8)

Ograniczenia

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body jest typu (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), gdzie is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...) z wyjątkiem tego, że nie są uwzględniane rozmiary wymiarów inputs... odpowiadające dimensions.
  • (C8) element_type(results[i]) = Ei dla wszystkich i[0,N).

Przykłady

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
  dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]

Więcej przykładów

reduce_precision

Semantyka

Wykonuje konwersję elementów tensora operand na inny typ zmiennoprzecinkowy, który używa wartości exponent_bitsmantissa_bits, a następnie z powrotem na pierwotny typ zmiennoprzecinkowy i tworzy tensor output.

Bardziej formalnie:

  • Bity mantysy pierwotnej wartości są aktualizowane w celu zaokrąglenia pierwotnej wartości do najbliższej wartości, którą można przedstawić za pomocą mantissa_bits, zgodnie z semantyką roundToIntegralTiesToEven.
  • Jeśli mantissa_bits jest mniejsza niż liczba bitów mantysy pierwotnej wartości, bity mantysy są obcinane do mantissa_bits.
  • Jeśli bity wykładnika wyniku pośredniego nie mieszczą się w zakresie podanym przez exponent_bits, wynik pośredni przekracza zakres i staje się nieskończonością z pierwotnym znakiem lub spada poniżej zakresu i staje się zerem z pierwotnym znakiem.
  • W przypadku typów skwantyzowanych wykonuje operację dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)
(I2) exponent_bits stała typu si32 (C2)
(I3) mantissa_bits stała typu si32 (C3)

Wyniki

Nazwa Typ Ograniczenia
output tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Przykłady

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Więcej przykładów

reduce_scatter

Semantyka

reduce_scatter

W każdej grupie procesów w siatce procesów StableHLO wykonuje redukcję za pomocą funkcji computations na wartościach tensora operand z każdego procesu, dzieli wynik redukcji wzdłuż osi scatter_dimension na części i rozprasza podzielone części między procesy, aby uzyskać tensor result.

Operacja dzieli siatkę procesów StableHLO na process_groups, co jest zdefiniowane w ten sposób:

  • cross_replica(replica_groups) jeśli channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) jeśli channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) jeśli channel_id > 0 and use_global_device_ids = true.

Następnie w każdym process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] dla wszystkich senderprocess_group, gdzie receiver_index = process_group.index(receiver).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1), (C2), (C7), (C8)
(I2) scatter_dimension stała typu si64 (C1), (C2), (C8)
(I3) replica_groups Stała tensora dwuwymiarowego typu si64 (C3-C5)
(I4) channel_id stała typu si64 (C6)
(I5) use_global_device_ids stała typu i1 (C6)
(I6) computation funkcja (C7)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C8-C9)

Ograniczenia

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_replicas, jeśli używana jest właściwość cross_replica_and_partition.
    • num_processes, jeśli używana jest właściwość flattened_ids.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Jeśli use_global_device_ids = true, to channel_id > 0.
  • (C7) computation ma typ (tensor<E>, tensor<E>) -> (tensor<E>), gdzie is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) z wyjątkiem:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9)element_type(result) = E

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
  %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
  "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimension = 1 :< i64,
  >replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
  channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Więcej przykładów

reduce_window

Semantyka

Stosuje funkcję redukcji body do okien o rozmiarach inputsinit_values i generuje results.

Na diagramie poniżej pokazujemy, jak elementy w results... są obliczane na podstawie inputs... na konkretnym przykładzie.

reduce_window

Bardziej formalnie:results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (patrz reduce), gdzie:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values zmienna liczba tensorów 0-wymiarowych lub tensorów skwantowanych na poziomie tensora (C1), (C13)
(I3) window_dimensions 1-wymiarowa stała tensora typu si64 (C4), (C5), (C15)
(I4) window_strides 1-wymiarowa stała tensora typu si64 (C6), (C7), (C15)
(I5) base_dilations 1-wymiarowa stała tensora typu si64 (C8), (C9), (C15)
(I6) window_dilations 1-wymiarowa stała tensora typu si64 (C10), (C11), (C15)
(I7) padding Stała tensora dwuwymiarowego typu si64 (C12), (C15)
(I8) body funkcja (C13)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1), (C14-C16)

Ograniczenia

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9)0 < base_dilations
  • (C10)size(window_dilations) = rank(inputs[0])
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body ma typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), gdzie is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows, gdzie:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei za wszystkie i w sklepie [0,N).

Przykłady

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
  wind>ow_dimensions = arrayi64: <2, 1,
  w>indow_strides = arrayi64: <4, 1,
  b>ase_dilations = arrayi64: 2,< 1,
  win>dow_dilations = arr<ayi64: 3, 1,
  p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]

Więcej przykładów

reszta

Semantyka

Wykonuje operację reszty z dzielenia na tensorach dzielnej lhs i dzielnika rhs oraz tworzy tensor result.

Formalnie znak wyniku jest pobierany z dzielnej, a wartość bezwzględna wyniku jest zawsze mniejsza niż wartość bezwzględna dzielnika. Reszta jest obliczana jako lhs - d * rhs, gdzie d jest określone wzorem:

  • W przypadku liczb całkowitych: stablehlo.divide(lhs, rhs).
  • W przypadku liczb zmiennoprzecinkowych: division(lhs, rhs) z IEEE-754 z atrybutem zaokrąglania roundTowardZero.
  • W przypadku liczb zespolonych: TBD (#997).
  • W przypadku typów skwantyzowanych:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

W przypadku typów elementów zmiennoprzecinkowych ta operacja jest przeciwieństwem operacji remainder ze specyfikacji IEEE-754, gdzie d jest wartością całkowitą najbliższą dokładnej wartości lhs/rhs, a w przypadku remisu wybierana jest liczba parzysta.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)
(I2) rhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]

Więcej przykładów

replica_id

Semantyka

Generuje replica_id bieżącego procesu.

Wyniki

Nazwa Typ
result Tensor 0-wymiarowy typu ui32

Przykłady

%result = "stablehlo.replica_id">;() : (<) - >tensorui32

Więcej przykładów

przekształcać,

Semantyka

Zmienia kształt tensora operand na tensor result. W praktyce oznacza to zachowanie tej samej reprezentacji kanonicznej, ale potencjalną zmianę kształtu, np. z tensor<2x3xf32> na tensor<3x2xf32> lub tensor<6xf32>.

Bardziej formalnie: result[result_index] = operand[operand_index], gdzie result_indexoperand_index mają tę samą pozycję w porządku leksykograficznym index_space(result)index_space(operand).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantyzowany, (C1-C3)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C1-C3)

Ograniczenia

  • (C1) element_type(result) jest określone wzorem:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand) z wyjątkiem tego, że quantization_dimension(operand)quantization_dimension(result) mogą się różnić.
  • (C2) size(operand) = size(result).
  • (C3) Jeśli is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Przykłady

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]

Więcej przykładów

odwróć

Semantyka

Odwraca kolejność elementów w operand wzdłuż określonego dimensions i tworzy tensor result. Bardziej formalnie: result[result_index] = operand[operand_index] gdzie:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 jeśli d w dimensions.
  • operand_index[d] = result_index[d] w przeciwnym razie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1), (C3)
(I2) dimensions 1-wymiarowa stała tensora typu si64 (C2), (C3)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1), (C3)

Ograniczenia

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Przykłady

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]

Więcej przykładów

rng

Semantyka

Generuje liczby losowe za pomocą algorytmu rng_distribution i tworzy tensor result o danym kształcie shape.

Jeśli rng_distribution = UNIFORM, liczby losowe są generowane zgodnie z rozkładem jednostajnym w przedziale [a, b). Jeśli a >= b, zachowanie jest nieokreślone.

Jeśli rng_distribution = NORMAL, liczby losowe są generowane zgodnie z rozkładem normalnym o średniej = a i odchyleniu standardowym = b. Jeśli b < 0, zachowanie jest nieokreślone.

Dokładny sposób generowania liczb losowych jest zdefiniowany w implementacji. Na przykład mogą być deterministyczne lub nie, a także mogą korzystać ze stanu ukrytego lub nie.

W rozmowach z wieloma zainteresowanymi stronami ta operacja została uznana za w zasadzie wycofaną, więc w przyszłości planujemy ją usunąć (#597).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) a Tensor 0-wymiarowy typu całkowitego, logicznego lub zmiennoprzecinkowego (C1), (C2)
(I2) b Tensor 0-wymiarowy typu całkowitego, logicznego lub zmiennoprzecinkowego (C1), (C2)
(I3) shape 1-wymiarowa stała tensora typu si64 (C3)
(I4) rng_distribution wyliczenie UNIFORM i NORMAL (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, logicznego lub zmiennoprzecinkowego; (C1-C3)

Ograniczenia

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Jeśli rng_distribution = NORMAL, to is_float(a).
  • (C3) shape(result) = shape.

Przykłady

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantyka

Zwraca ciąg output wypełniony jednorodnymi losowymi bitami i zaktualizowany stan wyjściowy output_state przy użyciu algorytmu generatora liczb pseudolosowych rng_algorithm na podstawie stanu początkowego initial_state. Wynik jest deterministyczną funkcją initial_state, ale nie musi być deterministyczny w przypadku różnych implementacji.

rng_algorithm to jedna z tych wartości:

  • DEFAULT: Algorytm zdefiniowany przez implementację.
  • THREE_FRY: zdefiniowana przez implementację odmiana algorytmu Threefry*.
  • PHILOX: zdefiniowana przez implementację odmiana algorytmu Philox*.

* Zobacz: Salmon et al. SC 2011. Równoległe liczby losowe: to proste jak 1, 2, 3.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) rng_algorithm enum: DEFAULT, THREE_FRYPHILOX (C2)
(I2) initial_state 1-wymiarowy tensor typu ui64 (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
output_state 1-wymiarowy tensor typu ui64 (C1)
output tensor typu całkowitego lub zmiennoprzecinkowego

Ograniczenia

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) jest zdefiniowane jako:
    • zdefiniowane przez implementację, jeśli rng_algorithm = DEFAULT.
    • 2 jeżeli rng_algorithm = THREE_FRY.
    • 2 lub 3, jeśli rng_algorithm = PHILOX.

Przykłady

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantyka

Wykonuje zaokrąglanie elementów tensora operand do najbliższej liczby całkowitej, w przypadku remisu zaokrąglając w kierunku zera, i tworzy tensor result. Implementuje operację roundToIntegralTiesToAway ze specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Więcej przykładów

round_nearest_even

Semantyka

Wykonuje zaokrąglanie elementów tensora operand do najbliższej liczby całkowitej, a w przypadku remisu zaokrągla do najbliższej parzystej liczby całkowitej, i tworzy tensor result. Implementuje operację roundToIntegralTiesToEven ze specyfikacji IEEE-754. W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(round_nearest_even, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor skwantyzowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Więcej przykładów

rsqrt

Semantyka

Wykonuje operację odwrotnego pierwiastka kwadratowego na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: rSqrt z IEEE-754.
  • W przypadku liczb zespolonych: złożony odwrotny pierwiastek kwadratowy.
  • W przypadku typów skwantowanych: dequantize_op_quantize(rsqrt, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Więcej przykładów

rozrzucać,

Semantyka

Tworzy tensory results, które są równe tensorom inputs, z wyjątkiem tego, że kilka wycinków określonych przez scatter_indices jest aktualizowanych za pomocą wartości updates przy użyciu update_computation.

Ten diagram pokazuje, jak elementy w updates... są mapowane na elementy w results... na konkretnym przykładzie. Na diagramie wybrano kilka przykładowych updates...indeksów i szczegółowo wyjaśniono, którym results...indeksom odpowiadają.

rozrzucać,

Bardziej formalnie, dla wszystkich update_indexindex_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index jest zdefiniowane jako:
    • scatter_indices[si0, ..., :, ..., siN], gdzie si to poszczególne elementy w update_scatter_index, a : jest wstawiany w indeksie index_vector_dim, jeśli index_vector_dim < rank(scatter_indices).
    • [scatter_indices[update_scatter_index]] w przeciwnym razie.
  • Dla d_inputaxes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] if d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 w przeciwnym razie.
  • Dla d_inputaxes(inputs[0]),
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jeśli d_input = input_batching_dims[i_batching]d_start = scatter_indices_batching_dims[i_batching].
    • full_batching_index[d_input] = 0 w przeciwnym razie.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], gdzie wi to poszczególne elementy w update_window_index, a 0 jest wstawiany w indeksach od inserted_window_dims do input_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

Biorąc pod uwagę, że results = exec(schedule, inputs), gdzie:

  • schedule to zdefiniowana przez implementację permutacja index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results) where:
    • Jeśli result_index mieści się w zakresie shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results to kopia elementu results, w której wartość results...[result_index] ustawiono na updated_values....
    • W innym przypadku
    • updated_results = results.
  • exec([], results) = results.

Jeśli indices_are_sorted ma wartość true, implementacja może założyć, że scatter_indices są posortowane względem scatter_dims_to_operand_dims. W przeciwnym razie działanie jest niezdefiniowane. Bardziej formalnie: dla wszystkich i1 < i2 z zakresu indices(result), full_start_index(i1) <= full_start_index(i2).

Jeśli unique_indices to true, implementacja może założyć, że wszystkie indeksy result_index, do których są rozpraszane dane, są niepowtarzalne. Jeśli unique_indices jest równe true, ale indeksy, do których są rozpraszane dane, nie są unikalne, działanie jest niezdefiniowane.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices tensor typu całkowitego (C4), (C15), (C19), (C22)
(I3) updates zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C3-C6), (C8)
(I4) update_window_dims 1-wymiarowa stała tensora typu si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims 1-wymiarowa stała tensora typu si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims 1-wymiarowa stała tensora typu si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims 1-wymiarowa stała tensora typu si64 (C14-C18)
(I8) scatter_dims_to_operand_dims 1-wymiarowa stała tensora typu si64 (C19-C21)
(I9) index_vector_dim stała typu si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted stała typu i1
(I11) unique_indices stała typu i1
(I12) update_computation funkcja (C23)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C24-C25)

Ograniczenia

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes), gdzie:
    • update_scatter_dim_sizes = shape(scatter_indices) z wyjątkiem tego, że rozmiar wymiaru scatter_indices odpowiadający index_vector_dim nie jest uwzględniony.
    • update_window_dim_sizes <= shape(inputs[0]) z wyjątkiem tego, że nie są uwzględniane rozmiary wymiarów w inputs[0] odpowiadające inserted_window_dimsinput_batching_dims.
    • combine umieszcza update_scatter_dim_sizes na osiach odpowiadających update_scatter_dims, a update_window_dim_sizes na osiach odpowiadających update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10)is_sorted(inserted_window_dims)
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation ma typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), gdzie is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei dla wszystkich i[0,N).

Przykłady

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ],
//           [
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//            [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
    "stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
  scatter_dimensio<n_numbers = #stablehlo.scatter
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2>, 1],
    index_vector_dim = 3,
  indices_are_sorted = false,
  uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

Więcej przykładów

wybierz

Semantyka

Tworzy tensor result, w którym każdy element jest wybierany z tensora on_true lub on_false na podstawie wartości odpowiedniego elementu tensora pred. Bardziej formalnie: result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], gdzie pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. W przypadku typów skwantyzowanych wykonuje działanie dequantize_select_quantize(pred, on_true, on_false, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) pred tensor typu i1 (C1)
(I2) on_true tensor lub tensor skwantowany na poziomie tensora; (C1-C2)
(I3) on_false tensor lub tensor skwantowany na poziomie tensora; (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C2)

Ograniczenia

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Przykłady

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]

Więcej przykładów

select_and_scatter

Semantyka

Rozprasza wartości z tensora source za pomocą tensora scatter na podstawie wyniku reduce_window tensora input za pomocą tensora select i tworzy tensor result.

Ten diagram pokazuje, jak elementy w result są obliczane na podstawie operandsource na konkretnym przykładzie.

select_and_scatter

Bardziej formalnie:

  • selected_values = reduce_window_without_init(...) z tymi danymi wejściowymi:

    • inputs = [operand].
    • window_dimensions, window_strides i padding, które są używane w niezmienionej formie.
    • base_dilations = windows_dilations = 1.
    • body jest zdefiniowane jako:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    gdzie E = element_type(operand) i reduce_window_without_init działają dokładnie tak samo jak reduce_window, z tym wyjątkiem, że schedule bazowego reduce (patrz reduce) nie zawiera wartości początkowych. Obecnie nie jest określone, co się stanie, jeśli odpowiednie okno nie będzie miało wartości (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter)gdzie:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index jeśli selected_values[source_index] zawiera element operandoperand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1-C4), (C6), (C8-C11)
(I2) source tensor lub tensor skwantowany na poziomie tensora; (C1), (C2)
(I3) init_value Tensor 0-wymiarowy lub tensor skwantyzowany na poziomie tensora (C3)
(I4) window_dimensions 1-wymiarowa stała tensora typu si64 (C2), (C4), (C5)
(I5) window_strides 1-wymiarowa stała tensora typu si64 (C2), (C6), (C7)
(I6) padding Stała tensora dwuwymiarowego typu si64 (C2), (C8)
(I7) select funkcja (C9)
(I8) scatter funkcja (C10)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C11-C12)

Ograniczenia

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows, gdzie:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select ma typ (tensor<E>, tensor<E>) -> tensor<i1>, gdzie E = element_type(operand).
  • (C10) scatter ma typ (tensor<E>, tensor<E>) -> tensor<E>, gdzie is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Przykłady

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>E
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
  ^bb0(%<arg>0: tensori64, %arg1: tensori64):
    %0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
  >  "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
  window_dim<ensions => arrayi64: 3, 1,
  <window_strides => arrayi64<: 2, 1,>
  padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Więcej przykładów

wyślij

Semantyka

Wysyła inputs na kanał channel_id. Dane wejściowe są następnie wysyłane na inne urządzenia w kolejności określonej przez source_target_pairs. Operacja generuje token result.

Jeśli is_host_transfer ma wartość true, operacja przesyła dane do hosta. W przeciwnym razie przenosi dane na inne urządzenie na podstawie wartości parametru source_target_pairs. Ten flag powiela informacje podane w channel_type, dlatego w przyszłości planujemy zachować tylko jeden z nich (#666). Jeśli is_host_transfer = false, a wartość source_target_pairs to None lub pusta wartość, jest to uznawane za niezdefiniowane działanie.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub tensorów skwantowanych,
(I2) token token
(I3) channel_id stała typu si64
(I4) channel_type wyliczenie DEVICE_TO_DEVICE i DEVICE_TO_HOST (C5)
(I5) is_host_transfer stała typu i1 (C5-C6)
(I6) source_target_pairs Stała tensora dwuwymiarowego typu si64 (C1-C4), (C6)

Wyniki

Nazwa Typ
result token

Ograniczenia

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, gdzie N jest zdefiniowane jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C5) channel_type jest zdefiniowane jako:
    • DEVICE_TO_HOST jeżeli is_host_transfer = true,
    • DEVICE_TO_DEVICE w przeciwnym razie.

Przykłady

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
  is_host_transfer = false,
  source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token

Więcej przykładów

shift_left

Semantyka

Wykonuje operację przesunięcia w lewo na poszczególnych elementach tensora lhsrhs bitów i zwraca tensor result.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego (C1)
(I2) rhs tensor typu całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]

Więcej przykładów

shift_right_arithmetic

Semantyka

Wykonuje operację przesunięcia bitowego w prawo na poszczególnych elementach tensora lhs o rhs bitów i tworzy tensor result.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego (C1)
(I2) rhs tensor typu całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]

Więcej przykładów

shift_right_logical

Semantyka

Wykonuje operację logicznego przesunięcia w prawo na poszczególnych elementach tensora lhs o rhs bitów i tworzy tensor result.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego (C1)
(I2) rhs tensor typu całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]

Więcej przykładów

podpisywanie

Semantyka

Zwraca znak elementu operand i tworzy tensor result. Bardziej formalnie, w przypadku każdego elementu x semantyka może być wyrażona za pomocą składni Pythona w ten sposób:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(sign, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor ze znakiem liczby całkowitej, zmiennoprzecinkowej lub zespolonej albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor ze znakiem liczby całkowitej, zmiennoprzecinkowej lub zespolonej albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Więcej przykładów

sinus

Semantyka

Wykonuje operację sinusową na każdym elemencie tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: sin z IEEE-754.
  • W przypadku liczb zespolonych: sinus zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(sine, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]

Więcej przykładów

wycinek

Semantyka

Wyodrębnia wycinek z operand za pomocą statycznie obliczonych indeksów początkowych i tworzy tensor result. start_indices zawierają indeksy początkowe wycinka dla każdego wymiaru, limit_indices zawierają indeksy końcowe (wyłączone) wycinka dla każdego wymiaru, a strides zawierają kroki dla każdego wymiaru.

Bardziej formalnie: result[result_index] = operand[operand_index] gdzie operand_index = start_indices + result_index * strides.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantowany na poziomie tensora; (C1-C3), (C5)
(I2) start_indices 1-wymiarowa stała tensora typu si64 (C2), (C3), (C5)
(I3) limit_indices 1-wymiarowa stała tensora typu si64 (C2), (C3), (C5)
(I4) strides 1-wymiarowa stała tensora typu si64 (C2), (C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantowany na poziomie tensora; (C1), (C5)

Ograniczenia

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Przykłady

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indic<es = arra>yi64: 1, 2,
  limit_indic<es = arra>yi64: 3, 4,
  strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Więcej przykładów

sortuj

Semantyka

Sortuje 1-wymiarowe wycinki tensora inputs wzdłuż wymiaru dimension zgodnie z funkcją comparator i zwraca tensor results.

W przeciwieństwie do podobnych danych wejściowych w innych operacjach dimension dopuszcza wartości ujemne, których znaczenie opisano poniżej. W przyszłości może to być niedozwolone ze względu na spójność (#1377).

Jeśli wartość is_stable to „true”, sortowanie jest stabilne, tzn. względna kolejność elementów uznanych za równe przez komparator jest zachowywana. W przypadku pojedynczego wejścia 2 elementy e1e2 są uznawane za równe przez komparator wtedy i tylko wtedy, gdy comparator(e1, e2) = comparator(e2, e1) = false. Poniżej znajdziesz formalizację, która pokazuje, jak uogólnić to na wiele danych wejściowych.

Bardziej formalnie, dla wszystkich result_indexindex_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], gdzie riN to poszczególne elementy w result_index, a : jest wstawiany w adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • gdzie sort sortuje 1-wymiarowy wycinek w kolejności niemalejącej, oczekując, że comparator_together zwróci true, jeśli argument po lewej stronie jest mniejszy niż argument po prawej stronie.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C1-C5)
(I2) dimension stała typu si64 (C4)
(I3) is_stable stała typu i1
(I4) comparator funkcja (C5)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub skwantowane tensory dla każdego tensora; (C2), (C3)

Ograniczenia

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, gdzie R = rank(inputs[0]).
  • (C5) comparator ma typ (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, gdzie Ei = element_type(inputs[i]).

Przykłady

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_di<rection = #stablehlocom>parison_directio<n G>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    "stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
  dimension = 0 : i64,
<  is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Więcej przykładów

sqrt

Semantyka

Wykonuje operację pierwiastka kwadratowego na każdym elemencie tensora operand i zwraca tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: squareRoot z IEEE-754.
  • W przypadku liczb zespolonych: pierwiastek kwadratowy z liczby zespolonej.
  • W przypadku typów skwantowanych: dequantize_op_quantize(sqrt, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]

Więcej przykładów

odejmować,

Semantyka

Wykonuje odejmowanie elementów dwóch tensorów lhsrhs i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb całkowitych: odejmowanie liczb całkowitych.
  • W przypadku liczb zmiennoprzecinkowych: subtraction z IEEE-754.
  • W przypadku liczb zespolonych: odejmowanie liczb zespolonych.
  • W przypadku typów skwantyzowanych:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)
(I2) rhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]

Więcej przykładów

tan

Semantyka

Wykonuje operację styczną na poszczególnych elementach tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: tan z IEEE-754.
  • W przypadku liczb zespolonych: tangens liczby zespolonej.
  • W przypadku typów skwantowanych: dequantize_op_quantize(tan, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

Więcej przykładów

tanh

Semantyka

Wykonuje operację tangensa hiperbolicznego na każdym elemencie tensora operand i tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku liczb zmiennoprzecinkowych: tanh z IEEE-754.
  • W przypadku liczb zespolonych: tangens hiperboliczny liczby zespolonej.
  • W przypadku typów skwantyzowanych:
    • dequantize_op_quantize(tanh, operand, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]

Więcej przykładów

transponować

Semantyka

Przestawia wymiary tensora operand za pomocą permutation i tworzy tensor result. Bardziej formalnie: result[result_index] = operand[operand_index] gdzie result_index[d] = operand_index[permutation[d]].

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor skwantyzowany, (C1-C4)
(I2) permutation 1-wymiarowa stała tensora typu si64 (C2-C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor skwantyzowany, (C1), (C3-C4)

Ograniczenia

  • (C1) element_type(result) jest określone wzorem:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand) z wyjątkiem tego, że quantization_dimension(operand)quantization_dimension(result) mogą się różnić.
  • (C2) permutation jest permutacją range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Jeśli is_per_axis_quantized(result), to quantization_dimension(operand) = permutation(quantization_dimension(result)).

Przykłady

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Więcej przykładów

triangular_solve

Semantyka

Rozwiązuje partie układów równań liniowych z macierzami współczynników w postaci trójkątnej dolnej lub górnej.

Bardziej formalnie, przy danych ab wartość result[i0, ..., iR-3, :, :] jest rozwiązaniem równania op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :], gdy left_side jest równe true, lub równania x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :], gdy left_side jest równe false. Rozwiązujemy równanie dla zmiennej x, gdzie wartość op(a) jest określana przez transpose_a, które może być jedną z tych wartości:

  • NO_TRANSPOSE: Wykonaj operację przy użyciu a w obecnej postaci.
  • TRANSPOSE: wykonaj operację na transponowanej macierzy a.
  • ADJOINT: wykonaj operację na transponowanej macierzy sprzężonej a.

Dane wejściowe są odczytywane tylko z dolnego trójkąta macierzy a, jeśli lower ma wartość true, lub z górnego trójkąta macierzy a w pozostałych przypadkach. Dane wyjściowe są zwracane w tym samym trójkącie. Wartości w drugim trójkącie są zdefiniowane przez implementację.

Jeśli wartość unit_diagonal to prawda, implementacja może założyć, że elementy diagonalne a są równe 1, w przeciwnym razie działanie jest niezdefiniowane.

W przypadku typów skwantyzowanych wykonuje działanie dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) a tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1-C3)
(I2) b tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1-C4)
(I3) left_side stała typu i1 (C3)
(I4) lower stała typu i1
(I5) unit_diagonal stała typu i1
(I6) transpose_a enum NO_TRANSPOSE, TRANSPOSEADJOINT

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego albo tensor skwantowany na poziomie tensora (C1)

Ograniczenia

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) Zależność między shape(a)shape(b) jest zdefiniowana w ten sposób:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Przykłady

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

krotka

Semantyka

Tworzy krotkę result z wartości val.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) val zmienna liczba wartości, (C1)

Wyniki

Nazwa Typ Ograniczenia
result krotka (C1)

Ograniczenia

  • (C1) result jest typu tuple<E0, ..., EN-1>, gdzie Ei = type(val[i]).

Przykłady

// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))

Więcej przykładów

uniform_dequantize

Semantyka

Przeprowadza konwersję tensora poddanego kwantyzacji operand na tensor zmiennoprzecinkowy result zgodnie z parametrami kwantyzacji zdefiniowanymi przez typ operand.

Bardziej formalnie: result = dequantize(operand).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand skwantowany tensor, (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego (C1), (C2)

Ograniczenia

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Przykłady

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]

uniform_quantize

Semantyka

Przeprowadza konwersję tensora zmiennoprzecinkowego lub tensora poddanego kwantyzacji operand na tensor poddany kwantyzacji result zgodnie z parametrami kwantyzacji zdefiniowanymi przez typ result.

Bardziej formalnie:

  • Jeśli is_float(operand):
    • result = quantize(operand, type(result)).
  • Jeśli is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub skwantyzowanego (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result skwantowany tensor, (C1), (C2)

Ograniczenia

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Przykłady

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]

podczas

Semantyka

Wykonuje body funkcję 0 lub więcej razy, dopóki cond funkcja nie zwróci wartości true. Bardziej formalnie semantykę można wyrazić za pomocą składni Pythona w ten sposób:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Działanie nieskończonej pętli jest do ustalenia (#383).

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) operand zmienna liczba wartości, (C1-C3)
(I2) cond funkcja (C1)
(I3) body funkcja (C2)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba wartości, (C3)

Ograniczenia

  • (C1) cond ma typ (T0, ..., TN-1) -> tensor<i1>, gdzie Ti = type(operand[i]).
  • (C2) body jest typu (T0, ..., TN-1) -> (T0, ..., TN-1), gdzie Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Przykłady

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_di<rection = #stablehlocom>parison_directio<n L>T
    } <: (>ten>sori64,< t>ensori64) - tensori1
    stablehlo.r<et>urn %cond : tensori1
  }, {
<  ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
    %new_sum = stablehlo.add <%ar>g1, %one : tensori64
    %new_i = stablehlo.add <%ar>g0, %one : tensori64
    stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10

Więcej przykładów

xor

Semantyka

Wykonuje operację XOR na poszczególnych elementach 2 tensorów lhsrhs oraz tworzy tensor result. W zależności od typu elementu wykonuje te czynności:

  • W przypadku wartości logicznych: logiczne XOR.
  • W przypadku liczb całkowitych: bitowa operacja XOR.

Wejścia

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu logicznego lub całkowitego, (C1)
(I2) rhs tensor typu logicznego lub całkowitego, (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego lub całkowitego, (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]

Więcej przykładów

Dialect Interop

Obecnie programy StableHLO w praktyce zawierają czasami operacje, które nie są zdefiniowane przez StableHLO.

Moduł, funkcja, wywołanie i wartość zwracana

StableHLO używa operacji MLIR wyższego poziomu w przypadku operacji ModuleOp, FuncOp, CallOp i ReturnOp. Zrobiliśmy to, aby zapewnić lepszą interoperacyjność z istniejącymi mechanizmami MLIR, ponieważ wiele przydatnych przekształceń jest napisanych pod kątem operacji FuncOp i ModuleOp, a wiele potoków kompilacji oczekuje, że te operacje będą obecne. W przypadku tych operacji obowiązują pełne gwarancje zgodności. Jeśli w przyszłości te operacje ulegną zmianom w sposób niezgodny z obecnym (np. zostaną usunięte), dodamy ich odpowiedniki w StableHLO, aby zachować zgodność.

CHLO

Zestaw operacji CHLO zawiera operacje wyższego poziomu, które są rozkładane na StableHLO. Obecnie nie ma gwarancji zgodności w przypadku CHLO. Aby zagwarantować zgodność, przed serializacją należy użyć chlo-legalize-to-stablehlo pass.

Operacje na kształtach

W społeczności często stosuje się w dynamicznych programach StableHLO pewne operacje z podstawowych dialektów MLIR do obliczania kształtów. Najczęściej są to operatory shapedialektu, takie jak shape_of lub num_elements, operatory tensordialektu, takie jak dim lub from_elements, oraz wbudowany typ index.

Dokument Dynamism RFC > O2 określa je jako wykraczające poza zakres, ale ze względu na interoperacyjność uwzględniono obsługę niektórych typów index. Nie ma gwarancji zgodności tych operacji ani typów. Za pomocą przekształcenia shape-legalize-to-stablehlo można przekonwertować te operacje na w pełni obsługiwane operacje StableHLO.

Wycofane operacje

Istnieje kilka operacji StableHLO odziedziczonych z MHLO, które są wycofywane i zostaną usunięte z StableHLO. Szczegółowe informacje o tych zmianach znajdziesz w StableHLO w wersji 1.0 – czyszczenie nr 2283. Problem śledzenia tych wycofań to #2340.

Te operacje dzielą się na kilka kategorii:

  • Operacje StableHLO z kategorii „Not in HLO” (nie należące do HLO) – początkowo były częścią zestawu operacji StableHLO, ale później uznano, że do niego nie pasują:broadcast, create_token, cross-replica-sum, dot, einsum,torch_index_select, unary_einsum (#3).
  • Nieużywane operacje – te operacje mogły być kiedyś przydatne, ale były niedopracowane lub potoki, które ich używały, zostały zmodyfikowane tak, aby nie były już potrzebne. Obejmuje to map, tuple (#598), get_tuple_element, rng, complex porównania #560, i splot window_reversal (#1181).

Niektóre z tych operacji można łatwo usunąć, ponieważ można je wyrazić za pomocą istniejących operacji (broadcast, create_token, cross-replica-sum, dot, unary_einsum). Zostaną one usunięte po upływie okresu zgodności (6 miesięcy). W przypadku innych operatorów rozważamy usunięcie (einsum, get_tuple_element, map, rng torch_index_select, tuple, complex porównania, window_reversal). W zależności od opinii społeczności te operatory zostaną usunięte lub dodane do specyfikacji z pełną obsługą. Dopóki nie będą znane przyszłe wersje tych systemów operacyjnych, gwarantujemy tylko 6-miesięczną kompatybilność.

Wykonanie

Wykonywanie sekwencyjne

Program StableHLO jest wykonywany przez podanie wartości wejściowych do funkcji main i obliczenie wartości wyjściowych. Wartości wyjściowe funkcji są obliczane przez wykonanie wykresu operacji zakorzenionych w odpowiedniej operacji return.

Kolejność wykonywania jest zdefiniowana przez implementację, o ile jest zgodna z przepływem danych, tzn. operacje są wykonywane przed ich użyciem. W StableHLO wszystkie operacje wywołujące efekty uboczne zużywają 1 token i wytwarzają 1 token (wiele tokenów można multipleksować w 1 token za pomocą operacji after_all), więc kolejność wykonywania efektów ubocznych jest również zgodna z przepływem danych. Na przykład w programie poniżej istnieją 2 możliwe kolejności wykonywania: %0%1%2return%1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Bardziej formalnie proces StableHLO to połączenie: 1) programu StableHLO, 2) stanów operacji (jeszcze nie wykonano, już wykonano) i 3) wartości pośrednich, nad którymi pracuje proces. Proces rozpoczyna się od wartości wejściowych funkcji main, przechodzi przez graf operacji aktualizujących stany operacji i wartości pośrednie, a kończy się wartościami wyjściowymi. Dalsze formalizowanie zostanie określone w przyszłości (#484).

Wykonanie równoległe

Programy StableHLO można wykonywać równolegle, organizując je w dwuwymiarową siatkę procesów o wymiarach num_replicas × num_partitions, gdzie oba wymiary są typu ui32.

W siatce procesów StableHLO jednocześnie wykonywanych jest num_replicas * num_partitions procesów StableHLO. Każdy proces ma unikalny process_id = (replica_id, partition_id), gdzie replica_idreplica_ids = range(num_replicas)partition_idpartition_ids = range(num_partitions) mają typ ui32.

Rozmiar siatki procesów jest znany statycznie dla każdego programu (w przyszłości planujemy uczynić go jawną częścią programów StableHLO #650), a pozycja w siatce procesów jest znana statycznie dla każdego procesu. Każdy proces ma dostęp do swojego położenia w siatce procesów za pomocą operacji replica_idpartition_id.

W siatce procesów programy mogą być takie same (w stylu „Jeden program, wiele danych”), różne (w stylu „Wiele programów, wiele danych”) lub coś pomiędzy. W przyszłości planujemy wprowadzić obsługę innych idiomów definiowania równoległych programów StableHLO, w tym GSPMD (#619).

W siatce procesów procesy są w większości od siebie niezależne – mają oddzielne stany operacji, oddzielne wartości wejściowe, pośrednie i wyjściowe, a większość operacji jest wykonywana oddzielnie między procesami, z wyjątkiem niewielkiej liczby operacji zbiorczych opisanych poniżej.

Większość operacji korzysta tylko z wartości z tego samego procesu, więc zwykle nie ma wątpliwości, do których wartości się odnosić, jeśli używa się ich nazw. Jednak w przypadku opisywania semantyki operacji zbiorowych jest to niewystarczające, dlatego używamy notacji name@process_id, aby odwoływać się do wartości name w ramach określonego procesu. (Z tej perspektywy niekwalifikowane name można traktować jako skrót od name@(replica_id(), partition_id())).

Kolejność wykonywania w różnych procesach jest zdefiniowana przez implementację, z wyjątkiem synchronizacji wprowadzanej przez komunikację typu punkt-punkt i operacje zbiorowe, jak opisano poniżej.

Komunikacja punkt-punkt

Procesy StableHLO mogą się ze sobą komunikować za pomocą kanałów StableHLO. Kanał jest reprezentowany przez dodatni identyfikator typu si64. Za pomocą różnych operacji można wysyłać wartości do kanałów i odbierać je z kanałów.

Dalsze formalizowanie, np. skąd pochodzą te identyfikatory kanałów, jak procesy programów je wykrywają i jakiego rodzaju synchronizację wprowadzają, zostanie określone w przyszłości (#484).

Komunikacja strumieniowa

Każdy proces StableHLO ma dostęp do 2 interfejsów strumieniowych:

  • Infeed, z którego można odczytywać dane.
  • Plik wyjściowy, do którego można zapisywać dane.

W przeciwieństwie do kanałów, które służą do komunikacji między procesami i dlatego mają procesy na obu końcach, w przypadku kanałów wejściowych i wyjściowych drugi koniec jest zdefiniowany przez implementację.

Dalsza formalizacja, np. jak komunikacja strumieniowa wpływa na kolejność wykonywania i jakiego rodzaju synchronizację wprowadza, zostanie określona w przyszłości (#484).

Operacje zbiorcze

W StableHLO jest 6 operacji zbiorowych: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permutereduce_scatter. Wszystkie te operacje dzielą procesy w siatce procesów StableHLO na grupy procesów StableHLO i wykonują wspólne obliczenia w każdej grupie procesów, niezależnie od innych grup procesów.

W każdej grupie procesów operacje zbiorcze mogą wprowadzać barierę synchronizacji. Dalsze sformalizowanie, np. wyjaśnienie, kiedy dokładnie następuje synchronizacja, jak dokładnie procesy docierają do tej bariery i co się stanie, jeśli tego nie zrobią, zostanie określone w przyszłości (#484).

Jeśli grupa procesów obejmuje komunikację między partycjami, tzn. w grupie procesów znajdują się procesy o różnych identyfikatorach partycji, wykonanie operacji zbiorczej wymaga kanału, a operacja zbiorcza musi zwracać dodatnią wartość channel_id typu si64. Komunikacja między replikami nie wymaga kanałów.

Obliczenia wykonywane przez operacje zbiorcze są specyficzne dla poszczególnych operacji i opisane w sekcjach powyżej. Strategie, według których siatka procesów jest dzielona na grupy procesów, są jednak wspólne dla tych operacji i opisane w tej sekcji. StableHLO obsługuje te 4 strategie:

cross_replica

W każdej grupie procesów odbywa się tylko komunikacja między replikami. Ta strategia przyjmuje replica_groups – listę list identyfikatorów replik – i oblicza iloczyn kartezjański replica_groupspartition_ids. replica_groups musi zawierać unikalne elementy i obejmować wszystkie replica_ids. Bardziej formalnie, używając składni Pythona:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Na przykład dla wartości replica_groups = [[0, 1], [2, 3]]num_partitions = 2 funkcja cross_replica zwróci wartość [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

W każdej grupie procesów odbywa się tylko komunikacja między partycjami. Ta strategia przyjmuje partition_groups – listę list identyfikatorów partycji – i oblicza iloczyn kartezjański partition_groups przez replica_ids. partition_groups musi zawierać unikalne elementy i obejmować wszystkie partition_ids. Bardziej formalnie, używając składni Pythona:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Na przykład dla wartości partition_groups = [[0, 1]]num_replicas = 4 funkcja cross_partition zwróci wartość [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

W każdej grupie procesów może dochodzić do komunikacji między replikami i między partycjami. Ta strategia przyjmuje replica_groups – listę list identyfikatorów replik – i oblicza iloczyny kartezjańskie każdej listy replica_group przez partition_ids. replica_groups musi zawierać unikalne elementy i obejmować wszystkie replica_ids. Bardziej formalnie, używając składni Pythona:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Na przykład dla wartości replica_groups = [[0, 1], [2, 3]]num_partitions = 2 funkcja cross_replica_and_partition zwróci wartość [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Ta strategia przyjmuje flattened_id_groups – listę list „spłaszczonych” identyfikatorów procesów w formacie replica_id * num_partitions + partition_id – i przekształca je w identyfikatory procesów. flattened_id_groups musi zawierać unikalne elementy i obejmować wszystkie process_ids. Bardziej formalnie, używając składni Pythona:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Na przykład dla flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4num_partitions = 2 funkcja flattened_ids zwróci [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Dokładność

Obecnie StableHLO nie gwarantuje dokładności numerycznej, ale w przyszłości może się to zmienić (#1156).

Semantyka wykonania operacji skwantowanej

Interpretacja skwantyzowanych operacji StableHLO może się różnić w zależności od wymagań i możliwości sprzętowych. Na przykład niektóre urządzenia mogą interpretować operacje kwantyzowane za pomocą strategii „dekwantyzacja, wykonanie operacji zmiennoprzecinkowej i kwantyzacja”. Inne mogą wykonywać całe obliczenia za pomocą arytmetyki liczb całkowitych. W związku z tym interpretacja skwantyzowanych operacji StableHLO zależy wyłącznie od konkretnej implementacji. Interpretacja kwantyzacji hybrydowej (#1575) powinna opierać się na jej semantyce określonej w specyfikacji (w 1792).

Błędy

Programy StableHLO są weryfikowane za pomocą obszernego zestawu ograniczeń dotyczących poszczególnych operacji, co wyklucza wiele klas błędów przed czasem działania. Nadal jednak mogą wystąpić błędy, np. przepełnienie liczb całkowitych, dostęp poza zakresem itp. O ile nie zaznaczono inaczej, wszystkie te błędy powodują zachowanie zdefiniowane przez implementację, ale w przyszłości może się to zmienić (#1157).

Wyjątki zmiennoprzecinkowe

Wyjątkiem od tej reguły są wyjątki zmiennoprzecinkowe w programach StableHLO, które mają dobrze zdefiniowane działanie. Operacje, które powodują wyjątki zdefiniowane w standardzie IEEE-754 (nieprawidłowa operacja, dzielenie przez zero, przepełnienie, niedomiar lub nieprecyzyjne wyjątki), generują domyślne wyniki (zgodnie ze standardem) i są kontynuowane bez podnoszenia odpowiedniej flagi stanu, podobnie jak obsługa wyjątków raiseNoFlag w standardzie. Wyjątki dotyczące niestandardowych operacji (np. złożonych operacji arytmetycznych i niektórych funkcji transcendentalnych) są zdefiniowane w implementacji.

Niezgodności kształtów

StableHLO obsługuje tensory o dynamicznym kształcie. Kształty muszą jednak być zgodne w czasie działania, w przeciwnym razie zachowanie jest nieokreślone. StableHLO nie udostępnia jawnie operacji, która mogłaby potwierdzić, że tensor ma w czasie działania określony kształt. Za wygenerowanie prawidłowego kodu odpowiada producent.

Poniższy program jest prawidłowy. Jednak w czasie działania programu dokładne kształty %arg0%arg1 muszą być takie same, w przeciwnym razie działanie programu jest nieokreślone:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notacja

Do opisu składni w tym dokumencie używamy zmodyfikowanej wersji notacji EBNF zgodnej z normą ISO (ISO/IEC 14977:1996, Wikipedia) z 2 modyfikacjami: 1) reguły są definiowane za pomocą znaku ::= zamiast =,

2) łączenie jest wyrażane za pomocą zestawienia, a nie znaku ,.

Do opisywania semantyki (np. w sekcjach „Typy”, „Stałe” i „Operacje”) używamy formuł opartych na składni Pythona rozszerzonej o obsługę zwięzłego wyrażania operacji na tablicach, jak opisano poniżej. Sprawdza się to w przypadku małych fragmentów kodu, ale w rzadkich przypadkach, gdy potrzebne są większe fragmenty kodu, używamy czystej składni Pythona, która jest zawsze wprowadzana w sposób jawny.

Wzory

Przyjrzyjmy się, jak działają formuły, na przykładzie z dot_generalspecyfikacji. Jedno z ograniczeń tej operacji wygląda tak:dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)

Nazwy użyte w tej formule pochodzą z 2 źródeł: 1) funkcji globalnych, np. dim, 2) definicji elementów odpowiedniego programu, np. lhs, lhs_batching_dimensions, rhsrhs_batching_dimensions, które są zdefiniowane w sekcji „Dane wejściowe” w dot_general.

Jak wspomnieliśmy powyżej, składnia tego wzoru jest oparta na języku Python i zawiera pewne rozszerzenia ułatwiające zwięzłość. Aby zrozumieć formułę, przekształćmy ją w czystą składnię Pythona.

A) W tych formułach używamy symbolu =, aby oznaczyć równość. Pierwszym krokiem do uzyskania składni Pythona jest więc zastąpienie symbolu = symbolem ==, jak poniżej:dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)

B) Te formuły obsługują też wielokropki (...), które przekształcają wyrażenia skalarne w wyrażenia tensorowe. W skrócie f(xs...) oznacza mniej więcej „dla każdego skalaru x w tensorze xs oblicz skalar f(x), a następnie zwróć wszystkie te wyniki skalarne razem jako wynik tensorowy”. W składni zwykłego Pythona nasza przykładowa formuła wygląda tak:[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]

Dzięki wielokropkom często można uniknąć pracy na poziomie poszczególnych skalarów. W niektórych skomplikowanych przypadkach można jednak używać składni niższego poziomu, np. w start_indices[bi0, ..., :, ..., biN] wzorzegathergather specyfikacji. Aby zachować zwięzłość, nie podajemy dokładnego formalizmu tłumaczenia takiej składni na czystego Pythona, mając nadzieję, że w poszczególnych przypadkach będzie ona nadal intuicyjnie zrozumiała. Jeśli niektóre formuły wydają Ci się niejasne, daj nam znać, a postaramy się je ulepszyć.

Zauważysz też, że w formułach używamy wielokropka do rozwijania wszelkiego rodzaju list, w tym tensorów, list tensorów (które mogą np. pochodzić z różnej liczby tensorów) itp. Jest to kolejny obszar, w którym nie podajemy dokładnego formalizmu (np. listy nie są nawet częścią systemu typów StableHLO), a zamiast tego polegamy na intuicyjnej zrozumiałości.

C) Ostatnim wartym uwagi narzędziem notacyjnym, którego używamy, jest niejawne rozgłaszanie. Chociaż zestaw operacji StableHLO nie obsługuje niejawnego rozgłaszania, formuły to robią, również w celu zwięzłości. Krótko mówiąc, jeśli skalar jest używany w kontekście, w którym oczekiwany jest tensor, skalar jest rozgłaszany do oczekiwanego kształtu.

Aby kontynuować przykład z dot_general, podajemy kolejne ograniczenie: 0 <= lhs_batching_dimensions < rank(lhs) Zgodnie ze specyfikacją dot_general, lhs_batching_dimensions jest tensorem, ale 0rank(lhs) są skalarami. Po zastosowaniu niejawnego rozgłaszania formuła zmieni się na [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

W przypadku zastosowania do konkretnego dot_general działania ta formuła zwróci tensor wartości logicznych. Gdy formuły są używane jako ograniczenia, ograniczenie jest spełnione, jeśli formuła zwraca wartość true lub tensor, który zawiera tylko elementy true.

Nazwy

W formułach zakres leksykalny obejmuje: 1) funkcje globalne, 2) definicje elementów,

3) definicje lokalne. Listę funkcji globalnych znajdziesz poniżej. Lista definicji elementów zależy od elementu programu, do którego odnosi się notacja:

  • W przypadku operacji definicje elementów obejmują nazwy wprowadzone w sekcjach „Dane wejściowe” i „Dane wyjściowe”.
  • W przypadku pozostałych elementów definicje elementów programu obejmują części strukturalne elementu programu, nazwane zgodnie z odpowiednimi nieterminalami EBNF. W większości przypadków nazwy tych części strukturalnych są uzyskiwane przez przekształcenie nazw nieterminali na format snake_case (np. IntegerLiteral => integer_literal), ale czasami nazwy są w tym procesie skracane (np. QuantizationStorageType => storage_type). W takim przypadku nazwy są wprowadzane w sposób jawny, podobnie jak sekcje „Dane wejściowe” i „Dane wyjściowe” w specyfikacjach operacji.
  • Dodatkowo definicje elementów programu zawsze zawierają self, aby odwoływać się do odpowiedniego elementu programu.

Wartości

Podczas obliczania formuł są używane te typy wartości:1) Value (wartości rzeczywiste, np. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; zawsze znają swoje typy),2) Placeholder (wartości przyszłe, np. lhs, rhs lub result; ich rzeczywiste wartości nie są jeszcze znane, znane są tylko ich typy),3) Type (typy zdefiniowane w sekcji „Typy”),4) Function (funkcje globalne zdefiniowane w sekcji „Funkcje”).

W zależności od kontekstu nazwy mogą odnosić się do różnych wartości. W sekcji „Semantyka” w przypadku operacji (i odpowiedników w przypadku innych elementów programu) zdefiniowana jest logika czasu działania, więc wszystkie dane wejściowe są dostępne jako Value. Z kolei sekcja „Constraints” (Ograniczenia) w przypadku operacji (i odpowiedników) definiuje logikę „czasu kompilacji”, czyli coś, co jest zwykle wykonywane przed czasem działania, więc tylko stałe dane wejściowe są dostępne jako Value, a inne dane wejściowe są dostępne tylko jako Placeholder.

Nazwy W sekcji „Semantyka” W sekcji „Ograniczenia”
Funkcje globalne Function Function
Stałe dane wejściowe Value Value
Dane wejściowe, które nie są stałe Value Placeholder
Wyniki Value Placeholder
Definicje lokalne Zależy od definicji Zależy od definicji

Rozważmy przykładową operację transpose:

%result = "stablehlo.transpose"(%operand) {
  permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32

W przypadku tej operacji permutation jest stałą, więc jest dostępna jako Value zarówno w semantyce, jak i w ograniczeniach. Z kolei wartości operandresult są dostępne jako Value w semantyce, ale tylko jako Placeholder w ograniczeniach.

Funkcje

Konstrukcja typów

Nie ma funkcji, których można użyć do tworzenia typów. Zamiast tego używamy bezpośrednio składni typów, ponieważ jest ona zwykle bardziej zwięzła. Np. (tensor<E>, tensor<E>) -> (tensor<E>) zamiast function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Funkcje dotyczące typów

  • element_type jest zdefiniowana dla typów tensorów i skwantowanych typów tensorów i zwraca odpowiednio część TensorElementType lub QuantizedTensorElementType odpowiedniego TensorType lub QuantizedTensorType.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value to skrót do is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value to skrót do is_quantized(x) and quantization_dimension(x) is None.

  • Sprawdza, czy typ x można promować do typu y.is_promotable(x: Type, y: Type) -> bool Gdy xyQuantizedTensorElementType, promocja jest stosowana tylko do storage_type. Ta konkretna wersja promocji jest obecnie używana w kontekście obliczania obniżki (więcej informacji znajdziesz w RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value to skrót do: is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Dostępne dla wszystkich typów. Na przykład is_float(x) zwraca true, jeśli x jest FloatType. Jeśli x jest wartością lub symbolem zastępczym, ta funkcja jest skrótem do is_type_name(type(x)).

  • max_value(x: Type) -> Value zwraca maksymalną wartość z TensorElementType. Jeśli x nie jest TensorElementType, zwraca None.

  • min_value(x: Type) -> Value zwraca najmniejszą możliwą wartość an TensorElementType. Jeśli x nie jest TensorElementType, zwraca None.

  • member_name(x: Value | Placeholder | Type) -> Any. Dostępne dla wszystkich definicji elementów member_name wszystkich typów. Na przykład tensor_element_type(x) zwraca część TensorElementType odpowiedniego elementu TensorType. Jeśli x jest wartością lub symbolem zastępczym, ta funkcja jest skrótem do member_name(type(x)). Jeśli x nie jest typem, który ma odpowiedni element, lub wartością bądź symbolem zastępczym takiego typu, zwraca None.

  • is_empty_algorithm(*args: Type) sprawdza, czy wszystkie pola algorytmu kropkowego są ustawione na None. Jest to konieczne, ponieważ algorytmy kropkowe mają zdefiniowane w implementacji domyślne zachowania, więc określenie wartości domyślnej byłoby nieprawidłowe.

Budowanie wartości

  • operation_name(*xs: Value | Type) -> Value. Dostępne w przypadku wszystkich operacji. Na przykład funkcja add(lhs, rhs) przyjmuje 2 wartości tensora lhsrhs oraz zwraca wynik obliczenia operacji add z tymi danymi wejściowymi. W przypadku niektórych operacji, np. broadcast_in_dim, typy ich wyników są „nośne”, tzn. potrzebne do oceny operacji. W takim przypadku funkcja przyjmuje te typy jako argumenty.

Funkcje dotyczące wartości

  • Dostępne są wszystkie operatory i funkcje Pythona. Np. zarówno notacja subscription, jak i slicing z Pythona są dostępne do indeksowania tensorów, tensorów skwantowanych i krotek.

  • to_destination_type(x: Value, destination_type: Type) -> Value jest zdefiniowana na tensorach i zwraca przekonwertowaną wartość x na podstawie type(x)destination_type w ten sposób:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Trwają wstępne dyskusje na temat połączenia operacji convert, uniform_quantizeuniform_dequantize (#1576). Po połączeniu nie potrzebujemy już powyższej funkcji i zamiast convert możemy używać nazwy operacji.

  • Funkcja is_nan(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli wszystkie elementy xNaN, lub false w przeciwnym razie. Jeśli x nie jest tensorem, zwraca None.

  • is_sorted(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli elementy tensora x są posortowane w kolejności rosnącej względem rosnącej kolejności leksykograficznej ich indeksów, lub false w przeciwnym razie. Jeśli x nie jest tensorem, zwraca wartość None.

  • Funkcja is_unique(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli x nie zawiera zduplikowanych elementów, lub wartość false w przeciwnym razie. Jeśli x nie jest tensorem, zwraca None.

  • member_name(x: Value) -> Any jest zdefiniowany dla wszystkich definicji elementówmember_name wszystkich wartości. Na przykład real_part(x) zwraca część RealPart odpowiedniego elementu ComplexConstant. Jeśli x nie jest wartością, która ma odpowiedni element, zwraca None.

  • same(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli elementy x są równe, lub false w przeciwnym razie. Jeśli tensor nie zawiera elementów, jest to traktowane jako „wszystkie elementy są równe”, tzn. funkcja zwraca wartość true. Jeśli x nie jest tensorem, zwraca None.

  • Funkcja split(x: Value, num_results: Value, axis: Value) -> Value jest zdefiniowana na tensorach i zwraca num_results wycinków tensora x wzdłuż osi axis. Jeśli x nie jest tensorem ani dim(x, axis) % num_results != 0, zwraca None.

  • is_defined_in_parent_scope(x: Value) -> Value jest zdefiniowana na ciągach znaków i zwraca true, jeśli x jest nazwą funkcji zdefiniowanej w tym samym zakresie co funkcja nadrzędna odpowiedniej operacji.

  • is_namespaced_op_name(x: Value) -> Value jest zdefiniowane w ciągach i zwraca true, jeśli x jest prawidłową nazwą operacji, czyli spełnia to wyrażenie regularne: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Obliczenia kształtów

  • axes(x: Value | Placeholder | Type) -> Value to skrót do: range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value to skrót do: shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List to skrót do: list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value jest zdefiniowane na tensorach i zwraca indeksy size(x) dla odpowiednich TensorType posortowanych w porządku rosnącym, czyli [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Jeśli x nie jest typem tensora, skwantyzowanym typem tensora, wartością lub symbolem zastępczym jednego z tych typów, zwraca None.

  • rank(x: Value | Placeholder | Type) -> Value to skrót do: size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value jest zdefiniowane w sekcji „Funkcje w typach” za pomocą member_name.

  • size(x: Value | Placeholder | Type) -> Value to skrót do: reduce(lambda x, y: x * y, shape(x)).

Obliczenia kwantyzacji

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type to skrót do element_type(baseline_type(x)).

  • baseline_type jest zdefiniowany dla typów tensorów i skwantowanych typów tensorów i przekształca je w „wartość bazową”, czyli typ o tym samym kształcie, ale z parametrami kwantyzacji typu elementu zresetowanymi do wartości domyślnych. Jest to przydatny trik do jednolitego porównywania typów tensorów i skwantowanych tensorów, co jest dość często potrzebne. W przypadku typów skwantyzowanych umożliwia to porównywanie typów z pominięciem parametrów kwantyzacji, tzn. wszystkie parametry shape, storage_type, expressed_type, storage_min, storage_maxquantization_dimension (w przypadku typu skwantyzowanego na osi) muszą być zgodne, ale parametry scaleszero points mogą się różnić.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize jest zdefiniowany w przypadku typów tensorów skwantowanych i przekształca je w typy tensorów zmiennoprzecinkowych. Odbywa się to poprzez przekształcenie skwantyzowanych elementów, które reprezentują wartości całkowite typu pamięci, na odpowiednie wartości zmiennoprzecinkowe typu wyrażonego za pomocą punktu zerowego i skali powiązanych z typem skwantyzowanego elementu.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize jest zdefiniowana dla typów tensorów zmiennoprzecinkowych i przekształca je w skwantowane typy tensorów. Odbywa się to przez przekształcenie wartości zmiennoprzecinkowych wyrażonego typu na odpowiednie wartości całkowite typu pamięci za pomocą punktu zerowego i skali powiązanych z kwantyzowanym typem elementu.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize służy do określania obliczeń na poszczególnych elementach skwantowanych tensorów. Dekompresuje, czyli przekształca skwantowane elementy w ich typy wyrażone, a następnie wykonuje operację i kwantyzuje, czyli przekształca wyniki z powrotem w ich typy pamięci. Obecnie ta funkcja działa tylko w przypadku kwantyzacji na poziomie tensora. Kwantyzacja na osi jest w toku (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op służy do określania kwantyzacji tylko wagi w przypadku operacji hybrydowej, która akceptuje lewą stronę w postaci liczby zmiennoprzecinkowej, a prawą w postaci typu skwantowanego. Dekompresuje skwantowane dane wejściowe do ich wyrażonych typów i wykonuje obliczenia w formacie zmiennoprzecinkowym. Typ elementu tensora zmiennoprzecinkowego po lewej stronie i typ wyrażony skwantowanego tensora po prawej stronie powinny być identyczne.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Obliczenia w siatce

  • cross_partition(replica_groups: Value) -> Value. Patrz sekcja „cross_replica” powyżej.

  • cross_replica(replica_groups: Value) -> Value. Patrz sekcja „cross_replica” powyżej.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Patrz sekcja „cross_replica_and_partition” powyżej.

  • flattened_ids(replica_groups: Value) -> Value. Patrz sekcja „flattened_ids” powyżej.

Dynamizm

Wartości StableHLO mogą mieć dynamiczne rozmiary wymiarów, np. tensor<?xi64>. Wartości StableHLO nie mogą jednak mieć dynamicznej liczby wymiarów (dynamizm bez rangi, np. tensor<*xi64>). Operandy i wyniki mogą używać dynamicznych rozmiarów wymiarów, nawet jeśli istnieją ograniczenia dotyczące rozmiarów. Ograniczenia będą weryfikowane statycznie, jeśli to możliwe. W przeciwnym razie weryfikacja zostanie odroczona do czasu wykonania, a niezgodności spowodują nieokreślone zachowanie. Przykłady znajdziesz poniżej.

Niezgodność kształtów w przypadku jednoargumentowych operacji na elementach

Rozważmy ten program dotyczący zabawek:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Taki program jest nietypowy, ponieważ zwykle znamy kształt danych wejściowych, a nie kształt wyniku. Jest to jednak prawidłowy program StableHLO. Nie można statycznie zweryfikować operacji abs w tym programie, ponieważ dokładny kształt operandu jest nieznany. Kształty są jednak z pewnością zgodne, co można sprawdzić statycznie: ? może w czasie działania programu okazać się 2 i nie będzie z tym żadnego problemu. Jednak ? może też być inną liczbą całkowitą, w którym to przypadku działanie jest niezdefiniowane.

Pamiętaj, że jeśli rozmiar wymiaru jest dynamiczny w wyniku, nie może wystąpić niezdefiniowane zachowanie. Nie ma bowiem „oczekiwanego” rozmiaru, więc nie może wystąpić niezgodność.

Niezgodność kształtów w przypadku binarnych operacji na poszczególnych elementach

Rozważmy ten program dotyczący zabawek:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

W przypadku binarnych operacji na elementach kształty danych wejściowych i wyniku muszą być zgodne w czasie działania. W momencie kompilacji wymiary statyczne muszą być równe, w przeciwnym razie muszą być tylko zgodne. Jeśli którykolwiek wymiar jest dynamiczny w danych wejściowych, w czasie działania może wystąpić niezdefiniowane zachowanie, ponieważ rozmiar dynamiczny może nie pasować do odpowiedniego rozmiaru w drugim operacji (statycznym lub dynamicznym). Jeśli wszystkie dane wejściowe są statyczne, to czy wynik jest dynamiczny, czy nie, nie ma znaczenia: wymiary znane statycznie będą sprawdzane statycznie, a wymiary dynamiczne nie narzucają żadnych ograniczeń.

Niezgodność kształtów w przypadku operacji, które przyjmują kształt danych wyjściowych jako operand

Rozważmy ten program dotyczący zabawek:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Wartości w operacji kształtu w czasie działania muszą być zgodne z kształtem wyniku, w przeciwnym razie zachowanie jest niezdefiniowane. Oznacza to, że w czasie działania programu zmienna %arg0 musi mieć wartość dense<[3, 4]> : tensor<2xi32>. Jeśli operand kształtu jest stałą, można to sprawdzić statycznie. Jeśli kształt wyniku jest w pełni dynamiczny, nie może wystąpić niezgodność.