Specyfikacja StableHLO

StableHLO to zestaw operacji wysokiego poziomu (HLO) w modelach systemów uczących się. StableHLO działa jako warstwa przenoszenia między różnymi platformami ML i kompilatorami ML: platformy ML tworzące programy StableHLO są zgodne z kompilatorami ML, które korzystają z programów StableHLO.

Naszym celem jest uproszczenie i przyspieszenie programowania ML przez zapewnienie większej interoperacyjności między różnymi platformami ML (np. TensorFlow, JAX i PyTorch) i kompilatorami ML (np. XLA i IREE). W tym celu zamieściliśmy specyfikację języka programowania StableHLO.

Specyfikacja zawiera 3 główne sekcje. Sekcja Programy opisuje strukturę programów StableHLO, które składają się z funkcji StableHLO, które składają się z operacji StableHLO. W tej strukturze sekcja Operacje określa semantykę poszczególnych operacji. Sekcja Wykonanie zawiera semantykę wszystkich tych operacji wykonywanych razem w programie. W sekcji Notacja omawiamy też zapis stosowany w całej specyfikacji.

Aby wyświetlić specyfikację z poprzedniej wersji StableHLO, otwórz repozytorium w oznaczonym wydaniu. Na przykład specyfikacja StableHLO w wersji 0.19.0. Aby wyświetlić zmiany wprowadzone w każdej małej wersji StableHLO, zapoznaj się z logiem wersji w pliku VhloDialect.td.

Programy

Program ::= {Func}

Programy StableHLO składają się z dowolnej liczby funkcji StableHLO. Poniżej znajduje się przykładowy program z funkcją @main, która ma 3 parametry wejściowe (%image, %weights%bias) oraz 1 wyjście. Treść funkcji zawiera 6 operacji.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Funkcje

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Funkcje StableHLO (nazywane też funkcjami nazwanymi) mają identyfikator, wejścia/wyjścia i ciało. W przyszłości planujemy wprowadzić dodatkowe metadane funkcji, aby zwiększyć zgodność z HLO (#425, #626, #740, #744).

Identyfikatory

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Identyfikatory StableHLO są podobne do identyfikatorów w wielu językach programowania, ale mają 2 szczególne cechy: 1) wszystkie identyfikatory mają sygnatury, które odróżniają różne rodzaje identyfikatorów, 2) identyfikatory wartości mogą być całkowicie numeryczne, aby uprościć generowanie programów StableHLO.

Typy

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Typy StableHLO są podzielone na typy wartości (nazywane też typami najwyższej klasy), które reprezentują wartości StableHLO, oraz typy bez wartości, które opisują inne elementy programu. Typy StableHLO są podobne do typów w wielu językach programowania, a ich główną osobliwością jest specyfika danej domeny, która powoduje pewne nietypowe wyniki (np. typy skalarne nie są typami wartości).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'

Typy Tensor reprezentują tensory, czyli tablice wielowymiarowe. Mają one kształt i typ elementu, gdzie kształt reprezentuje nieujemne lub nieznane rozmiary wymiarów w kolejności rosnącej odpowiadających im wymiarów (zwanych też osiami), których numery wahają się od 0 do R-1. Liczba wymiarów R to ranking. Na przykład tensor<2x3xf32> to typ tensora o kształcie 2x3 i typie elementu f32. Ma 2 wymiary (czyli 2 osi) – 0 i 1, których rozmiary wynoszą odpowiednio 2 i 3. Jego pozycja to 2.

Kształty mogą być częściowo lub całkowicie nieznane (dynamiczne), np. tensor<?x2xf64> jest częściowo nieznany, a tensor<?x?xf64> jest całkowicie nieznany. Dynamiczne rozmiary wymiarów są reprezentowane za pomocą ?. Kształty nie mogą być niesklasyfikowane.

W przyszłości planujemy rozszerzyć typy tensorów poza rozmiary wymiarów i typy elementów, np. o układy (#629) i sparsity (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nazwa Typ Ograniczenia
storage_type typ liczba całkowita (C1-C3), (C8)
storage_min stała liczbowa (C1), (C3), (C7)
storage_max stała liczbowa (C2), (C3), (C7)
expressed_type typ zmiennoprzecinkowy, (C4)
quantization_dimension opcjonalna liczba całkowita (C10-C12)
scales zmienna liczba stałych zmiennoprzecinkowych (C4-C6), (C9), (C10), (C13)
zero_points zmienna liczba stałych całkowitych (C7-C9)

Typy elementów z kwantyzacją reprezentują wartości całkowite typu magazynowania w zakresie od storage_min do storage_max (łącznie) odpowiadające wartościom zmiennoprzecinkowym typu wyrażonego. Dla danej wartości całkowitej i odpowiadającą wartość zmiennoprzecinkową f można obliczyć jako f = (i - zero_point) * scale, gdzie scalezero_point to parametry kwantyzacji. Parametry storage_minstorage_max są opcjonalne w gramatyce, ale mają wartości domyślne odpowiednio min_value(storage_type)max_value(storage_type). Elementy typu „kwantowany” mają te ograniczenia:

  • (C1) type(storage_min) = storage_type.
  • (C2) type(storage_max) = storage_type.
  • (C3) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C4) type(scales...) = expressed_type.
  • (C5) 0 < scales.
  • (C6) is_finite(scales...).
  • (C7) storage_min <= zero_points <= storage_max.
  • (C8) type(zero_points...) = storage_type.
  • (C9) size(scales) = size(zero_points).
  • (C10) Jeśli is_empty(quantization_dimension), to size(scales) = 1.
  • (C11) 0 <= quantization_dimension.

Obecnie QuantizationScale jest stałą zmiennoprzecinkową, ale istnieje duże zainteresowanie skalami opartymi na liczbach całkowitych, reprezentowanymi przez mnożniki i przesunięcia. W najbliższej przyszłości planujemy wziąć to pod uwagę (#1404).

Trwa dyskusja na temat semantyki funkcji QuantizationZeroPoint, w tym jej typu, wartości i tego, czy w skwantyzowanym typie tensora może istnieć tylko jeden punkt zerowy, czy może on mieć większą liczbę punktów zerowych. Na podstawie wyników tej dyskusji specyfikacja dotycząca punktów 0 może w przyszłości ulec zmianie (#1405).

Kolejna trwająca dyskusja dotyczy semantyki QuantizationStorageMin i QuantizationStorageMax, aby określić, czy należy nałożyć jakieś ograniczenia na te wartości i wartości skończonych tensorów (#1406).

Na koniec planujemy zgłębić temat przedstawiania nieznanych skal i punktów zerowych w sposób podobny do tego, w jaki planujemy przedstawiać nieznane rozmiary wymiarów (#1407).

Typy kwantyzowanych tensorów reprezentują tensory z kwantyzowanymi elementami. Te tensory są dokładnie takie same jak zwykłe tensory, z tą różnicą, że ich elementy mają kwantowane typy elementów zamiast zwykłych typów elementów.

W przypadku skończonych tensorów kwantyzowanie może być na poziomie tensora, co oznacza, że dla całego tensora jest jeden element scalezero_point, lub na poziomie osi, co oznacza, że dla danej osi danego wymiaru quantization_dimension jest wiele elementów scaleszero_points. Bardziej formalnie, w tensorze t z kwantyzacją na osi występują dim(t, quantization_dimension) krojenia quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] itd. Wszystkie elementy w itym krojeniu używają wartości scales[i]zero_points[i] jako parametrów kwantowania. Typy zaokrąglonych tensorów mają te ograniczenia:

  • W przypadku kwantyzacji na centrów:
    • Brak dodatkowych ograniczeń.
  • W przypadku kwantyzacji na oś:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Typy tokenów to tokeny, czyli nieprzezroczyste wartości tworzone i wykorzystywane przez niektóre operacje. Tokeny są używane do określania kolejności wykonywania operacji zgodnie z opisem w sekcji Wykonywanie.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Typy tupla reprezentują tuple, czyli listy niejednorodne. Kropki to starsza funkcja, która zapewnia zgodność z HLO. W HLO tuple służą do reprezentowania zmiennych danych wejściowych i wyjściowych. W StableHLO obsługiwane są natywny wejścia i wyjścia, a jedynym zastosowaniem tupletów w StableHLO jest kompleksowe reprezentowanie HLO ABI, w którym np. T, tuple<T>tuple<tuple<T>> mogą się znacznie różnić w zależności od konkretnej implementacji. W przyszłości planujemy wprowadzić zmiany w interfejsie HLO ABI, które mogą pozwolić nam usunąć typy tuple z StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
            | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
            | 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Typy elementów reprezentują elementy typu tensora. W odróżnieniu od wielu języków programowania te typy nie są typu first class w StableHLO. Oznacza to, że programy StableHLO nie mogą bezpośrednio przedstawiać wartości tych typów (w efekcie są to wartości skalarne typu T z wartościami 0-wymiarowymi tensora typu tensor<T>).

  • Typ logiczny reprezentuje wartości logiczne truefalse.
  • Typy całkowite mogą być typu ze znakiem (si) lub bez znaku (ui) i mieć jedną z obsługiwanych szerokości bitów (2, 4, 8, 16, 32 lub 64). Podpisane typy siN reprezentują liczby całkowite z zakresu od -2^(N-1) do 2^(N-1)-1 włącznie, a bez znaku uiN – wartości całkowite z zakresu od 0 do 2^N-1.
  • Typy zmiennoprzecinkowe mogą być następujące:
  • Typy zespolone reprezentują wartości zespolone, które mają część rzeczywistą i część urojoną tego samego typu elementu. Obsługiwane złożone typy to complex<f32> (obie części są typu f32) i complex<f64> (obie części są typu f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Typy funkcji reprezentują zarówno funkcje nazwane, jak i anonimowe. Mają typy wejścia (lista typów po lewej stronie ->) i typy wyjścia (lista typów po prawej stronie ->). W wielu językach programowania typy funkcji są typu first class, ale nie w StableHLO.

StringType ::= 'string'

Typ ciągu tekstowego reprezentuje sekwencje bajtów. W przeciwieństwie do wielu języków programowania typ ciągu znaków nie jest pierwszą klasą w StableHLO i służy tylko do określania statycznych metadanych elementów programu.

Operacje

Operacje StableHLO (nazywane też operacjami) stanowią zamknięty zbiór operacji wysokiego poziomu w modelach uczenia maszynowego. Jak już wspomnieliśmy, składnia StableHLO jest mocno inspirowana MLIR, co niekoniecznie jest najbardziej ergonomiczną alternatywą, ale prawdopodobnie najlepiej pasuje do celu StableHLO, którym jest zwiększenie interoperacyjności między frameworkami i kompilatorami uczenia maszynowego.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Operacje StableHLO (nazywane też operacjami) mają nazwę, dane wejściowe/wyjściowe i podpis. Nazwa składa się z prefiksu stablehlo. i mnemotechniki, która jednoznacznie identyfikuje jedno z obsługiwanych operacji. Pełną listę wszystkich obsługiwanych operacji znajdziesz poniżej.

OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Operacje wykorzystują dane wejściowe i generują dane wyjściowe. Dane wejściowe są podzielone na wartości wejściowe (obliczone podczas wykonywania), funkcje wejściowe (podawane statycznie, ponieważ w StableHLO funkcje nie są wartościami najwyższej klasy) oraz atrybuty wejściowe (również podawane statycznie). Rodzaj danych wejściowych i wyjściowych zużywanych oraz wytwarzanych przez operatora zależy od jego skrótu. Na przykład funkcja add op pobiera 2 wartości wejściowe i zwraca 1 wartość wyjściową. Natomiast operator select_and_scatter wymaga 3 wartości wejściowych, 2 funkcji wejściowych i 3 atrybutów wejściowych.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Funkcje wejściowe (nazywane też funkcjami anonimowymi) są bardzo podobne do funkcji nazwanych, z tym że: 1) nie mają identyfikatora (stąd nazwa „anonimowa”) i 2) nie deklarują typów danych wyjściowych (typy danych są wywnioskowane z opcji return w ramach funkcji).

Składnia funkcji wejściowych zawiera obecnie nieużywaną część (patrz Unused produkcja powyżej), która jest tam ze względu na zgodność z MLIR. W MLIR występuje bardziej ogólna koncepcja „regionów”, które mogą mieć wiele „bloków” operacji połączonych ze sobą za pomocą operacji przeskoku. Te bloki mają identyfikatory odpowiadające Unused produkcji, dzięki czemu można je odróżnić od siebie. StableHLO nie obsługuje wykonywania operacji przejść, więc odpowiadająca mu część składni MLIR nie jest używana (ale nadal jest).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Atrybuty wejściowe mają nazwę i wartość, która jest jedną z obsługiwanych stałych. Stanowią one podstawowy sposób określania metadanych statycznych elementów programu. Na przykład operator concatenate używa atrybutu dimension do określania wymiaru, wzdłuż którego wartości wejściowe są łączone. Podobnie operacja slice używa wielu atrybutów, takich jak start_indices i limit_indices, aby określić granice używane do wycinania wartości wejściowej.

Obecnie programy StableHLO działające w środowisku naturalnym mogą czasami zawierać atrybuty, które nie są opisane w tym dokumencie. W przyszłości planujemy uwzględnić te atrybuty w opcji StableHLO lub zabronić ich wykorzystywania w programach StableHLO. Oto lista tych atrybutów:

  • layout (#629).
  • mhlo.frontend_attributes(#628).
  • mhlo.sharding (#619).
  • output_operand_aliases(#740).
  • metadane lokalizacji (#594);
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

Podpis operacji składa się z typów wszystkich wartości wejściowych (lista typów po lewej stronie ->) oraz typów wszystkich wartości wyjściowych (lista typów po prawej stronie ->). Ściśle mówiąc, typy wejścia są zbędne, a typy wyjścia prawie zawsze są zbędne (ponieważ w przypadku większości operacji StableHLO typy wyjścia można wywnioskować z danych wejściowych). Mimo to op signature jest celowo częścią składni StableHLO, aby zapewnić zgodność z MLIR.

Poniżej znajduje się przykład operacji, której skrót to select_and_scatter. Przyjmuje 3 wartości wejściowe (%operand, %source i %init_value), 2 funkcje wejściowe i 3 atrybuty wejściowe (window_dimensions, window_strides i padding). Zwróć uwagę, że podpis operacji obejmuje tylko typy wartości wejściowych (ale nie typy funkcji i atrybutów wejściowych podane w tekście).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Stałe

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Stałe StableHLO mają literał i typ, które razem reprezentują wartość StableHLO. Typ jest zwykle częścią składni stałej, z wyjątkiem sytuacji, gdy jest jednoznaczny (np. stała logiczna ma jednoznacznie typ i1, podczas gdy stała całkowita może mieć kilka możliwych typów).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Stałe logiczne reprezentują wartości logiczne truefalse. Stałe logiczne mają typ i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Stałe liczbowe reprezentują wartości liczb całkowitych za pomocą ciągów tekstowych, które używają zapisu dziesiętnego lub szesnastkowego. Inne systemy liczbowe, np. binarny czy octal, nie są obsługiwane. Stałe liczbowe są objęte tymi ograniczeniami:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Stałe zmiennoprzecinkowe reprezentują wartości zmiennoprzecinkowe za pomocą ciągów znaków, które używają zapisu dziesiętnego lub wykładniczego. Dodatkowo przy użyciu notacji szesnastkowej można bezpośrednio określać bazowe bity w formacie zmiennoprzecinkowym odpowiedniego typu. Stałe zmiennoprzecinkowe są objęte tymi ograniczeniami:

  • (C1) Jeśli używany jest zapis w formacie innym niż szesnastkowy, is_wellformed(float_literal, float_type).
  • (C2) Jeśli używana jest notacja szesnastkowa, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Stałe zespolone reprezentują wartości zespolone za pomocą list części rzeczywistej (pojawia się jako pierwsza) i części urojonej (pojawia się jako druga). Na przykład (1.0, 0.0) : complex<f32> oznacza 1.0 + 0.0i, a (0.0, 1.0) : complex<f32> oznacza 0.0 + 1.0i. Kolejność, w jakiej te części są przechowywane w pamięci, jest zdefiniowana przez implementację. Stałe złożone mają te ograniczenia:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Stałe tensora reprezentują wartości tensora za pomocą zagnieżdżonych list określonych za pomocą notacji NumPy. Na przykład dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> reprezentuje wartość tensora z tym mapowaniem indeksów na elementy: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Kolejność, w jakiej te elementy są przechowywane w pamięci, jest zdefiniowana przez implementację. Stałe tensora mają następujące ograniczenia:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), gdzie:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), gdzie:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • w przeciwnym razie: false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Zwartości skonwertowanego tensora reprezentują wartości skonwertowanego tensora za pomocą tej samej notacji co stałe tensora, a ich elementy są określone jako stałe ich typu magazynu. Poddane kwantyzacji stałe tensora mają następujące ograniczenia:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Literały łańcuchowe składają się z bajtów określonych za pomocą znaków ASCII i sekwencji ucieczki. Są one niezależne od kodowania, więc interpretacja tych bajtów jest zdefiniowana przez implementację. Literały łańcuchowe mają typ string.

Operacje

abs

Semantyka

Wykonuje elementarną operację abs na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • W przypadku liczb całkowitych ze znakiem: moduł liczby całkowitej.
  • Dla liczb zmiennoprzecinkowych: abs z IEEE-754.
  • W przypadku liczb zespolonych: moduł zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(abs, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora (C1-C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu signed integer lub zmiennoprzecinkowego albo tensor kwantyzowany na podstawie tensora (K1–C2)

Ograniczenia

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) zdefiniowano jako:
    • complex_element_type(element_type(operand)) jeżeli is_complex(operand).
    • baseline_element_type(operand) w innych przypadkach.

Przykłady

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

 Więcej przykładów

dodaj

Semantyka

Wykonuje dodawanie elementów dwóch tensorów lhsrhs i tworzy tensor result. W zależności od typu elementu:

  • W przypadku wartości logicznych: operator logiczny LUB.
  • W przypadku liczb całkowitych: dodawanie liczb całkowitych.
  • Dla liczb zmiennoprzecinkowych: addition z IEEE-754.
  • W przypadku liczb zespolonych: dodawanie zespolone.
  • W przypadku typów skwantowanych: dequantize_op_quantize(add, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub kwantyzowany tensor (C1-C6)
(I2) rhs tensor lub kwantyzowany tensor (C1-C5), (C7)

Wyniki

Nazwa Typ Ograniczenia
result tensor (tensor kwantowy) (C1-C7)

Ograniczenia

  • Jeśli operacja używa niespakowanych tensorów:
    • (C1) type(lhs) = type(rhs) = type(result).
  • Jeśli operacja używa kwantowanych tensorów:
    • (C2) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C3) storage_type(lhs) = storage_type(rhs) = storage_type(result).
    • (C4) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C5) (is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result).
    • (C6) Jeśli is_per_axis_quantized(lhs), to quantization_dimension(lhs) = quantization_dimension(result).
    • (C7) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) = quantization_dimension(result).

Przykłady

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Więcej przykładów

after_all

Semantyka

Zapewnia, że operacje generujące inputs są wykonywane przed operacjami zależnymi od result. Wykonywanie tej operacji nie powoduje żadnych zmian. Służy ona tylko do ustanowienia zależności danych z poziomu result do inputs.

Dane wejściowe

Etykieta Nazwa Typ
(I1) inputs zmienna liczba token

Wyniki

Nazwa Typ
result token

Przykłady

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

 Więcej przykładów

all_gather

Semantyka

W każdej grupie procesów w siatce procesów StableHLO konkatenuje wartości tensorów operands z każdego procesu wzdłuż wymiaru all_gather_dim i tworzy tensory results.

Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:

  • cross_replica(replica_groups) jeśli channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) if channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

Następnie w każdym wierszu process_group:

  • operands...@receiver = [operand@sender for sender in process_group] za wszystkie receiver w process_group.
  • results...@process = concatenate(operands...@process, all_gather_dim) za wszystkie process w process_group.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operands zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C1), (C6)
(I2) all_gather_dim stała typu si64 (C1), (C6)
(I3) replica_groups 2-wymiarowa stała tensora typu si64 (C2-C4)
(I4) channel_id stała typu si64 (C5)
(I5) use_global_device_ids stała typu i1 (K5)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C6)

Ograniczenia

  • (C1) 0 <= all_gather_dim < rank(operands...).
  • (C2) is_unique(replica_groups).
  • (C3) Parametr size(replica_groups) jest zdefiniowany jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_replicas, jeśli używana jest właściwość cross_replica_and_partition.
    • num_processes, jeśli używana jest właściwość flattened_ids.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Jeśli use_global_device_ids = true, to channel_id > 0.
  • (C6) type(results...) = type(operands...) z wyjątkiem:
    • dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]

Więcej przykładów

all_reduce

Semantyka

W ramach każdej grupy procesów w siatce procesów StableHLO stosuje funkcję redukcji computation do wartości tensorów operands z każdego procesu i generuje tensory results.

Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:

  • cross_replica(replica_groups) jeśli channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) if channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

Następnie w każdym wierszu process_group:

  • results...@process[result_index] = exec(schedule) dla niektórych drzew binarnych schedule gdzie:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule to drzewo binarne zdefiniowane przez implementację, którego traversal w kolejności to to_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operands zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C5), (C6)
(I2) replica_groups zmienna liczba stałych tensora jednowymiarowego typu si64 (K1–C3)
(I3) channel_id stała typu si64 (C4)
(I4) use_global_device_ids stała typu i1 (C4)
(I5) computation funkcja (C5)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C6-C7)

Ograniczenia

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) zdefiniowano jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_replicas, jeśli używana jest właściwość cross_replica_and_partition.
    • num_processes, jeśli używana jest właściwość flattened_ids.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Jeśli use_global_device_ids = true, to channel_id > 0.
  • (C5) computation ma typ (tensor<E>, tensor<E>) -> (tensor<E>), gdzie is_promotable(element_type(operand), E).
  • (C6) shape(results...) = shape(operands...).
  • (C7) element_type(results...) = E.

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]

Więcej przykładów

all_to_all

Semantyka

all_to_all

W ramach każdej grupy procesów w siatce procesów StableHLO dzieli wartości tensorów operands wzdłuż split_dimension na części, rozprasza te części między procesami, konkatenuje rozproszone części wzdłuż concat_dimension i tworzy tensory results. Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:

  • cross_replica(replica_groups) jeżeli channel_id <= 0.
  • cross_partition(replica_groups) jeżeli channel_id > 0.

Następnie w każdym wierszu process_group:

  • split_parts...@sender = split(operands...@sender, split_count, split_dimension)dla wszystkich senderprocess_group.
  • scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group] gdzie receiver_index = process_group.index(receiver).
  • results...@process = concatenate(scattered_parts...@process, concat_dimension).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operands zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C1-C3), (C9)
(I2) split_dimension stała typu si64 (C1), (C2), (C9)
(I3) concat_dimension stała typu si64 (C3), (C9)
(I4) split_count stała typu si64 (C2), (C4), (C8), (C9)
(I5) replica_groups 2-wymiarowa stała tensora typu si64 (C5-C8)
(I6) channel_id stała typu si64

Wyniki

Nazwa Typ Ograniczenia
results liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor (C9)

Ograniczenia

  • (C1) 0 <= split_dimension < rank(operands...).
  • (C2) dim(operands..., split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operands...).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) jest zdefiniowany jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(results...) = type(operands...) oprócz tych, które split_dimension != concat_dimension:
    • dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.
    • dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
//                    [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
//                    [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
//                    [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
//                    [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
  // channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]

 Więcej przykładów

i

Semantyka

Wykonuje z punktu widzenia elementu ORAZ dwa tensory lhs i rhs, tworzy intensywność result. W zależności od typu elementu:

  • W przypadku wartości logicznych: operator logiczny OR.
  • W przypadku liczb całkowitych: bitowe ORAZ.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu logicznego lub całkowitego (C1)
(I2) rhs tensor typu logicznego lub całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego lub całkowitego (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

 Więcej przykładów

atan2

Semantyka

Wykonuje elementową operację atan2 na tensorach lhsrhs, tworząc tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: atan2 z IEEE-754.
  • W przypadku liczb zespolonych: complex atan2.
  • W przypadku typów skwantowanych: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)
(I2) rhs tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

 Więcej przykładów

batch_norm_grad

Semantyka

Oblicza gradienty kilku wejść batch_norm_training z backpropagatinggrad_output i tworzy tensory grad_operand, grad_scalegrad_offset. Bardziej formalnie ta operacja może zostać wyrażona jako dekompozycja istniejących operacji StableHLO z użyciem składni Pythona:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

W przypadku typów skwantowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1-C3), (C5)
(I2) scale 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (C2), (C4), (C5)
(I3) mean 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (C2), (C4)
(I4) variance 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (C2), (C4)
(I5) grad_output tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C2), (C3)
(I6) epsilon stała typu f32
(I7) feature_index stała typu si64 (C1), (C5)

Wyniki

Nazwa Typ Ograniczenia
grad_operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C2), (C3)
grad_scale 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (C2), (C4)
grad_offset Jednowymiarowy tensor typu kwantyzowanego typu zmiennoprzecinkowego lub na tensor (C2), (C4)

Ograniczenia

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale i grad_offset mają te same wartości baseline_element_type.
  • (C3) operand, grad_outputgrad_operand mają ten sam kształt.
  • (C4) Znaczniki scale, mean, variance, grad_scalegrad_offset mają ten sam kształt.
  • (C5) size(scale) = dim(operand, feature_index).

Przykłady

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semantyka

Normalizuje tensor operand we wszystkich wymiarach z wyjątkiem wymiaru feature_index i tworzy tensor result. Bardziej formalnie tę operację można wyrazić jako dekompozycję istniejących operacji StableHLO za pomocą składni Pythona w ten sposób:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

W przypadku typów kwantowych wybiera dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub kwantyzowany tensor na poziomie procesora (C1-C7)
(I2) scale 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (C2), (C3)
(I3) offset 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (C2), (C4)
(I4) mean 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (K5)
(I5) variance 1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora (C2), (C6)
(I6) epsilon stała typu f32
(I7) feature_index stała typu si64 (C1), (C3-C6)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C2), (C7)

Ograniczenia

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance i result mają te same ustawienia baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semantyka

Oblicza średnią i wariancję we wszystkich wymiarach oprócz wymiaru feature_index oraz normalizuje tensor operand, tworząc tensory output, batch_meanbatch_var. Bardziej formalnie ta operacja może być wyrażona jako dekompozycja istniejących operacji StableHLO z użyciem składni Pythona:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

W przypadku typów skwantowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)
(I2) scale Jednowymiarowy tensor kwantyzowany (zmiennoprzecinkowy lub na tensor) (C2), (C3)
(I3) offset 1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora (C2), (C4)
(I4) epsilon stała typu f32 (C1), (C3-C6)
(I5) feature_index stała typu si64 (C1), (C3-C6)

Wyniki

Nazwa Typ Ograniczenia
output tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C7)
batch_mean 1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora (C2), (C5)
batch_var 1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora (C2), (C6)

Ograniczenia

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_varoutput mają ten sam baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Przykłady

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semantyka

Wykonuje operację bitcast na tensorze operand i tworzy tensor result, w którym bity całego tensora operand są ponownie interpretowane przy użyciu typu tensora result.

Bardziej formalnie, jeśli E = element_type(operand), E' = element_type(result)R = rank(operand):

  • Jeśli num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Jeśli num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Jeśli num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

Funkcja bits zwraca reprezentację danej wartości w pamięci, a jej działanie jest definiowane przez implementację, ponieważ dokładna reprezentacja tensorów jest zdefiniowana przez implementację, a dokładna reprezentacja typów elementów również jest zdefiniowana w implementacji.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub kwantyzowany tensor (C1-C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub kwantyzowany tensor (K1–C2)

Ograniczenia

  • (C1) Przy założeniu, że E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result)R = rank(operand):
    • Jeśli num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Jeśli num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) dla wszystkich 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Jeśli num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) za wszystkie 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Jeśli is_complex(operand) or is_complex(result), to is_complex(operand) and is_complex(result).

Przykłady

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

 Więcej przykładów

broadcast_in_dim

Semantyka

Rozwija wymiary lub rangę wejściowego tensora przez powielanie danych w tensorze operand i tworzenie tensora result. Formalnie: result[result_index] = operand[operand_index] dla wszystkich d w ramach axes(operand):

  • operand_index[d] = 0 jeżeli dim(operand, d) = 1.
  • operand_index[d] = result_index[broadcast_dimensions[d]] w innych przypadkach.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub kwantyzowany tensor (C1-C2), (C5-C6)
(I2) broadcast_dimensions Jednowymiarowa stała tensora typu si64 (C2-C6)

Wyniki

Nazwa Typ Ograniczenia
result tensor (tensor kwantowy) (C1), (C3), (C5-C6)

Ograniczenia

  • (C1) element_type(result) otrzymuje:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand), z tym że quantization_dimension(operand), scales(operand)zero_points(operand) mogą się różnić od quantization_dimension(result), scales(result)zero_points(result) odpowiednio.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) W przypadku wszystkich daxes(operand):
    • dim(operand, d) = 1 lub
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jeśli is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jeśli dim(operand, quantization_dimension(operand)) = 1, to scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Przykłady

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 Więcej przykładów

etui

Semantyka

Generuje dane wyjściowe w wyniku wykonania dokładnie 1 funkcji z funkcji branches w zależności od wartości index. Bardziej formalnie: result = selected_branch()gdzie:

  • selected_branch = branches[index] jeżeli 0 <= index < size(branches).
  • selected_branch = branches[-1] w innych przypadkach.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) index Tensor 0-wymiarowy typu si32
(I2) branches zmienna liczba funkcji (C1-C4)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów (C4)

Ograniczenia

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Przykłady

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Więcej przykładów

Cbrt

Semantyka

Wykonuje elementarną operację pierwiastka sześciennego na tensorze operand i tworzy tensor result. W zależności od typu elementu wykonuje te działania:

  • W przypadku jednostek zmiennoprzecinkowych: rootn(x, 3) w standardzie IEEE-754.
  • Liczby zespolone: pierwiastek sześcienny zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(cbrt, operand, type(result))

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

 Więcej przykładów

ceil

Semantyka

Przeprowadza element po elemencie zaokrąglenie w górę tensora operand i generuje tensor result. Realizuje operację roundToIntegralTowardPositive według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(ceil, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Więcej przykładów

Cholesky

Semantyka

Oblicza rozkład Choleskiego dla zbioru macierzy.

W bardziej formalnym ujęciu dla wszystkich iindex_space(result) result[i0, ..., iR-3, :, :] jest rozkładem Cholesky'ego a[i0, ..., iR-3, :, :] w postaci dolnej macierzy trójkątnej (jeśli lower jest true) lub górnej macierzy trójkątnej (jeśli lower jest false). Wartości wyjściowe w trójkącie przeciwległym, tj. odpowiednio ścisły górny trójkąt lub odpowiednio ścisły trójkąt dolny, są definiowane na podstawie implementacji.

Jeśli istnieje element i, w którym macierz wejściowy nie jest macierzową o dodatnie dodatnim, działanie jest niezdefiniowane.

W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) a tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1-C3)
(I2) lower Stałe tensora 0-wymiarowego typu i1

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Przykłady

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

ograniczać (zakres)

Semantyka

Łączy każdy element tensora operand między wartością minimalną a maksymalną, tworząc tensor result. Bardziej formalnie: result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), gdzie min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(clamp, min, operand, max, type(result)).

Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) min tensor lub tensor zagregowany z tensorów (C1), (C3)
(I2) operand tensor lub tensor zagregowany z tensorów (C1-C4)
(I3) max tensor lub tensor zagregowany z tensorów (C2), (C3)

Wyniki

Nazwa Typ Ograniczenia
result kwantowy tensor lub tensor kwantowy (C4)

Ograniczenia

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Przykłady

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Więcej przykładów

collective_broadcast

Semantyka

W każdej grupie procesów w siatce procesów StableHLO wyślij wartość tensora operand z procesu źródłowego do procesów docelowych i utwórz tensor result.

Ta operacja dzieli siatkę procesów StableHLO na siatkę procesów process_groups, która jest zdefiniowana w ten sposób:

  • cross_replica(replica_groups) jeżeli channel_id <= 0.
  • cross_partition(replica_groups) jeżeli channel_id > 0.

Następnie result@process jest obliczana jako:

  • operand@process_groups[i, 0], jeśli istnieje i, który określa, że proces znajduje się w regionie process_groups[i].
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))w innym przypadku.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor zagregowany z tensorów (C3)
(I2) replica_groups zmienna liczba stałych tensora jednowymiarowego typu si64 (C1), (C2)
(I3) channel_id stała typu si64

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C3)

Ograniczenia

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, gdzie N jest zdefiniowany jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C3) type(result) = type(operand).

Przykłady

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semantyka

W każdej grupie procesów w siatce procesów StableHLO wysyła wartość tensora operand z procesu źródłowego do procesu docelowego i tworzy tensor result.

Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:

  • cross_replica(source_target_pairs), jeśli channel_id <= 0.
  • cross_partition(source_target_pairs) jeżeli channel_id > 0.

Później wartość result@process jest określana przez:

  • operand@process_groups[i, 0], jeśli istnieje i takie, że process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand kwantowy tensor lub tensor kwantowy (K5)
(I2) source_target_pairs 2-wymiarowa stała tensora typu si64 (C1-C4)
(I3) channel_id stała typu si64

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1)

Ograniczenia

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, gdzie N jest zdefiniowany jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_partitions, jeśli używana jest właściwość cross_partition.
  • (C5) type(result) = type(operand).

Przykłady

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Więcej przykładów

porównaj

Semantyka

Porównuje elementy tensorów lhsrhs zgodnie z definicjami comparison_directioncompare_type, tworząc tensor result.

Wartości comparison_directioncompare_type mają następującą semantykę:

W przypadku typów elementów wartości logicznych i liczb całkowitych:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

W przypadku typów elementów zmiennoprzecinkowych z compare_type = FLOAT operator op implementuje te operacje IEEE-754:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

W przypadku typów elementów zmiennoprzecinkowych z wartością compare_type = TOTALORDER operator używa kombinacji operacji totalOrdercompareQuietEqual z normy IEEE-754.

W przypadku złożonych typów elementów przeprowadzane jest porównanie leksykograficzne par (real, imag) za pomocą podanych wartości comparison_directioncompare_type. Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych, gdy comparison_direction = GE, GT, LE lub LT (#560).

W przypadku typów poddanych kwantyzacji wykonuje działanie dequantize_compare(lhs, rhs, comparison_direction).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor zagregowany z tensorów (C1-C3)
(I2) rhs tensor lub tensor zagregowany z tensorów (C1-C2)
(I3) comparison_direction enum EQ, NE, GE, GT, LELT
(I4) compare_type enum FLOAT, TOTALORDER, SIGNEDUNSIGNED (C3)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego (K2)

Ograniczenia

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type jest zdefiniowany jako:
    • SIGNED jeżeli is_signed_integer(element_type(lhs)).
    • UNSIGNED jeżeli is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT lub TOTALORDER, jeśli is_float(element_type(lhs)).
    • FLOAT jeżeli is_complex(element_type(lhs)).

Przykłady

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

 Więcej przykładów

złożone

Semantyka

Przeprowadza konwersję element po elemencie na wartość zespoloną z pary wartości rzeczywistych i urojonych, lhsrhs, oraz generuje tensor result.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu f32 lub f64 (C1-C3)
(I2) rhs tensor typu f32 lub f64 (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zespolonego (C2), (C3)

Ograniczenia

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) ma typ complex<E>, gdzie E = element_type(lhs).

Przykłady

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

 Więcej przykładów

wieloskładnikowa

Semantyka

Zawiera operację złożoną z innych operacji StableHLO, przyjmując inputs i composite_attributes oraz zwraca results. Semantyka operacji jest implementowana przez atrybut decomposition. Opcję composite można zastąpić jej rozkładem bez zmiany semantyki programu. Jeśli wstawienie dekompozycji w kod źródłowy nie zapewnia tej samej semantyki op, użyj custom_call.

Pole version (domyślnie 0) służy do informowania o zmianie semantyki elementu złożonego.

Dane wejściowe

Etykieta Nazwa Typ
(I1) inputs zmienna liczba wartości
(I2) name stała typu string
(I3) composite_attributes słownik atrybutów
(I4) decomposition stała typu string
(I5) version stała typu si32

Wyniki

Nazwa Typ
results zmienna liczba wartości

Ograniczenia

  • (C1) is_namespaced_op_name(name)
  • (C2) is_defined_in_parent_scope(decomposition)
  • (C3) types(inputs...) == input_types(decomposition)
  • (C4) types(results...) == output_types(decomposition)

Przykłady

%results = "stablehlo.composite"(%input0, %input1) {
  name = "my_namespace.my_op",
  composite_attributes = {
    my_attribute = "my_value"
  },
  decomposition = @my_op,
  version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>

Więcej przykładów

konkatenacja

Semantyka

Łączy elementy inputs wzdłuż wymiaru dimension w takim samym porządku jak podane argumenty i tworzy tensor result. Bardziej formalnie: result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], gdzie:

  1. id = d0 + ... + dk-1 + kd.
  2. d jest równy dimension, a d0, ... to drozmiary inputs.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (K1–C6)
(I2) dimension stała typu si64 (C2), (C4), (C6)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C5–C6)

Ograniczenia

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)) z wyjątkiem dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) oprócz:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Przykłady

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

 Więcej przykładów

stała

Semantyka

Tworzy tensor output na podstawie stałej value.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) value stała (C1)

Wyniki

Nazwa Typ Ograniczenia
output tensor lub kwantyzowany tensor (C1)

Ograniczenia

  • (C1) type(value) = type(output).

Przykłady

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Więcej przykładów

dokonają konwersji

Semantyka

Przeprowadza konwersję elementów z jednego typu na inny w tensorze operand i tworzy tensor result.

W przypadku konwersji boolean-to-any-supported-type wartość false jest zamieniana na 0, a wartość true na 1. W przypadku konwersji typu any-supported-type-to-boolean wartość 0 jest konwertowana na false, a wartości inne niż zero – na true. Poniżej znajdziesz informacje o tym, jak to działa w przypadku typów złożonych.

W przypadku konwersji zawierających liczby całkowite na liczbę całkowitą, liczbę zmiennoprzecinkową lub zmiennoprzecinkową na liczbę zmiennoprzecinkową wartość źródłową może być dokładnie reprezentowana w typie miejsca docelowego. W przeciwnym razie zachowanie jest nieokreślone (#180).

W przypadku konwersji zawierających floating-point-to-integer część ułamkowa jest skracana. Jeśli skrócona wartość nie może być przedstawiona w przypadku typu miejsca docelowego, zachowanie jest nieokreślone (#180).

Konwersja z typu zespolonego na typ zespolony działa tak samo jak konwersja z typu zmiennoprzecinkowego na typ zmiennoprzecinkowy w przypadku konwersji części rzeczywistej i części urojonej.

W przypadku konwersji z typu złożonego na dowolny inny typz dowolnego innego typu na złożony wartość wyobrażona źródła jest ignorowana lub docelowa wartość wyobrażona jest ustawiana na 0. Konwersja części rzeczywistej następuje zgodnie z konwersją na liczby zmiennoprzecinkowe.

Zasadniczo ta operacja może wyrażać dekwantyzację (przekształcanie tensorów regularnych w tensory regularne), kwantyzację (konwersja z tensorów regularnych na tensory kwantyzowane) i rekwantyzację (przekształcanie między kwantyzowanymi procesorami), ale obecnie mamy dla nich specjalne operacje – uniform_dequantize dla pierwszego przypadku użycia i uniform_quantize w drugim i trzecim przypadku użycia. W przyszłości te 2 operacje mogą zostać połączone w convert (#1576).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor (C1)

Ograniczenia

  • (C1) shape(operand) = shape(result).

Przykłady

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

 Więcej przykładów

splotu

Semantyka

Oblicza iloczyny skalarne między oknami lhs i wycinkami rhs oraz generuje result. Na diagramie poniżej widać, jak elementy w elementach result są obliczane na podstawie elementów lhsrhs.

splotu

Bardziej formalnie rozważ następujące zmiany w danych wejściowych w postaci elementu lhs, aby można było tworzyć przedziały czasu lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Ta zmiana kadrowania wykorzystuje te funkcje pomocnicze:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1] gdzie j[d] = i[permutation[d]].

Jeśli feature_group_count = 1 i batch_group_count = 1, to dla wszystkich output_spatial_index w index_space(dim(result, output_spatial_dimensions...)):result[result_shape(:, output_spatial_index, :)] = dot_product gdzie:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Ta funkcja wydaje się nieużywana, więc planujemy ją usunąć w przyszłości (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Jeśli feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Jeśli batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

W przypadku typów kwantowych wybiera dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

W przypadku typów hybrydowych kwantyzowana wartość wylicza hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor zagregowany z tensorów (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34)
(I2) rhs tensor lub kwantyzowany tensor (C1), (C14–C16), (C25), (C27–C29), (C31–C34)
(I3) window_strides Jednowymiarowa stała tensora typu si64 (C2-C3), (C25)
(I4) padding 2-wymiarowa stała tensora typu si64 (C4), (C25)
(I5) lhs_dilation Jednowymiarowa stała tensora typu si64 (C5-C6), (C25)
(I6) rhs_dilation 1-wymiarowa stała tensora typu si64 (C7-C8), (C25)
(I7) window_reversal 1-wymiarowa stała tensora typu i1 (C9)
(I8) input_batch_dimension stała typu si64 (C10), (C13), (C25)
(I9) input_feature_dimension stała typu si64 (C11), (C13-C14)
(I10) input_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C12), (C13), (C25)
(I11) kernel_input_feature_dimension stała typu si64 (C14), (C18)
(I12) kernel_output_feature_dimension stała typu si64 (C15–C16), (C18), (C25), (C29)
(I13) kernel_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C17-C18), (C25)
(I14) output_batch_dimension stała typu si64 (C20), (C25)
(I15) output_feature_dimension stała typu si64 (C20), (C25), (C30)
(I16) output_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C19-C20), (C25)
(I17) feature_group_count stała typu si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count stała typu si64 (C10), (C15), (C22), (C23), (C25)
(I19) precision_config zmienna liczba typów enumeracji DEFAULT, HIGH i HIGHEST (C24)

Wyniki

Nazwa Typ Ograniczenia
result tensor (tensor kwantowy) (C25-C28), (C30), (C32-34)

Ograniczenia

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Zgodnie z input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Podany kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Zgodnie z output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) jest zdefiniowany jako:
    • dim(lhs, input_batch_dimension) / batch_group_count jeżeli result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jeżeli result_dim = output_feature_dimension.
    • num_windows w przeciwnym razie, gdzie:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jeśli operacja używa tensorów niekwantowych:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jeśli operacja używa kwantowanych tensorów:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jeśli is_per_axis_quantized(result), to quantization_dimension(result) = output_feature_dimension.
    • Jeśli is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Jeśli is_per_tensor_quantized(rhs), to is_per_tensor_quantized(result).
    • Jeśli !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Przykłady

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs: [
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]],
//        [[[1]], [[1]], [[1]]]
//       ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = array<i64: 4, 4>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  batch_group_count = 1 : i64,
  feature_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

 Więcej przykładów

cosinus

Semantyka

Wykonuje elementarną operację cosinusa na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: cos z IEEE-754.
  • W przypadku liczb zespolonych: cosinus zespolony.
  • W przypadku typów kwantowych: dequantize_op_quantize(cosine, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

 Więcej przykładów

count_leading_zeros

Semantyka

Wykonuje element po elemencie zliczanie liczby początkowych zer w tensorze operandi tworzy tensor result.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu liczba całkowita (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu liczba całkowita (C1)

Ograniczenia

  • (C1) type(operand) = type(result).

Przykłady

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Więcej przykładów

custom_call

Semantyka

Zawiera zdefiniowaną przez implementację operację call_target_name, która przyjmuje argumenty inputscalled_computations oraz zwraca results. Atrybuty has_side_effect, backend_config i api_version mogą służyć do udostępniania dodatkowych metadanych zdefiniowanych przez implementację.

Obecnie ta operacja zawiera dość nieuporządkowaną kolekcję metadanych, która odzwierciedla organiczną ewolucję jej odpowiednika w kompilatorze XLA. W przyszłości planujemy ujednolicić te metadane (#741).

Dane wejściowe

Etykieta Nazwa Typ
(I1) inputs zmienna liczba wartości
(I2) call_target_name stała typu string
(I3) has_side_effect stała typu i1
(I4) backend_config stała typu string lub słownik atrybutów
(I5) api_version stała typu si32
(I6) called_computations liczba zmiennoprzecinkowa typu string

Wyniki

Nazwa Typ
results liczba zmiennoprzecinkowa

Przykłady

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = {bar = 42 : i32},
  api_version = 4 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

dzielenie

Semantyka

Wykonuje element po elemencie dzielenie tensorów dzielnika lhs i dzielenia rhs oraz zwraca tensor result. W zależności od typu elementu:

  • W przypadku liczb całkowitych: dzielenie liczb całkowitych, które zwraca iloraz algebraiczny z pominięciem części ułamkowej.
  • Dla liczb zmiennoprzecinkowych: division z IEEE-754.
  • W przypadku liczb zespolonych: dzielenie zespolone.
  • W przypadku typów skwantowanych:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)
(I2) rhs tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Więcej przykładów

dot_general

Semantyka

Oblicza iloczyn skalarny między wycinkami lhs i rhs, uzyskując tensor result.

Bardziej formalnie: result[result_index] = dot_product, gdzie:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_indexgdzie size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) i size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

W przypadku typów skwantowanych wykonuje operację dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

W przypadku hybrydowych typów kwantyzacji wykonuje działanie hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs).

precision_config kontroluje kompromis między szybkością a dokładnością obliczeń na backendzie akceleratora. Może to być jedna z tych wartości (obecnie semantyka tych wartości jest niewystarczająco sprecyzowana, ale planujemy rozwiązać ten problem w bładze zgłoszenia 755):

  • DEFAULT: najszybsze obliczenia, ale najmniej dokładne przybliżenie do pierwotnego wyniku.
  • HIGH: wolniejsze obliczenia, ale dokładniejsze przybliżenie do pierwotnej wartości.
  • HIGHEST: najwolniejsze obliczenia, ale najbardziej dokładne przybliżenie do pierwotnej wartości.

DotAlgorithm definiuje główne właściwości algorytmu używanego do implementacji operacji kropki, która określa też dokładność. Jeśli pola atrybutów algorytmu są ustawione, precision_config musi mieć wartość DEFAULT. DotAlgorithms nie mają wartości domyślnej, ponieważ parametry domyślne są definiowane przez implementację. Dlatego wszystkie pola algorytmu kropki mogą być ustawione na None, aby określić pusty algorytm kropki, który zamiast tego użyje wartości precision_config.

Pola DotAlgorithm:

  • lhs_precision_typerhs_precision_type, dokładność, do której zaokrąglane są wartości po lewej i po prawej stronie operacji. Typy dokładności są niezależne od typów magazynowania danych wejściowych i wyjściowych.
  • accumulation_type dokładność użyta do skumulowania.
  • lhs_component_count, rhs_component_countnum_primitive_operations stosuje się, gdy używamy algorytmu, który rozkłada lewą lub prawą stronę na kilka komponentów i wykonuje na tych wartościach wiele „prostych” operacji dot – zwykle w celu emulacji większej precyzji (np. Korzystanie z typu danych bfloat16 do obliczeń o większej precyzji: bf16_6x tf32_3x itp.). W przypadku algorytmów bez dekompozycji te wartości powinny wynosić 1.
  • allow_imprecise_accumulation, aby określić, czy akumulacja z mniejszą dokładnością jest dozwolona w przypadku niektórych kroków (np. CUBLASLT_MATMUL_DESC_FAST_ACCUM).

Przykładowe atrybuty DotAlgorithm:

// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
 rhs_precision_type = tf32,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = false}


// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
 rhs_precision_type = bf16,
 accumulation_type = f32,
 lhs_component_count = 3,
 rhs_component_count = 3,
 num_primitive_operations = 6,
 allow_imprecise_accumulation = false}


// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
 rhs_precision_type = f8e5m2,
 accumulation_type = f32,
 lhs_component_count = 1,
 rhs_component_count = 1,
 num_primitive_operations = 1,
 allow_imprecise_accumulation = true}

To od implementacji zależy, które kombinacje są obsługiwane. Ogólnie nie ma gwarancji, że każdy algorytm jest obsługiwany przez każdy typ akceleratora przez użytkownika StableHLO. Jeśli dany algorytm nie jest obsługiwany, należy zgłosić błąd zamiast korzystać z alternatywnego rozwiązania. Weryfikacja StableHLO zapewni najlepszą weryfikację, zapobiegając działaniu algorytmów, które nie są obsługiwane na żadnym sprzęcie.

Niektóre obsługiwane wartości algorytmu znajdziesz w sekcji xla_data.proto > Algorithm. zgłoszenie #2483 zawiera plan stworzenia scentralizowanego dokumentu na temat obsługiwanych algorytmów na zapleczu.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor zagregowany z tensorów (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20)
(I2) rhs tensor lub kwantyzowany tensor (C7-C10), (C12-C20)
(I3) lhs_batching_dimensions 1-wymiarowa stała tensora typu si64 (C1), (C3), (C5), (C9), (C12)
(I4) rhs_batching_dimensions 1-wymiarowa stała tensora typu si64 (C1), (C4), (C7), (C9)
(I5) lhs_contracting_dimensions 1-wymiarowa stała tensora typu si64 (C2), (C3), (C6), (C10)
(I6) rhs_contracting_dimensions 1-wymiarowa stała tensora typu si64 (C2), (C4), (C8), (C10), (C16)
(I7) precision_config zmienna liczba typów enumeracji DEFAULT, HIGH i HIGHEST (C11), (C21)
(I8) lhs_precision_type FloatType lub TensorFloat32 (C21)
(I9) rhs_precision_type FloatType lub TensorFloat32 (C21)
(I10) accumulation_type FloatType lub TensorFloat32 (C21)
(I11) lhs_component_count stała typu si32 (C21), (C22)
(I12) rhs_component_count stała typu si32 (C21), (C23)
(I13) num_primitive_operations stała typu si32 (C21), (C24)
(I14) allow_imprecise_accumulation stała typu bool (C21)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub kwantyzowany tensor (C12), (C14), (C18–C20)

Ograniczenia

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Jeśli operacja używa niespakowanych tensorów:
    • (C13) element_type(lhs) = element_type(rhs).
  • Jeśli operacja używa kwantowanych tensorów:
    • (C14) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C15) zero_points(rhs) = 0.
    • (C16) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) nie jest w rhs_contracting_dimensions.
    • Jeśli is_quantized(lhs):
    • (C17) storage_type(lhs) = storage_type(rhs).
    • (C18) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C19) Jeśli is_per_tensor_quantized(rhs), to is_per_tensor_quantized(result).
    • Jeśli !is_quantized(lhs):
    • (C20) element_type(lhs) = expressed_type(rhs) = element_type(result).
  • Jeśli !is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):
    • (C21) precision_config... = DEFAULT.
    • (C22) 0 < lhs_component_count.
    • (C23) 0 < rhs_component_count.
    • (C24) 0 < num_primitive_operations.

Przykłady

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
  algorithm = #stablehlo.dot_algorithm<
    lhs_precision_type = tf32,
    rhs_precision_type = tf32,
    accumulation_type = f32,
    lhs_component_count = 1,
    rhs_component_count = 1,
    num_primitive_operations = 1,
    allow_imprecise_accumulation = false
  >
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

 Więcej przykładów

dynamic_broadcast_in_dim

Semantyka

Ta operacja jest pod względem funkcjonalnym identyczna z operacją broadcast_in_dim, ale kształt wyniku jest określany dynamicznie za pomocą parametru output_dimensions.

Operacja akceptuje też opcjonalne atrybuty known_expanding_dimensions i known_nonexpanding_dimensions, które służą do wyrażania statycznej wiedzy o zachowaniu rozszerzania wymiarów. Jeśli nie są określone, przyjmuje się, że wszystkie wymiary mogą się rozszerzać.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor (tensor kwantowy) (C1–C2), (C5–C6), (C9)
(I2) output_dimensions Jednowymiarowy tensor typu liczby całkowitej (C7)
(I3) broadcast_dimensions 1-wymiarowy tensor stały typu całkowitego (C2-C6)
(I4) known_expanding_dimensions Jednowymiarowy tensor stały typu liczby całkowitej (C8-C9)
(I5) known_nonexpanding_dimensions 1-wymiarowy tensor stały typu całkowitego (C8–C9)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub kwantyzowany tensor (C1), (C3), (C5-C7)

Ograniczenia

  • (C1) element_type(result) otrzymuje:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand), z tym że quantization_dimension(operand), scales(operand)zero_points(operand) mogą się różnić od quantization_dimension(result), scales(result)zero_points(result) odpowiednio.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) W przypadku wszystkich daxes(operand):
    • dim(operand, d) = 1 lub
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Jeśli is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Jeśli dim(operand, quantization_dimension(operand)) = 1, to scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
  • (C7) size(output_dimensions) = rank(result).
  • (C8) is_unique(known_expanding_dimensions + known_nonexpanding_dimensions).
  • (C9) 0 <= known_expanding_dimensions < rank(operand).
  • (C10) 0 <= known_nonexpanding_dimensions < rank(operand).

Przykłady

// %operand: [
//            [1, 2, 3]
//           ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
  broadcast_dimensions = array<i64: 2, 1>,
  known_expanding_dimensions = array<i64: 0>,
  known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

 Więcej przykładów

dynamic_conv

Semantyka

Ta operacja jest pod względem funkcjonalnym identyczna z operacją convolution, ale wypełnienie jest podawane dynamicznie za pomocą padding.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor lub tensor zagregowany z tensorów (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33)
(I2) rhs tensor lub kwantyzowany tensor (C1), (C14-C16), (C26-C28), (C30-C33)
(I3) padding 2-wymiarowy tensor typu całkowitego (C4)
(I4) window_strides Jednowymiarowa stała tensora typu si64 (C2-C3)
(I5) lhs_dilation Jednowymiarowa stała tensora typu si64 (C5-C6)
(I6) rhs_dilation 1-wymiarowa stała tensora typu si64 (C7-C8)
(I7) window_reversal 1-wymiarowa stała tensora typu i1 (C9)
(I8) input_batch_dimension stała typu si64 (C10), (C13)
(I9) input_feature_dimension stała typu si64 (C11), (C13-C14)
(I10) input_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C12), (C13)
(I11) kernel_input_feature_dimension stała typu si64 (C14), (C18)
(I12) kernel_output_feature_dimension stała typu si64 (C15-C16), (C18), (C28)
(I13) kernel_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C17-C18)
(I14) output_batch_dimension stała typu si64 (C20)
(I15) output_feature_dimension stała typu si64 (C20), (C29)
(I16) output_spatial_dimensions 1-wymiarowa stała tensora typu si64 (C19-C20)
(I17) feature_group_count stała typu si64 (C11), (C14), (C16), (C21), (C23)
(I18) batch_group_count stała typu si64 (C10), (C15), (C22), (C23)
(I19) precision_config zmienna liczba typów enumeracji DEFAULT, HIGH i HIGHEST (C24)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub kwantyzowany tensor (C25-C27), (C29), (C31-C33)

Ograniczenia

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Zgodnie z input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Podany kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Zgodnie z output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) jest zdefiniowany jako:
    • dim(lhs, input_batch_dimension) / batch_group_count jeżeli result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) jeżeli result_dim = output_feature_dimension.
    • num_windows w przeciwnym razie, gdzie:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Jeśli operacja używa tensorów niekwantowych:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Jeśli operacja używa kwantowanych tensorów:
    • (C28) is_quantized(lhs) = is_quantized(result) and is_quantized(rhs).
    • (C29) Jeśli is_per_axis_quantized(rhs), to quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C30) Jeśli is_per_axis_quantized(result), to quantization_dimension(result) = output_feature_dimension.
    • Jeśli is_quantized(lhs):
    • (C31) storage_type(lhs) = storage_type(rhs).
    • (C32) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C33) Jeśli is_per_tensor_quantized(rhs), to is_per_tensor_quantized(result).
    • Jeśli !is_quantized(lhs):
    • (C34) element_type(lhs) = expressed_type(rhs) = element_type(result).

Przykłady

// %lhs: [[
//        [[1], [2], [5], [6]],
//        [[3], [4], [7], [8]],
//        [[10], [11], [14], [15]],
//        [[12], [13], [16], [17]]
//      ]]
//
// %rhs: [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
// %padding: [[1, 1],
//            [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
  window_strides = array<i64: 4, 4>,
  lhs_dilation = array<i64: 2, 2>,
  rhs_dilation = array<i64: 1, 1>,
  window_reversal = array<i1: false, false>,
  dimension_numbers = #stablehlo.conv<raw
    input_batch_dimension = 0,
    input_feature_dimension = 3,
    input_spatial_dimensions = [0, 1],
    kernel_input_feature_dimension = 2,
    kernel_output_feature_dimension = 3,
    kernel_spatial_dimensions = [0, 1],
    output_batch_dimension = 0,
    output_feature_dimension = 3,
    output_spatial_dimensions = [1, 2]
  >,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
//            [[1], [5]],
//            [[10], [14]]
//          ]]

Więcej przykładów

dynamic_gather

Semantyka

Ta operacja jest pod względem funkcjonalnym identyczna z gather op, z tym że slice_sizes jest dynamicznie określany jako wartość.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor zagregowany z tensorów (C1), (C7), (C10-C12), (C14)
(I2) start_indices tensor typu liczba całkowita (C2), (C3), (C13)
(I3) slice_sizes 1-wymiarowy tensor typu całkowitego (C8), (C11–C13)
(I4) offset_dims 1-wymiarowa stała tensora typu si64 (C1), (C4–C5), (C13)
(I5) collapsed_slice_dims 1-wymiarowa stała tensora typu si64 (C1), (C6-C8), (C13)
(I6) start_index_map 1-wymiarowa stała tensora typu si64 (C3), (C9), (C10)
(I7) index_vector_dim stała typu si64 (C2), (C3), (C13)
(I8) indices_are_sorted stała typu i1

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C5), (C13–C14)

Ograniczenia

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) gdzie:
    • batch_dim_sizes = shape(start_indices), z tym że nie uwzględnia wymiaru start_indices odpowiadającego wymiarowi index_vector_dim.
    • offset_dim_sizes = shape(slice_sizes), z tym że nie uwzględnia wymiarów slice_sizes odpowiadających wymiarowi collapsed_slice_dims.
    • Funkcja combine umieszcza batch_dim_sizes na osi odpowiadającej batch_dims, a funkcja offset_dim_sizes – na osi odpowiadającej offset_dims.
  • (C14) element_type(operand) = element_type(result).

Przykłady

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Więcej przykładów

dynamic_iota

Semantyka

Ta operacja jest pod względem funkcjonalnym identyczna z operacją iota, ale kształt wyniku jest dynamicznie określany za pomocą parametru output_shape.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) output_shape 1-wymiarowy tensor typu całkowitego (C1), (C2)
(I2) iota_dimension si64 (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C2)

Ograniczenia

  • (C1) 0 <= iota_dimension < size(output_shape).
  • (C2) rank(result) = size(output_shape).

Przykłady

%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
  iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

 Więcej przykładów

dynamic_pad

Semantyka

Ta operacja jest funkcjonalnie identyczna z pad, ale z wartościami edge_padding_low, edge_padding_highinterior_padding określonymi dynamicznie.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand kwantowy tensor lub tensor kwantowy (C1), (C2), (C4)
(I2) padding_value Tensor 0-wymiarowy lub kwantowy tensor na tensor (C1)
(I3) edge_padding_low Jednowymiarowy tensor typu liczby całkowitej (C1), (C4)
(I4) edge_padding_high 1-wymiarowy tensor typu całkowitego (C1), (C4)
(I5) interior_padding Jednowymiarowy tensor typu liczby całkowitej (C2-C4)

Wyniki

Nazwa Typ Ograniczenia
result kwantowy tensor lub tensor kwantowy (C3-C6)

Ograniczenia

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Przykłady

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
  %edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Więcej przykładów

dynamic_reshape

Semantyka

Ta operacja jest pod względem funkcjonalnym identyczna z operacją reshape, ale kształt wyniku jest dynamicznie określany za pomocą argumentu output_shape.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub kwantyzowany tensor (K1–C3)
(I2) output_shape 1-wymiarowy tensor typu całkowitego (C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub kwantyzowany tensor (C1–C4)

Ograniczenia

  • (C1) element_type(result) otrzymuje:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand) – oprócz tych quantization_dimension(operand) i quantization_dimension(result) mogą się różnić.
  • (C2) size(operand) = size(result).
  • (C3) Jeśli is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
  • (C4) size(output_shape) = rank(result).

Przykłady

// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]

 Więcej przykładów

dynamic_slice

Semantyka

Wyodrębnia wycinek z operand, używając dynamicznie obliczanych indeksów początkowych i tworzy tensor result. start_indices zawiera indeksy początkowe przekroju dla każdego wymiaru, który może zostać dostosowany, a slice_sizes zawiera rozmiary przekroju dla każdego wymiaru. Bardziej oficjalnie: result[result_index] = operand[operand_index], gdzie:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand kwantowy tensor lub tensor kwantowy (C1), (C2), (C4)
(I2) start_indices zmienna liczba 0-wymiarowych tensorów typu całkowitego (C2), (C3)
(I3) slice_sizes Jednowymiarowa stała tensora typu si64 (C2), (C4), (C5)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1), (C5)

Ograniczenia

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Przykłady

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

 Więcej przykładów

dynamic_update_slice

Semantyka

Tworzy tensor result, który jest równy tensorowi operand, ale wycinek zaczynający się od start_indices jest aktualizowany o wartości w update. Bardziej formalnie result[result_index] jest zdefiniowana jako:

  • update[update_index] jeśli 0 <= update_index < shape(update) gdzie:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • W przeciwnym razie: operand[result_index].

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor zagregowany z tensorów (C1-C4), (C6)
(I2) update tensor lub tensor zagregowany z tensorów (C2), (C3), (C6)
(I3) start_indices zmienna liczba 0-wymiarowych tensorów typu całkowitego (C4), (C5)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1)

Ograniczenia

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Przykłady

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Więcej przykładów

wykładniczo

Semantyka

Wykonuje elementarną operację wykładniczą na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: exp z IEEE-754.
  • Liczby zespolone: wykładnik zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(exponential, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

 Więcej przykładów

exponential_minus_one

Semantyka

Wykonuje elementarną operację wykładniczą minus 1 na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: expm1 z IEEE-754.
  • W przypadku liczb zespolonych: zespolona wykładnicza wartość minus 1.
  • W przypadku typów skwantowanych: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

 Więcej przykładów

FFT

Semantyka

Wykonuje bezpośrednie i odwrotne transformacje Fouriera w przypadku rzeczywistych i zespolonych wejść/wyjść.

fft_type może mieć jedną z tych wartości:

  • FFT: FFT kompleksowe w przód.
  • IFFT: odwrotna transformacja FFT z kompleksowej na złożoną.
  • RFFT: przesuwanie widoku w kierunku rzeczywistym do złożonego.
  • IRFFT: odwrotna transformacja Fouriera z realnego na zespolony (czyli przyjmuje zespolony, zwraca rzeczywisty).

Bardziej formalnie, jeśli dana funkcja fft przyjmuje jako dane wejściowe 1-wymiarowe tensory złożonych typów, to zwraca 1-wymiarowe tensory tego samego typu co dane wyjściowe i wylicza dyskretną transformację Fouriera:

W przypadku fft_type = FFT wartość result jest zdefiniowana jako końcowy wynik serii obliczeń L, gdzie L = size(fft_length). Na przykład w przypadku L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Ponadto funkcja ifft, która ma ten sam podpis typu i oblicza odwrotność fft:

W przypadku fft_type = IFFT wartość result jest zdefiniowana jako odwrotna wartość obliczeń dla fft_type = FFT. Na przykład w przypadku L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Funkcja rfft przyjmuje 1-wymiarowe tensory typu zmiennoprzecinkowego i tworzy 1-wymiarowe tensory typu zespolonego o tej samej semantyce zmiennoprzecinkowej. Działa ona w ten sposób:

  • rfft(real_operand) = truncated_result gdzie
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(gdy dyskretna transformacja Fouriera jest obliczana dla rzeczywistych operandów, pierwsze N/2 + 1 elementy wyniku jednoznacznie określają pozostałą część wyniku, dlatego wynik rfft jest obcinany, aby uniknąć obliczania zbędących elementów).

W przypadku funkcji fft_type = RFFT wartość result jest definiowana jako końcowy wynik serii obliczeń L, gdzie L = size(fft_length). Na przykład w przypadku L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Na koniec, jeśli dana funkcja irfft ma tę samą deklarację typu i oblicza odwrotność funkcji rfft:

W przypadku fft_type = IRFFT wartość result jest zdefiniowana jako odwrotna wartość obliczeń dla fft_type = RFFT. Na przykład w przypadku L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego (C1), (C2), (C4), (C5)
(I2) fft_type enum FFT, IFFT, RFFTIRFFT (C2), (C5)
(I3) fft_length 1-wymiarowa stała tensora typu si64 (C1), (C3), (C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego (C2), (C4), (C5)

Ograniczenia

  • (C1) size(fft_length) <= rank(operand).
  • (C2) Związek między typami elementów operandresult jest różny:
    • Jeśli fft_type = FFT, element_type(operand) i element_type(result) mają ten sam typ złożony.
    • Jeśli fft_type = IFFT, element_type(operand) i element_type(result) mają ten sam typ złożony.
    • Jeśli fft_type = RFFT, element_type(operand) jest typem zmiennoprzecinkowym, a element_type(result) jest złożonym typem o tej samej semantyce zmiennoprzecinkowej.
    • Jeśli fft_type = IRFFT, element_type(operand) jest typem złożonym, a element_type(result) jest typem zmiennoprzecinkowym o tej samej semantyce zmiennoprzecinkowej.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Jeśli wśród elementów operandresult znajduje się tensor real typu zmiennoprzecinkowego, to shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) z wyjątkiem:
    • Jeśli fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Jeśli fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Przykłady

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

piętro

Semantyka

Przeprowadza element po elemencie zaokrąglenie w dół tensora operand i generuje tensor result. Realizuje operację roundToIntegralTowardNegative według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(floor, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Więcej przykładów

zbierać

Semantyka

Gromadzi wycinki z tensora operand z przesunięć określonych w start_indices i tworzy tensor result.

Ten diagram pokazuje na konkretnym przykładzie, jak elementy w result są mapowane na elementy w operand. Diagram wybiera kilka przykładowych result indeksów i szczegółowo wyjaśnia, którym indeksom operand odpowiadają indeksy.

zbierać

Bardziej formalnie: result[result_index] = operand[operand_index], gdzie:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index jest zdefiniowany jako:
    • start_indices[bi0, ..., :, ..., biN], gdzie bi to poszczególne elementy w batch_index, a element : jest wstawiany w indeksie index_vector_dim, jeśli index_vector_dim < rank(start_indices).
    • [start_indices[batch_index]] w innych przypadkach.
  • W przypadku d_operand w axes(operand):
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) if d_operand = start_index_map[d_start].
    • W przeciwnym razie: full_start_index[d_operand] = 0.
  • W przypadku d_operand w axes(operand):
    • full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jeśli d_operand = operand_batching_dims[i_batching] i d_start = start_indices_batching_dims[i_batching].
    • full_batching_index[d_operand] = 0 w innych przypadkach.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], gdzie oi to poszczególne elementy w tablicy offset_index, a element 0 jest wstawiany na pozycjach z tablic collapsed_slice_dimsoperand_batching_dims.
  • operand_index = full_start_index + full_batching_index + full_offset_index.

Jeśli indices_are_sorted ma wartość true, implementacja może zakładać, że dane start_indices są posortowane według argumentu start_index_map. W przeciwnym razie działanie jest niezdefiniowane. Bardziej formalnie, w przypadku wszystkich i1 < i2indices(result):full_start_index(i1) <= full_start_index(i2).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor zagregowany z tensorów (C1), (C8), (C11), (C17), (C19-C21), (C23)
(I2) start_indices tensor typu liczba całkowita (C2–C3), (C14), (C17), (C22)
(I3) offset_dims 1-wymiarowa stała tensora typu si64 (C1), (C4-C5), (C22)
(I4) collapsed_slice_dims 1-wymiarowa stała tensora typu si64 (C1), (C6-C9), (C22)
(I5) operand_batching_dims 1-wymiarowa stała tensora typu si64 (C1), (C6), (C10-C12), (C16-C18), (C22)
(I6) start_indices_batching_dims 1-wymiarowa stała tensora typu si64 (C13-C17)
(I7) start_index_map 1-wymiarowa stała tensora typu si64 (C3), (C18-C19)
(I8) index_vector_dim stała typu si64 (C2-C3), (C15), (C22)
(I9) slice_sizes 1-wymiarowa stała tensora typu si64 (C9), (C12), (C20–C22)
(I10) indices_are_sorted stała typu i1

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C5), (C22-C23)

Ograniczenia

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
  • (C7) is_sorted(collapsed_slice_dims).
  • (C8) 0 <= collapsed_slice_dims < rank(operand).
  • (C9) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C10) is_sorted(operand_batching_dims).
  • (C11) 0 <= operand_batching_dims < rank(operand).
  • (C12) slice_sizes[operand_batching_dims...] <= 1.
  • (C13) is_unique(start_indices_batching_dims).
  • (C14) 0 <= start_indices_batching_dims < rank(start_indices).
  • (C15) index_vector_dim not in start_indices_batching_dims.
  • (C16) size(operand_batching_dims) == size(start_indices_batching_dims).
  • (C17) dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...).
  • (C18) is_unique(concatenate(start_index_map, operand_batching_dims)).
  • (C19) 0 <= start_index_map < rank(operand).
  • (C20) size(slice_sizes) = rank(operand).
  • (C21) 0 <= slice_sizes <= shape(operand).
  • (C22) shape(result) = combine(batch_dim_sizes, offset_dim_sizes) gdzie:
    • batch_dim_sizes = shape(start_indices), z tym że nie uwzględnia wymiaru start_indices odpowiadającego wymiarowi index_vector_dim.
    • offset_dim_sizes = slice_sizes, z tą różnicą, że rozmiary wymiarów w polu slice_sizes odpowiadające wartościom collapsed_slice_dims i operand_batching_dims nie są uwzględniane.
    • Funkcja combine umieszcza batch_dim_sizes na osiach odpowiadających wartości batch_dims i offset_dim_sizes na osiach odpowiadających wartości offset_dims.
  • (C23) element_type(operand) = element_type(result).

Przykłady

// %operand: [
//            [
//             [[1, 2], [3, 4], [5, 6], [7, 8]],
//             [[9, 10],[11, 12], [13, 14], [15, 16]],
//             [[17, 18], [19, 20], [21, 22], [23, 24]]
//            ],
//            [
//             [[25, 26], [27, 28], [29, 30], [31, 32]],
//             [[33, 34], [35, 36], [37, 38], [39, 40]],
//             [[41, 42], [43, 44], [45, 46], [47, 48]]
//            ]
//           ]
// %start_indices: [
//                  [
//                   [[0, 0], [1, 0], [2, 1]],
//                   [[0, 1], [1, 1], [0, 9]]
//                  ],
//                  [
//                   [[0, 0], [2, 1], [2, 2]],
//                   [[1, 2], [0, 1], [1, 0]]
//                  ]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [3, 4],
    collapsed_slice_dims = [1],
    operand_batching_dims = [0],
    start_indices_batching_dims = [1],
    start_index_map = [2, 1],
    index_vector_dim = 3>,
  slice_sizes = array<i64: 1, 1, 2, 2>,
  indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[3, 4], [5, 6]],
//             [[13, 14], [15, 16]]
//            ],
//            [
//             [[33, 34], [35, 36]],
//             [[35, 36], [37, 38]],
//             [[41, 42], [43, 44]]
//            ]
//           ],
//           [
//            [
//             [[1, 2], [3, 4]],
//             [[13, 14], [15, 16]],
//             [[21, 22], [23, 24]]
//            ],
//            [
//             [[43, 44], [45, 46]],
//             [[33, 34], [35, 36]],
//             [[27, 28], [29, 30]]
//            ]
//           ]
//          ]

Więcej przykładów

get_dimension_size

Semantyka

Zwraca rozmiar danego dimension w ramach operand. Bardziej formalnie: result = dim(operand, dimension). Semantyka dotyczy tylko komponentu kształtu danego typu. Typ elementu może być dowolny.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub kwantyzowany tensor (C1)
(I2) dimension stała typu si64 (C1)

Wyniki

Nazwa Typ
result Tensor 0-wymiarowy typu si32

Ograniczenia

  • (C1) 0 <= dimension < rank(operand).

Przykłady

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

 Więcej przykładów

get_tuple_element

Semantyka

Wyodrębnia element w pozycji index krotki operand i tworzy result. Więcej formalnie: result = operand[index].

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tablice (C1), (C2)
(I2) index stała typu si32 (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result dowolny obsługiwany typ (C2)

Ograniczenia

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Przykłady

// %operand: ([1.0, 2.0], (3))
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

 Więcej przykładów

jeśli

Semantyka

Wynikiem jest wynik wykonania dokładnie jednej funkcji z true_branch lub false_branch w zależności od wartości elementu pred. W bardziej formalnym ujęciu: result = pred ? true_branch() : false_branch().

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) pred 0-wymiarowy tensor typu i1
(I2) true_branch funkcja (C1-C3)
(I3) false_branch funkcja (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów (C3)

Ograniczenia

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Przykłady

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Więcej przykładów

imag

Semantyka

Wyodrębnia część urojona z elementów tensora operand i tworzy tensor result. W formalnej postaci dla każdego elementu x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego (C1), (C2)

Ograniczenia

  • (C1) shape(result) = shape(operand).
  • (C2) Parametr element_type(result) jest zdefiniowany jako:
    • complex_element_type(element_type(operand)) jeżeli is_complex(operand).
    • W przeciwnym razie: element_type(operand).

Przykłady

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Więcej przykładów

infeed

Semantyka

Odczytuje dane z infeedu i generuje results.

Semantyka elementu infeed_config jest zdefiniowana przez implementację.

results składa się z wartości ładunku, które występują na początku, oraz tokena, który występuje na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).

Dane wejściowe

Etykieta Nazwa Typ
(I1) token token
(I2) infeed_config stała typu string

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów (C1-C3)

Ograniczenia

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) lub is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Przykłady

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

 Więcej przykładów

iota

Semantyka

Wypełnia tensor output wartościami w rosnącej kolejności, zaczynając od 0 w wymiarze iota_dimension. Bardziej formalnie,

output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) iota_dimension si64 (C1)

Wyniki

Nazwa Typ Ograniczenia
output tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)

Ograniczenia

  • (C1) 0 <= iota_dimension < rank(output).

Przykłady

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Więcej przykładów

is_finite

Semantyka

Sprawdza, czy wartość w x jest skończona (tzn.nie jest skończona (tzn. nie jest liczbą +Inf, -Inf ani NaN) i generuje tensor y. Realizuje operację isFinitez specyfikacji IEEE-754. W przypadku typów kwantowych wynik to zawsze true.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) x tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
y tensor typu logicznego (C1)

Ograniczenia

  • (C1) shape(x) = shape(y).

Przykłady

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

 Więcej przykładów

log

Semantyka

Wykonuje operację logarytmiczną z uwzględnieniem elementów na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: log z IEEE-754.
  • W przypadku liczb zespolonych: logarytm zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(log, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

 Więcej przykładów

log_plus_one

Semantyka

Wykonuje logarytm z elementami oraz 1 operację na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: logp1 z IEEE-754.
  • W przypadku liczb zespolonych: logarytm zespolony plus 1.
  • W przypadku typów skwantowanych: dequantize_op_quantize(log_plus_one, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

 Więcej przykładów

logistyczna

Semantyka

Wykonuje operacje logistyczne związane z elementami na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: division(1, addition(1, exp(-x))) z IEEE-754.
  • W przypadku liczb zespolonych: złożona logistyczna.
  • W przypadku typów skwantowanych: dequantize_op_quantize(logistic, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

 Więcej przykładów

mapa

Semantyka

Stosuje funkcję mapy computation do inputs na dimensions i tworzy tensor result.

Więcej formalnie: result[result_index] = computation(inputs...[result_index]).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C1–C4)
(I2) dimensions 1-wymiarowa stała tensora typu si64 (K3)
(I3) computation funkcja (K4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1), (C4)

Ograniczenia

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation ma typ (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, gdzie Ei = element_type(inputs[i]) i E' = element_type(result).

Przykłady

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

 Więcej przykładów

maksimum

Semantyka

Wykonuje operację elementarnego maksimum na tensorach lhsrhs oraz zwraca tensor result. W zależności od typu elementu wykonuje te działania:

  • W przypadku wartości logicznych: logiczny LUB.
  • W przypadku liczb całkowitych: maksymalna liczba całkowita.
  • Dla liczb zmiennoprzecinkowych: maximum z IEEE-754.
  • W przypadku liczb zespolonych: maksymalna wartość leksykograficzna pary (real, imaginary). Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).
  • W przypadku typów skwantowanych:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs kwantowy tensor lub tensor kwantowy (C1)
(I2) rhs tensor lub tensor zagregowany z tensorów (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

 Więcej przykładów

minimum

Semantyka

Wykonuje elementarną operację min na tensorach lhsrhs oraz zwraca tensor result. W zależności od typu elementu wykonuje te działania:

  • W przypadku wartości logicznych: operator logiczny OR.
  • W przypadku liczb całkowitych: minimalna liczba całkowita.
  • Dla liczb zmiennoprzecinkowych: minimum z IEEE-754.
  • W przypadku liczb zespolonych: leksykograficzne minimum dla pary (real, imaginary). Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).
  • W przypadku typów kwantowych:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs kwantowy tensor lub tensor kwantowy (C1)
(I2) rhs tensor lub tensor zagregowany z tensorów (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Więcej przykładów

pomnóż

Semantyka

Wykonuje element po elemencie iloczyn dwóch tensorów lhsrhs oraz generuje tensor result. W zależności od typu elementu:

  • W przypadku wartości logicznych: logiczne I.
  • W przypadku liczb całkowitych: mnożenie liczb całkowitych.
  • W przypadku jednostek zmiennoprzecinkowych: multiplication w standardzie IEEE-754.
  • Liczby zespolone: mnożenie zespolone.
  • W przypadku typów skwantowanych:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs kwantowy tensor lub tensor kwantowy (C1)
(I2) rhs tensor lub tensor zagregowany z tensorów (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

 Więcej przykładów

negacja

Semantyka

Wykonuje negację tensora operand z uwzględnieniem elementów i generuje tensor result. W zależności od typu elementu:

  • W przypadku liczb całkowitych ze znakiem: negacja liczby całkowitej.
  • W przypadku liczb bez znaku: bitowy zapis jako liczba ze znakiem, negacja liczby, bitowy zapis z powrotem jako liczba bez znaku.
  • Dla liczb zmiennoprzecinkowych: negate z IEEE-754.
  • Liczby zespolone: negacja zespolona.
  • W przypadku typów kwantowych: dequantize_op_quantize(negate, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Więcej przykładów

nie

Semantyka

Przeprowadza elementarną operację NOT na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • W przypadku wartości logicznych: NOT logiczna.
  • W przypadku liczb całkowitych: bitowa operacja NOT.

Argumenty

Nazwa Typ Ograniczenia
operand tensor typu logicznego lub całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego lub całkowitego (C1)

Ograniczenia

  • (C1) type(operand) = type(result).

Przykłady

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

 Więcej przykładów

optimization_barrier

Semantyka

Zapewnienie, że operacje, które generują operand, są wykonywane przed operacjami, które zależą od result, oraz zapobieganie przemieszczaniu operacji przez barierę przez przekształcenia kompilatora. W przeciwnym razie operacja jest tożsamością, czyli result = operand.

Argumenty

Nazwa Typ Ograniczenia
operand zmienna liczba tensorów, tensorów skwantowanych na podstawie tensora lub tokenów (C1)

Wyniki

Nazwa Typ Ograniczenia
result zmienna liczba tensorów, tensorów skwantowanych na podstawie tensora lub tokenów (C1)

Ograniczenia

  • (C1) type(operand...) = type(result...).

Przykłady

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Więcej przykładów

lub

Semantyka

Wykonuje elementarną operację OR na dwóch tensorach lhsrhs, tworząc tensor result. W zależności od typu elementu:

  • W przypadku wartości logicznych: operator logiczny LUB.
  • W przypadku liczb całkowitych: operator bitowy LUB.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu liczba całkowita lub logiczna (C1)
(I2) rhs tensor typu liczba całkowita lub logiczna (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu liczba całkowita lub logiczna (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

Więcej przykładów

outfeed

Semantyka

Zapisuje inputs w outfiedzie i generuje token result.

Semantyka elementu outfeed_config jest zdefiniowana przez implementację.

Dane wejściowe

Etykieta Nazwa Typ
(I1) inputs zmienna liczba tensorów lub zagęszczonych tensorów
(I2) token token
(I3) outfeed_config stała typu string

Wyniki

Nazwa Typ
result token

Przykłady

%result = "stablehlo.outfeed"(%input0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

 Więcej przykładów

pad

Semantyka

Rozszerza operand przez wypełnienie przestrzeni wokół tensora oraz między elementami tensora za pomocą podanego padding_value.

Parametry edge_padding_lowedge_padding_high określają ilość wypełnień dodawanych na niskich wartościach (obok indeksu 0) i na wysokich wartościach (obok najwyższego indeksu) w każdym wymiarze. Ilość wypełnienia może być ujemna, a wartość bezwzględna ujemnego wypełnienia wskazuje liczbę elementów do usunięcia z wybranego wymiaru.

interior_padding określa dopełnienie między dwoma dowolnymi elementami w każdym wymiarze, które może nie być ujemne. Wypełnienie wewnętrzne występuje przed wypełnieniem krawędzi, dzięki czemu wypełnienie krawędzi o ujemnej wartości spowoduje usunięcie elementów z operanda z wypełnieniem wewnętrznym.

Bardziej formalnie result[result_index] jest zdefiniowana jako:

  • operand[operand_index] jeśli result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • W przeciwnym razie: padding_value.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand kwantowy tensor lub tensor kwantowy (C1), (C2), (C4)
(I2) padding_value Tensor 0-wymiarowy lub kwantowy tensor na tensor (C1)
(I3) edge_padding_low 1-wymiarowa stała tensora typu si64 (C1), (C4)
(I4) edge_padding_high 1-wymiarowa stała tensora typu si64 (C1), (C4)
(I5) interior_padding 1-wymiarowa stała tensora typu si64 (C2-C4)

Wyniki

Nazwa Typ Ograniczenia
result kwantowy tensor lub tensor kwantowy (C3-C6)

Ograniczenia

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Przykłady

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = array<i64: 0, 1>,
  edge_padding_high = array<i64: 2, 1>,
  interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Więcej przykładów

partition_id

Semantyka

Tworzy partition_id bieżącego procesu.

Wyniki

Nazwa Typ
result 0-wymiarowy tensor typu ui32

Przykłady

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

 Więcej przykładów

Popcnt

Semantyka

Wykonuje za pomocą elementu liczby bitów ustawione w tensorze operand i tworzy tensor result.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu liczba całkowita (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu liczba całkowita (C1)

Ograniczenia

  • (C1) type(operand) = type(result).

Przykłady

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

 Więcej przykładów

moc

Semantyka

Wykonuje elementową potęgowanie tensora lhs przez tensor rhs i tworzy tensor result. W zależności od typu elementu:

  • W przypadku liczb całkowitych: wykładnik całkowity.
  • Dla liczb zmiennoprzecinkowych: pow z IEEE-754.
  • W przypadku liczb zespolonych: wykładnik zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(power, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)
(I2) rhs tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora według tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

 Więcej przykładów

real

Semantyka

Wyodrębnia rzeczywistą część z operand element po elemencie i tworzy tensor result. W formalnej postaci dla każdego elementu x: real(x) = is_complex(x) ? real_part(x) : x.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego (C1), (C2)

Ograniczenia

  • (C1) shape(result) = shape(operand).
  • (C2) Parametr element_type(result) jest zdefiniowany jako:
    • complex_element_type(element_type(operand)) jeżeli is_complex(operand).
    • W przeciwnym razie: element_type(operand).

Przykłady

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

 Więcej przykładów

recv

Semantyka

Pobiera dane z kanału z parametrem channel_id i tworzy results.

Jeśli is_host_transfer to true, operacja przenosi dane z hosta. W przeciwnym razie dane są przenoszone z innego urządzenia. Co to oznacza, zależy od implementacji. Ta flaga powiela informacje podane w flagach channel_type, dlatego w przyszłości planujemy zachować tylko jedną z nich (#666).

results składa się z wartości ładunku, które występują na początku, oraz tokena, który występuje na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) token token (C4)
(I2) channel_id stała typu si64
(I3) channel_type enum DEVICE_TO_DEVICEHOST_TO_DEVICE (C1)
(I4) is_host_transfer stała typu i1 (C1)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów (C2–C4)

Ograniczenia

  • (C1) channel_type jest zdefiniowany jako:
    • HOST_TO_DEVICE jeżeli is_host_transfer = true,
    • DEVICE_TO_DEVICE w innych przypadkach.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) lub is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Przykłady

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Więcej przykładów

zmniejszyć

Semantyka

Stosuje funkcję redukcji body do inputs i init_values wzdłuż dimensions i zwraca tensory results.

Kolejność redukcji jest zdefiniowana w implementacji, co oznacza, że body i init_values muszą utworzyć monoid, aby zagwarantować, że operacja da takie same wyniki przy wszystkich danych wejściowych we wszystkich implementacjach. Jednak w przypadku wielu popularnych rabatów to założenie nie jest spełnione. Przykładowo dodawanie liczb zmiennoprzecinkowych w przypadku wartości body i zera w przypadku wartości init_values nie tworzy monoidu, ponieważ dodawanie liczb zmiennoprzecinkowych nie jest skojarzone.

Bardziej formalnie: results...[j0, ..., jR-1] = reduce(input_slices_converted), gdzie:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], gdzie : są wstawiane w miejscu dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) dla niektórych drzew binarnych schedule gdzie:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule to pełne drzewo binarne zdefiniowane przez implementację, którego przeglądanie w kolejności zgodnej z ich poziomem w drzewie obejmuje:
    • wartości input_slices_converted...[index] dla wszystkich index w tablicy index_space(input_slices_converted) w rosnącym porządku leksykograficznym index.
    • Wplecione z określoną przez implementację ilością znaków init_values_converted w określonych przez implementację pozycjach.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C1-C4), (C6), (C7)
(I2) init_values zmienna liczba 0-wymiarowych tensorów lub tensorów skwantowanych na tensor (C2), (C3)
(I3) dimensions 1-wymiarowa stała tensora typu si64 (C4), (C5), (C7)
(I4) body funkcja (C6)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C3), (C7), (C8)

Ograniczenia

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body ma typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) gdzie is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), z tym że nie uwzględnia wymiaru rozmiary w elementach inputs... odpowiadających elementom dimensions.
  • (C8) element_type(results[i]) = Ei dla wszystkich i[0,N).

Przykłady

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

 Więcej przykładów

reduce_precision

Semantyka

Przeprowadza konwersję elementów operand na inny typ zmiennoprzecinkowy, który używa exponent_bits i mantissa_bits, a następnie z powrotem na pierwotny typ zmiennoprzecinkowy. Tworzy też tensor output.

Bardziej formalnie:

  • Bity mantissy pierwotnej wartości są aktualizowane w celu zaokrąglenia pierwotnej wartości do najbliższej wartości, którą można przedstawić za pomocą funkcji mantissa_bits przy użyciu semantyki roundToIntegralTiesToEven.
  • Następnie, jeśli mantissa_bits jest mniejsza od pierwotnej wartości, bity modliszki są obcinane do mantissa_bits.
  • Następnie, jeśli bity wykładnika wyniku pośredniego nie mieszczą się w zakresie określonym przez exponent_bits, wynik pośredni jest przepełniony do nieskończoności z użyciem znaku pierwotnego lub jest podpięty do zera z użyciem znaku pierwotnego.
  • W przypadku typów skwantowanych wykonuje operację dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)
(I2) exponent_bits stała typu si32 (C2)
(I3) mantissa_bits stała typu si32 (K3)

Wyniki

Nazwa Typ Ograniczenia
output tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Przykłady

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

 Więcej przykładów

reduce_scatter

Semantyka

reduce_scatter

W ramach każdej grupy procesów w siatce procesów StableHLO wykonuje redukcję (za pomocą funkcji computations) wartości tensora operand z każdego procesu, dzieli wynik redukcji wzdłuż osi scatter_dimension na części, a potem rozprasza te części między procesami, aby wygenerować result.

Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:

  • cross_replica(replica_groups) jeśli channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) if channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) if channel_id > 0 and use_global_device_ids = true.

Następnie w każdym wierszu process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] dla wszystkich sender w process_group, gdzie receiver_index = process_group.index(receiver).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand kwantowy tensor lub tensor kwantowy (C1), (C2), (C7), (C8)
(I2) scatter_dimension stała typu si64 (C1), (C2), (C8)
(I3) replica_groups 2-wymiarowa stała tensora typu si64 (C3-C5)
(I4) channel_id stała typu si64 (C6)
(I5) use_global_device_ids stała typu i1 (C6)
(I6) computation funkcja (C7)

Wyniki

Nazwa Typ Ograniczenia
result kwantowy tensor lub tensor kwantowy (C8-C9)

Ograniczenia

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) jest zdefiniowana jako:
    • num_replicas, jeśli używana jest właściwość cross_replica.
    • num_replicas, jeśli używana jest właściwość cross_replica_and_partition.
    • num_processes, jeśli używana jest właściwość flattened_ids.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Jeśli use_global_device_ids = true, to channel_id > 0.
  • (C7) computation ma typ (tensor<E>, tensor<E>) -> (tensor<E>), gdzie is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) z wyjątkiem:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Przykłady

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Więcej przykładów

reduce_window

Semantyka

Stosuje funkcję redukcji body do okien inputsinit_values oraz zwraca wartość results.

Ten diagram pokazuje, jak elementy w elementach results... są obliczane na podstawie elementów inputs... na konkretnym przykładzie.

reduce_window

Bardziej oficjalnie: results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (patrz reduce), gdzie:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) inputs liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor (C1–C4), (C6), (C8), (C10), (C12), (C13), (C15)
(I2) init_values liczba zmiennoprzecinkowa tensorów 0-wymiarowych lub kwantyzowanych tensorów na tensor (C1), (C13)
(I3) window_dimensions Jednowymiarowa stała tensora typu si64 (C4), (C5), (C15)
(I4) window_strides 1-wymiarowa stała tensora typu si64 (C6), (C7), (C15)
(I5) base_dilations 1-wymiarowa stała tensora typu si64 (C8), (C9), (C15)
(I6) window_dilations 1-wymiarowa stała tensora typu si64 (C10), (C11), (C15)
(I7) padding 2-wymiarowa stała tensora typu si64 (C12), (C15)
(I8) body funkcja (C13)

Wyniki

Nazwa Typ Ograniczenia
results liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor (C1), (C14–C16)

Ograniczenia

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body ma typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>) gdzie is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows gdzie:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei dla wszystkich i[0,N).

Przykłady

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 2, 1>,
  window_strides = array<i64: 4, 1>,
  base_dilations = array<i64: 2, 1>,
  window_dilations = array<i64: 3, 1>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

 Więcej przykładów

reszta

Semantyka

Wykonuje element po elemencie resztę dzielnika lhs i dzielonki rhs tensorów oraz zwraca tensor result.

Bardziej oficjalnie znak wyniku jest wyliczany z dywidendy, a wartość bezwzględna wyniku jest zawsze mniejsza od wartości bezwzględnej dzielnika. Reszta jest obliczana według wzoru: lhs - d * rhs, gdzie d jest określona przez:

  • W przypadku liczb całkowitych: stablehlo.divide(lhs, rhs).
  • W przypadku liczb zmiennoprzecinkowych: division(lhs, rhs) z IEEE-754 z atrybutem zaokrąglenia roundTowardZero.
  • W przypadku liczb zespolonych: do ustalenia (#997).
  • W przypadku typów skwantowanych:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

W przypadku elementów zmiennoprzecinkowych ta operacja jest niezgodna z operacją remainder ze specyfikacji IEEE-754, w której d jest wartością całkowitą zbliżoną do dokładnej wartości parametru lhs/rhs z opisem równomiernym.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)
(I2) rhs tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

 Więcej przykładów

replica_id

Semantyka

Tworzy replica_id bieżącego procesu.

Wyniki

Nazwa Typ
result 0-wymiarowy tensor typu ui32

Przykłady

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

 Więcej przykładów

zmienić kształt

Semantyka

Przekształca tensor operand w tensor result. W zasadzie jest to zachowanie tego samego reprezentacji kanonicznej, ale potencjalnie zmiany kształtu, np. z tensor<2x3xf32> na tensor<3x2xf32> lub tensor<6xf32>.

Dokładniej rzecz ujmując: result[result_index] = operand[operand_index], gdzie result_index i operand_index mają to samo miejsce w kolejności leksykograficznej index_space(result) i index_space(operand).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub kwantyzowany tensor (C1-C3)

Wyniki

Nazwa Typ Ograniczenia
result tensor (tensor kwantowy) (C1-C3)

Ograniczenia

  • (C1) element_type(result) otrzymuje:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand) – oprócz tych quantization_dimension(operand) i quantization_dimension(result) mogą się różnić.
  • (C2) size(operand) = size(result).
  • (C3) Jeśli is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Przykłady

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

 Więcej przykładów

odwróć

Semantyka

Odwraca kolejność elementów w operand wzdłuż podanego wymiaru dimensions i tworzy tensor result. Bardziej formalnie: result[result_index] = operand[operand_index] gdzie:

  • operand_index[d] = dim(result, d) - result_index[d] - 1jeżeli ddimensions.
  • operand_index[d] = result_index[d] w innych przypadkach.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor zagregowany z tensorów (C1), (C3)
(I2) dimensions 1-wymiarowa stała tensora typu si64 (C2), (C3)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1), (C3)

Ograniczenia

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Przykłady

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

 Więcej przykładów

rng

Semantyka

Generuje liczby losowe za pomocą algorytmu rng_distribution i tworzy tensor result o kształcie shape.

Jeśli rng_distribution = UNIFORM, liczby losowe są generowane zgodnie z rozkładem stałym w przedziałzie [a, b). Jeśli a >= b, zachowanie jest nieokreślone.

Jeśli rng_distribution = NORMAL, liczby losowe są generowane zgodnie z rozkładem normalnym, ze średnią = a i odchyleniem standardowym = b. Jeśli b < 0, zachowanie jest nieokreślone.

Dokładny sposób generowania liczb losowych jest zdefiniowany w implementacji. Mogą na przykład być deterministyczne, ale nie muszą.

W rozmowach z wielu interesariuszami okazało się, że ta opcja jest w istocie wycofana, dlatego w przyszłości planujemy jej usunięcie (#597).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) a 0-wymiarowy tensor typu liczby całkowitej, wartości logicznej lub zmiennoprzecinkowej (C1), (C2)
(I2) b 0-wymiarowy tensor typu liczby całkowitej, wartości logicznej lub zmiennoprzecinkowej (C1), (C2)
(I3) shape 1-wymiarowa stała tensora typu si64 (C3)
(I4) rng_distribution wyliczenie UNIFORM i NORMAL (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowita, logiczna lub zmiennoprzecinkowa (C1-C3)

Ograniczenia

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Jeśli rng_distribution = NORMAL, to is_float(a).
  • (C3) shape(result) = shape.

Przykłady

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semantyka

Zwraca pole output wypełnione jednolitymi losowymi bitami i zaktualizowany stan wyjściowy output_state za pomocą algorytmu generatora liczb pseudorandom rng_algorithm z przypisanym stanem początkowym initial_state. Wyjście jest zdefiniowane jako funkcja deterministyczna initial_state, ale nie jest deterministyczna w różnych implementacjach.

rng_algorithm jest jedną z tych:

  • DEFAULT: algorytm zdefiniowany przez implementację.
  • THREE_FRY: wariant algorytmu Threefry zdefiniowany przez implementację.*
  • PHILOX: wariant algorytmu Philox zdefiniowany przez implementację.*

* Patrz: Salmon et al. SC 2011. Losowe liczby równoległe: to bardzo proste.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) rng_algorithm wyliczenie DEFAULT, THREE_FRY i PHILOX (C2)
(I2) initial_state Jednowymiarowy tensor typu ui64 (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
output_state 1-wymiarowy tensor typu ui64 (C1)
output tensor typu liczba całkowita lub zmiennoprzecinkowa

Ograniczenia

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) zdefiniowano jako:
    • zdefiniowaną przez implementację, jeśli rng_algorithm = DEFAULT.
    • 2 jeżeli rng_algorithm = THREE_FRY.
    • 2 lub 3, jeśli rng_algorithm = PHILOX.

Przykłady

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semantyka

Wykonuje zaokrąglanie z poziomu elementu do najbliższej liczby całkowitej, zrywając remisy od zera na tensorze operand i tworzy tensor result. Realizuje operację roundToIntegralTiesToAway zgodnie ze specyfikacją IEEE-754. W przypadku typów skompresowanych wykonuje działanie dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

 Więcej przykładów

round_nearest_even

Semantyka

Zaokrągla elementy tensora operand do najbliższej liczby całkowitej, rozstrzygając remisy na korzyść parzystej liczby całkowitej, i tworzy tensor result. Realizuje operację roundToIntegralTiesToEven według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(round_nearest_even, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

 Więcej przykładów

rsqrt

Semantyka

Przeprowadza elementarną operację odwrotnego pierwiastka kwadratowego na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: rSqrt z IEEE-754.
  • Liczby zespolone: odwrotny pierwiastek kwadratowy z liczby zespolonej.
  • W przypadku typów kwantowych: dequantize_op_quantize(rsqrt, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

 Więcej przykładów

rozproszenie

Semantyka

Wyświetla tensory results, które są równe tensorom inputs, ale kilka wycinków określonych przez scatter_indices zostało zaktualizowanych wartościami updates przy użyciu metody update_computation.

Ten diagram pokazuje na konkretnym przykładzie, jak elementy w updates... są mapowane na elementy w results.... Diagram przedstawia kilka przykładowych indeksów updates... i szczegółowo wyjaśnia, z którymi indeksami results... są one powiązane.

rozproszenie

Więcej formalnie dla wszystkich update_index w index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index to:
    • scatter_indices[si0, ..., :, ..., siN], gdzie si to poszczególne elementy w update_scatter_index, a element : jest wstawiany w indeksie index_vector_dim, jeśli index_vector_dimrank(scatter_indices).
    • [scatter_indices[update_scatter_index]] w innych przypadkach.
  • Za d_input w axes(inputs[0]):
    • full_start_index[d_input] = start_index[d_start] jeśli d_input = scatter_dims_to_operand_dims[d_start].
    • full_start_index[d_input] = 0 w innych przypadkach.
  • W przypadku d_input w axes(inputs[0]):
    • full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)] jeśli d_input = input_batching_dims[i_batching] i d_start = scatter_indices_batching_dims[i_batching].
    • W przeciwnym razie: full_batching_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], gdzie wi to poszczególne elementy w tablicy update_window_index, a element 0 jest wstawiany na pozycjach z tablic inserted_window_dimsinput_batching_dims.
  • result_index = full_start_index + full_batching_index + full_window_index.

W związku z tym results = exec(schedule, inputs), gdzie:

  • schedule to określona przez implementację permutacja funkcji index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results), gdzie:
    • Jeśli result_index mieści się w zakresie shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results to kopia results z wartością results...[result_index] ustawioną na updated_values....
    • W innym przypadku
    • updated_results = results.
  • exec([], results) = results.

Jeśli indices_are_sorted to true, implementacja może założyć, że scatter_indices są posortowane zgodnie z scatter_dims_to_operand_dims. W przeciwnym razie zachowanie jest nieokreślone. Bardziej formalnie dotyczy wszystkich i1 < i2 od indices(result), full_start_index(i1) <= full_start_index(i2).

Jeśli unique_indices to true, implementacja może założyć, że wszystkie indeksy result_index, które są rozproszone, są unikalne. Jeśli unique_indices to true, ale indeksy, do których jest rozproszony, nie są unikalne, działanie jest niezdefiniowane.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24)
(I2) scatter_indices tensor typu liczba całkowita (C4), (C15), (C19), (C22)
(I3) updates liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor (C3-C6), (C8)
(I4) update_window_dims 1-wymiarowa stała tensora typu si64 (C2), (C4), (C7-C8)
(I5) inserted_window_dims 1-wymiarowa stała tensora typu si64 (C2), (C4), (C9-C11)
(I6) input_batching_dims 1-wymiarowa stała tensora typu si64 (C2), (C4), (C9), (C12-13), (C17-18), (C20)
(I7) scatter_indices_batching_dims 1-wymiarowa stała tensora typu si64 (C14-C18)
(I8) scatter_dims_to_operand_dims 1-wymiarowa stała tensora typu si64 (C19-C21)
(I9) index_vector_dim stała typu si64 (C4), (C16), (C19), (C22)
(I10) indices_are_sorted stała typu i1
(I11) unique_indices stała typu i1
(I12) update_computation funkcja (C23)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C24-C25)

Ograniczenia

  • (C1) same(shape(inputs...)).
  • (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
    • size(input_batching_dims)`.
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes), gdzie:
    • update_scatter_dim_sizes = shape(scatter_indices), z tym że nie uwzględnia wymiaru scatter_indices odpowiadającego wymiarowi index_vector_dim.
    • update_window_dim_sizes <= shape(inputs[0]), z tym że nie uwzględnia wymiarów inputs[0] odpowiadających wymiarom inserted_window_dimsinput_batching_dims.
    • Funkcja combine umieszcza pole update_scatter_dim_sizes na osiach odpowiadających podziałowi update_scatter_dims i update_window_dim_sizes na osiach odpowiadających wartości update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(concatenate(inserted_window_dims, input_batching_dims))
  • (C10) is_sorted(inserted_window_dims).
  • (C11) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C12) is_sorted(input_batching_dims).
  • (C13) 0 <= input_batching_dims < rank(inputs[0])).
  • (C14) is_unique(scatter_indices_batching_dims).
  • (C15) 0 <= scatter_indices_batching_dims < rank(scatter_indices).
  • (C16) index_vector_dim not in scatter_indices_batching_dims.
  • (C17) size(input_batching_dims) == size(scatter_indices_batching_dims).
  • (C18) dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...).
  • (C19) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C20) is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)).
  • (C21) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C22) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C23) update_computation ma typ (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), gdzie is_promotable(element_type(inputs[i]), Ei).
  • (C24) shape(inputs...) = shape(results...).
  • (C25) element_type(results[i]) = Ei dla wszystkich i[0,N).

Przykłady

// %input: [
//          [
//           [[1, 2], [3, 4], [5, 6], [7, 8]],
//           [[9, 10],[11, 12], [13, 14], [15, 16]],
//           [[17, 18], [19, 20], [21, 22], [23, 24]]
//          ],
//          [
//           [[25, 26], [27, 28], [29, 30], [31, 32]],
//           [[33, 34], [35, 36], [37, 38], [39, 40]],
//           [[41, 42], [43, 44], [45, 46], [47, 48]]
//          ]
//         ]
// %scatter_indices: [
//                    [
//                     [[0, 0], [1, 0], [2, 1]],
//                     [[0, 1], [1, 1], [0, 9]]
//                    ],
//                    [
//                     [[0, 0], [2, 1], [2, 2]],
//                     [[1, 2], [0, 1], [1, 0]]
//                    ]
//                   ]
// %update: [
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ],
//           [
//            [[1, 1], [1, 1], [1, 1]],
//            [[1, 1], [1, 1], [1, 1]]
//           ]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [3, 4],
    inserted_window_dims = [1],
    input_batching_dims = [0],
    scatter_indices_batching_dims = [1],
    scatter_dims_to_operand_dims = [2, 1],
    index_vector_dim = 3>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
//           [
//            [[3, 4], [6, 7], [6, 7], [7, 8]],
//            [[9, 10],[11, 12], [15, 16], [17, 18]],
//            [[17, 18], [19, 20], [22, 23], [24, 25]]
//           ],
//           [
//            [[25, 26], [28, 29], [30, 31], [31, 32]],
//            [[35, 36], [38, 39], [38, 39], [39, 40]],
//            [[41, 42], [44, 45], [46, 47], [47, 48]]
//           ]
//          ]

Więcej przykładów

wybierz

Semantyka

Tworzy tensor result, w którym każdy element jest wybierany z tensora on_true lub on_false na podstawie wartości odpowiadającego elementu tensora pred. W formie bardziej oficjalnej: result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], gdzie pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. W przypadku typów kwantowych wybiera dequantize_select_quantize(pred, on_true, on_false, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) pred tensor typu i1 (C1)
(I2) on_true tensor lub tensor zagregowany z tensorów (C1-C2)
(I3) on_false tensor lub tensor zagregowany z tensorów (K2)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C2)

Ograniczenia

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Przykłady

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

 Więcej przykładów

select_and_scatter

Semantyka

Rozprasza wartości z tensora source za pomocą funkcji scatter na podstawie wyniku reduce_window z tensora input za pomocą funkcji select i tworzy tensor result.

Na diagramie poniżej widać, jak elementy w elementach result są obliczane na podstawie elementów operandsource.

select_and_scatter

Bardziej formalnie:

  • selected_values = reduce_window_without_init(...) z tymi danymi wejściowymi:

    • inputs = [operand].
    • window_dimensions, window_strides i padding, które są używane w takiej postaci, w jakiej zostały przesłane.
    • base_dilations = windows_dilations = 1.
    • body jest zdefiniowana jako:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    gdzie E = element_type(operand) i reduce_window_without_init działają dokładnie tak jak reduce_window, z tą różnicą, że schedule wartości podstawowej reduce (patrz redukcja) nie zawiera wartości init. Obecnie nie określono, co się stanie, jeśli odpowiednie okno nie zawiera wartości (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter), gdzie:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index jeśli selected_values[source_index] ma element operandoperand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub tensor zagregowany z tensorów (C1-C4), (C6), (C8-C11)
(I2) source kwantowy tensor lub tensor kwantowy (C1), (C2)
(I3) init_value 0-wymiarowy tensor lub tensor kwantyzowany na podstawie tensora (C3)
(I4) window_dimensions 1-wymiarowa stała tensora typu si64 (C2), (C4), (C5)
(I5) window_strides 1-wymiarowa stała tensora typu si64 (C2), (C6), (C7)
(I6) padding 2-wymiarowa stała tensora typu si64 (C2), (C8)
(I7) select funkcja (C9)
(I8) scatter funkcja (C10)

Wyniki

Nazwa Typ Ograniczenia
result kwantowy tensor lub tensor kwantowy (C11-C12)

Ograniczenia

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows, gdzie:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select ma typ (tensor<E>, tensor<E>) -> tensor<i1>, gdzie E = element_type(operand).
  • (C10) scatter ma typ (tensor<E>, tensor<E>) -> tensor<E>, gdzie is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Przykłady

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = array<i64: 3, 1>,
  window_strides = array<i64: 2, 1>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Więcej przykładów

wyślij

Semantyka

Wysyła kod inputs do kanału channel_id i tworzy token result.

Jeśli is_host_transfer to true, operacja przenosi dane do hosta. W przeciwnym razie dane zostaną przeniesione na inne urządzenie. Co to oznacza, zależy od implementacji. Ta flaga powiela informacje podane w flagach channel_type, dlatego w przyszłości planujemy zachować tylko jedną z nich (#666).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub zagęszczonych tensorów
(I2) token token
(I3) channel_id stała typu si64
(I4) channel_type enum DEVICE_TO_DEVICEDEVICE_TO_HOST (C1)
(I5) is_host_transfer stała typu i1 (C1)

Wyniki

Nazwa Typ
result token

Ograniczenia

  • (C1) channel_type jest zdefiniowany jako:
    • DEVICE_TO_HOST jeżeli is_host_transfer = true,
    • DEVICE_TO_DEVICE w innych przypadkach.

Przykłady

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

 Więcej przykładów

shift_left

Semantyka

Przeprowadza elementową operację przesunięcia w lewo na tensorze lhs o liczbę bitów rhs i tworzy tensor result.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu liczba całkowita (C1)
(I2) rhs tensor typu liczby całkowitej (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu liczba całkowita (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

 Więcej przykładów

shift_right_arithmetic

Semantyka

Przesuwa elementowo w prawo o określoną liczbę bitów (rhs) elementarną operację arytmetyczną na tensorze lhs i tworzy tensor result.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu liczba całkowita (C1)
(I2) rhs tensor typu liczby całkowitej (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu liczba całkowita (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

 Więcej przykładów

shift_right_logical

Semantyka

Wykonuje logiczne operację przesunięcia w prawo na tensorze lhs o rhs bitów i tworzy tensor result.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu liczba całkowita (C1)
(I2) rhs tensor typu liczby całkowitej (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu liczba całkowita (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Więcej przykładów

podpisywanie

Semantyka

Zwraca znak operand element po elemencie i tworzy tensor result. Bardziej formalnie, w przypadku każdego elementu x semantyka może być wyrażona za pomocą składni Pythona w ten sposób:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(sign, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

 Więcej przykładów

sinus

Semantyka

Wykonuje elementarną operację sinusa na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: sin z IEEE-754.
  • W przypadku liczb zespolonych: sinus zespolony.
  • W przypadku typów skwantowanych: dequantize_op_quantize(sine, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

 Więcej przykładów

wycinek

Semantyka

Wyodrębnia wycinek z operand, używając statycznie obliczonych indeksów początkowych i tworzy tensor result. start_indices zawiera indeksy początkowe przekroju dla każdego wymiaru, limit_indices zawiera indeksy końcowe (wykluczające) przekroju dla każdego wymiaru, a strides zawiera kroki dla każdego wymiaru.

Bardziej formalnie: result[result_index] = operand[operand_index], gdzie operand_index = start_indices + result_index * strides.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand kwantowy tensor lub tensor kwantowy (C1-C3), (C5)
(I2) start_indices 1-wymiarowa stała tensora typu si64 (C2), (C3), (C5)
(I3) limit_indices 1-wymiarowa stała tensora typu si64 (C2), (C3), (C5)
(I4) strides 1-wymiarowa stała tensora typu si64 (C2), (C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub tensor zagregowany z tensorów (C1), (C5)

Ograniczenia

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Przykłady

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = array<i64: 1, 2>,
  limit_indices = array<i64: 3, 4>,
  strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

 Więcej przykładów

sortuj

Semantyka

Sortuje jednowymiarowe wycinki obiektu inputs wzdłuż wymiaru dimension według wartości comparator i tworzy results.

W przeciwieństwie do podobnych danych wejściowych w innych operacjach funkcja dimension umożliwia stosowanie wartości ujemnych. Wybrane wartości mają następującą interpretację: W przyszłości możemy zablokować tę funkcję ze względu na spójność (problem #1377).

Jeśli is_stable ma wartość prawda, sortowanie jest stabilne, co oznacza, że zachowana jest względna kolejność elementów uznanych przez porównywarkę za równe. W przypadku pojedynczego wejścia dwa elementy e1e2 są uznawane za równe przez porównywacz, jeśli i tylko jeśli comparator(e1, e2) = comparator(e2, e1) = false. Zobacz formalny opis poniżej, aby dowiedzieć się, jak to uogólniać do wielu danych wejściowych.

Więcej formalnie dla wszystkich result_index w index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], gdzie riN to pojedyncze elementy w komórce result_index, a : jest wstawiony w miejscu adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • gdzie sort sortuje jednowymiarowy wycinek w kolejności niemalejącej, oczekując, że funkcja comparator_together zwróci true, jeśli argument po lewej stronie jest mniejszy niż argument drugiej strony po prawej stronie.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) inputs zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów (C1-C5)
(I2) dimension stała typu si64 (C4)
(I3) is_stable stała typu i1
(I4) comparator funkcja (C5)

Wyniki

Nazwa Typ Ograniczenia
results liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor (C2), (C3)

Ograniczenia

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, gdzie R = rank(inputs[0]).
  • (C5) comparator ma typ (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, gdzie Ei = element_type(inputs[i]).

Przykłady

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

 Więcej przykładów

sqrt

Semantyka

Wykonuje operację pierwiastka kwadratowego uwzględniającą elementy na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: squareRoot z IEEE-754.
  • W przypadku liczb zespolonych: pierwiastek kwadratowy z liczby zespolonej.
  • W przypadku typów skwantowanych: dequantize_op_quantize(sqrt, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Więcej przykładów

odejmij

Semantyka

Wykonuje element po elemencie odejmowanie dwóch tensorów lhsrhs oraz generuje tensor result. W zależności od typu elementu:

  • W przypadku liczb całkowitych: odejmowanie liczb całkowitych.
  • Dla liczb zmiennoprzecinkowych: subtraction z IEEE-754.
  • W przypadku liczb zespolonych: odejmowanie liczb zespolonych.
  • W przypadku typów skwantowanych:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)
(I2) rhs tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora według tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora (C1)

Ograniczenia

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Przykłady

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Więcej przykładów

tan

Semantyka

Wykonuje elementarną operację pochodnej cząstkowej na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: tan z IEEE-754.
  • W przypadku liczb zespolonych: tangens zespolony.
  • W przypadku typów kwantowych: dequantize_op_quantize(tan, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
//           [0.0, 1.63312e+16],
//           [0.0, 5.44375e+15]
//          ]

 Więcej przykładów

tanh

Semantyka

Wykonuje elementarną operację tangensu hiperbolicznego na tensorze operand i tworzy tensor result. W zależności od typu elementu:

  • Dla liczb zmiennoprzecinkowych: tanh z IEEE-754.
  • W przypadku liczb zespolonych: tangens hiperboliczny zespolony.
  • W przypadku typów skwantowanych:
    • dequantize_op_quantize(tanh, operand, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_type(operand) = baseline_type(result).

Przykłady

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

 Więcej przykładów

transponować

Semantyka

Zamienia wymiary tensora operand za pomocą permutation i tworzy tensor result. Bardziej formalnie: result[result_index] = operand[operand_index]gdzie result_index[d] = operand_index[permutation[d]].

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor lub kwantyzowany tensor (C1-C4)
(I2) permutation 1-wymiarowa stała tensora typu si64 (C2-C4)

Wyniki

Nazwa Typ Ograniczenia
result tensor lub kwantyzowany tensor (C1), (C3-C4)

Ograniczenia

  • (C1) element_type(result) otrzymuje:
    • element_type(operand), jeśli !is_per_axis_quantized(operand).
    • element_type(operand) z tym, że quantization_dimension(operand)quantization_dimension(result) mogą się różnić.
  • (C2) permutation jest permutacją range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Jeśli is_per_axis_quantized(result), to quantization_dimension(operand) = permutation(quantization_dimension(result)).

Przykłady

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

 Więcej przykładów

triangular_solve

Semantyka

Rozwiązuje partie układów równań liniowych z macierzowymi współczynnikami o dolnej lub górnej trójkątności.

Bardziej formalnie, przy założeniu ab, result[i0, ..., iR-3, :, :] jest rozwiązaniem op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :], gdy left_side jest równe true lub x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :], gdy left_side jest równe false, przy czym zmienna x jest rozwiązaniem op(a), a jej wartość jest określana przez transpose_a, który może być równy:

  • NO_TRANSPOSE: wykonaj operację, używając a w postaci domyślnej.
  • TRANSPOSE: wykonaj operację na przekształceniu macierzowym a.
  • ADJOINT: wykonaj operację na sprzężonym przekształceniu macierzowym a.

Dane wejściowe są odczytywane tylko z dolnego trójkąta elementu a, jeśli lower to true, lub z górnego trójkąta elementu a, jeśli nie. Dane wyjściowe są zwracane w tym samym trójkącie, a wartości w drugim trójkącie są zdefiniowane przez implementację.

Jeśli unit_diagonal ma wartość true, implementacja może założyć, że elementy diagonalne funkcji a są równe 1. W przeciwnym razie działanie jest nieokreślone.

W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) a tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1-C3)
(I2) b tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora (C1-C4)
(I3) left_side stała typu i1 (K3)
(I4) lower stała typu i1
(I5) unit_diagonal stała typu i1
(I6) transpose_a enum NO_TRANSPOSE, TRANSPOSEADJOINT

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora (C1)

Ograniczenia

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) Związek między shape(a)shape(b) jest zdefiniowany w ten sposób:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Przykłady

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tablice

Semantyka

Tworzy kwaternion result z wartości val.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) val zmienna liczba wartości (C1)

Wyniki

Nazwa Typ Ograniczenia
result krotka (C1)

Ograniczenia

  • (C1) result ma typ tuple<E0, ..., EN-1>, gdzie Ei = type(val[i]).

Przykłady

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Więcej przykładów

uniform_dequantize

Semantyka

Przekształca element po elemencie kwantyzowany tensor operand na tensor zmiennoprzecinkowy result zgodnie z parametrami kwantyzacji zdefiniowanymi przez typ operand.

Więcej formalnie: result = dequantize(operand).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor kwantowy (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu zmiennoprzecinkowego (C1), (C2)

Ograniczenia

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Przykłady

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semantyka

Wykonuje konwersję tensora zmiennoprzecinkowego lub kwantyzowanego tensora operand na kwantyzowany tensor result zgodnie z parametrami kwantyzacji zdefiniowanych przez typ result.

Bardziej formalnie,

  • Jeśli is_float(operand):
    • result = quantize(operand, type(result)).
  • Jeśli is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand tensor typu zmiennoprzecinkowego lub kwantyzowanego (C1), (C2)

Wyniki

Nazwa Typ Ograniczenia
result tensor kwantyzowany (C1), (C2)

Ograniczenia

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Przykłady

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

podczas gdy

Semantyka

Wyświetla dane wyjściowe z wykonania funkcji body co najmniej 0 razy, gdy funkcja cond zwraca wartość true. Bardziej oficjalnie semantyka można wyrazić za pomocą składni Pythona:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Zachowanie nieskończonego pętli jest nieznane (#383).

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) operand zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów (C1-C3)
(I2) cond funkcja (C1)
(I3) body funkcja (C2)

Wyniki

Nazwa Typ Ograniczenia
results zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów (C3)

Ograniczenia

  • (C1) cond ma typ (T0, ..., TN-1) -> tensor<i1>, gdzie Ti = type(operand[i]).
  • (C2) body ma typ (T0, ..., TN-1) -> (T0, ..., TN-1), gdzie Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Przykłady

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

 Więcej przykładów

xor

Semantyka

Wykonuje elementarną operację XOR na dwóch tensorach lhsrhs oraz zwraca tensor result. W zależności od typu elementu wykonuje te działania:

  • W przypadku wartości logicznych: XOR logiczny.
  • W przypadku liczb całkowitych: XOR bitowy.

Dane wejściowe

Etykieta Nazwa Typ Ograniczenia
(I1) lhs tensor typu logicznego lub całkowitego (C1)
(I2) rhs tensor typu logicznego lub całkowitego (C1)

Wyniki

Nazwa Typ Ograniczenia
result tensor typu logicznego lub całkowitego (C1)

Ograniczenia

  • (C1) type(lhs) = type(rhs) = type(result).

Przykłady

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Więcej przykładów

Interoperacyjność dialektów

Obecnie programy StableHLO w naturze czasami zawierają operacje, które nie są zdefiniowane przez StableHLO.

Moduł, funkcja, wywołanie i zwracanie

StableHLO używa operacji MLIR z upstream do operacji ModuleOp, FuncOp, CallOp i ReturnOp. Zrobiliśmy to, aby zapewnić lepszą współpracę z dotychczasowymi mechanizmami MLIR, ponieważ wiele przydatnych przejść jest napisanych z uwzględnieniem operacji FuncOp i ModuleOp, a wiele ścieżek kompilacji oczekuje obecności tych operacji. Do tych operacji stosowane są gwarancje pełnej zgodności. Jeśli cokolwiek w tych działaniach zmieni się w niekompatybilny sposób (np. zostanie usunięty), zostaną dodane odpowiedniki StableHLO, aby zachować zgodność.

CHLO

Opset CHLO zawiera operacje wyższego poziomu, które rozkładają się na StableHLO. Obecnie nie ma żadnych gwarancji zgodności w przypadku CHLO. Aby zapewnić zgodność, przed serializacją należy użyć przejścia chlo-legalize-to-stablehlo.

Operacje kształtu

W społeczności często stosuje się w programach dynamicznych StableHLO pewne operacje z podstawowych dialektów MLIR do wykonywania obliczeń kształtu. Najczęściej są to operacje shape, takie jak shape_of lub num_elements, operacje tensor, takie jak dim lub from_elements, oraz wbudowany typ index.

Dynamism RFC > O2 wskazuje, że te typy wykraczają poza zakres, ale ze względu na interoperacyjność uwzględniono w nim niektóre typy index. W przypadku tych typów i wersji nie ma gwarancji zgodności. Za pomocą przejścia shape-legalize-to-stablehlo można przekształcić te operacje w w pełni obsługiwane operacje StableHLO.

Wycofane operacje

Istnieje kilka operacji StableHLO, które zostały odziedziczone z MHLO, są wycofane i zostaną usunięte z StableHLO. Szczegółowe informacje na ten temat znajdziesz w dokumencie StableHLO v1.0 Cleanup #2283 (Czyszczenie StableHLO w wersji 1.0). W przypadku tych wycofanych rozwiązań występuje problem z lokalizatorem to #2340.

Operacje te można podzielić na kilka kategorii:

  • Kategoria operacji StableHLO „Nie w HLO” – początkowo były one częścią zestawu operacji StableHLO, ale później uznano, że nie pasują do niego: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • Nieużywane operacje – te operacje mogły być przydatne w danym momencie, ale były niedostatecznie rozwinięte lub ścieżki korzystające z tych operacji zostały przebudowane tak, aby ich już nie wymagały. Dotyczy to funkcji map, tuple (#598), get_tuple_element, rng, complex (#560) oraz funkcji window_reversal (#1181).

Niektóre z tych operacji można łatwo usunąć, ponieważ można je wyrazić za pomocą istniejących operacji (broadcast, create_token, cross-replica-sum, dot, unary_einsum) i zostaną usunięte po upływie obecnego okna zgodności (6 miesięcy). Inne są nadal analizowane do usunięcia (porównania: einsum, get_tuple_element, map, rng torch_index_select, tuple, complex, window_reversal). Oczekujemy na opinię społeczności. Te operacje zostaną usunięte lub dodane do specyfikacji z pełną obsługą. Dopóki nie poznamy tych funkcji, gwarantujemy zgodność tylko przez 6 miesięcy.

Wykonanie

Wykonywanie sekwencyjne

Program StableHLO jest wykonywany przez podanie wartości wejściowych do funkcji main i obliczenie wartości wyjściowych. Wartości wyjściowe funkcji są obliczane przez wykonanie grafu operacji z korzenia w odpowiednim elemencie return.

Kolejność wykonywania jest określana przez implementację, o ile jest zgodna z przepływem danych, czyli jeśli operacje są wykonywane przed ich użyciem. W StableHLO wszystkie operacje o efektach ubocznych zużywają 1 token i generują 1 token (wiele tokenów można zmultipleksować w jeden token za pomocą funkcji after_all), więc kolejność wykonywania efektów ubocznych jest również zgodna z dataflow. Na przykład w programie poniżej są 2 możliwe zamówienia: %0%1%2return i %1%0%2return.

func.func @main() -> tensor<f64> {
  %0 = stablehlo.constant dense<1.0> : tensor<f64>
  %1 = stablehlo.constant dense<2.0> : tensor<f64>
  %2 = stablehlo.add %0, %1 : tensor<f64>
  return %2 : tensor<f64>
}

Bardziej oficjalnie proces StableHLO składa się z tych elementów: 1) program StableHLO, 2) stanów operacji (jeszcze niewykonany, już wykonany) oraz 3) wartości pośrednich, nad którymi dany proces pracuje. Proces zaczyna się od wartości wejściowych funkcji main, przechodzi przez wykres operacji aktualizowania stanów operacji i wartości pośrednich, a kończy się wartościami wyjściowymi. Dalsze formalizowanie jest jeszcze do ustalenia (#484).

równoległe wykonanie,

Programy StableHLO mogą być wykonywane równolegle, podzielone na siatkę procesów 2D num_replicas według num_partitions, z których obydwa mają typ ui32.

W siatce procesów StableHLO jednocześnie wykonywanych jest num_replicas * num_partitions procesów StableHLO. Każdy proces ma unikalny element process_id = (replica_id, partition_id), gdzie replica_idreplica_ids = range(num_replicas)partition_idpartition_ids = range(num_partitions) mają typ ui32.

Rozmiar siatki procesów jest znany statycznie w przypadku każdego programu (w przyszłości planujemy, aby był on częścią programów StableHLO #650), a pozycja w siatce procesów jest znana statycznie w przypadku każdego procesu. Każdy proces ma dostęp do swojej pozycji w siatce procesów za pomocą operacji replica_idpartition_id.

W siatce procesów programy mogą być takie same (w stylu „Jedno program, wiele danych”), różne (w stylu „Wiele programów, wiele danych”) lub mieścić się gdzieś pośrodku. W przyszłości planujemy wprowadzić obsługę innych idiomów definiowania równoległych programów StableHLO, w tym GSPMD (#619).

W ramach siatki procesów procesy są w większości niezależne od siebie – mają oddzielne stany operacji, oddzielne wartości wejścia/pośrednie/wyjścia i większość operacji jest wykonywana oddzielnie w ramach procesów, z wyjątkiem niewielkiej liczby operacji zbiorczych opisanych poniżej.

Ponieważ większość operacji wykonuje się tylko z użyciem wartości z tego samego procesu, zwykle odwołania do tych wartości za pomocą ich nazw są jednoznaczne. Jednak opis semantyki działań zbiorowych jest niewystarczający. Powoduje to powstanie zapisu name@process_id, który odwołuje się do wartości name w konkretnym procesie. (z tego punktu widzenia niekwalifikowana wartość name może być traktowana jako skrót od wartości name@(replica_id(), partition_id())).

Kolejność wykonywania procesów jest określana przez implementację, z wyjątkiem synchronizacji wprowadzonej przez komunikację punkt-punkt i operacje zbiorcze, jak opisano poniżej.

Komunikacja punkt-punkt

Procesy StableHLO mogą komunikować się ze sobą za pomocą kanałów StableHLO. Kanał jest reprezentowany przez dodatni identyfikator typu si64. Za pomocą różnych operacji można wysyłać wartości do kanałów i odbierać je z kanałów.

Dalsze formalizowanie, np. skąd pochodzą te identyfikatory kanałów, jak procesy programów je rozpoznają i jakie synchronizacja jest przez nie wprowadzana, jest jeszcze do ustalenia (#484).

Komunikacja strumieniowa

Każdy proces StableHLO ma dostęp do 2 interfejsów strumieniowego przesyłania danych:

  • Infeed, z których można odczytać treści.
  • Outfeed, na którym można zapisać dane.

W przeciwieństwie do kanałów, które służą do komunikacji między procesami, a więc mają procesy po obu końcach, dane wejściowe i dane wyjściowe mają zdefiniowany drugi koniec w ramach implementacji.

Dalszą formalizację, np. o tym, jak strumieniowa komunikacja wpływa na kolejność wykonywania i jaki rodzaj synchronizacji jest przez nie wprowadzana, to do ustalenia (#484).

Operacje zbiorcze

W StableHLO jest 6 zbiorowych operacji: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permutereduce_scatter. Wszystkie te operacje dzielą procesy w sieci procesów StableHLO na grupy procesów StableHLO i wykonują wspólne obliczenia w każdej z nich niezależnie od innych grup procesów.

W ramach każdej grupy procesów operacje zbiorcze mogą wprowadzać barierę synchronizacji. Dalsza formalizacja, np. określenie, kiedy dokładnie ma miejsce synchronizacja, jak dokładnie procesy docierają do tej bariery i co się dzieje, jeśli tego nie zrobią, jest jeszcze nierozstrzygnięta (#484).

Jeśli grupa procesów obejmuje komunikację między partycjami, czyli w grupie procesów są procesy o różnych identyfikatorach partycji, wykonanie operacji zbiorczej wymaga kanału, a operacja zbiorcza musi zawierać dodatnią wartość channel_id typu si64. Komunikacja między replikami nie wymaga kanałów.

Obliczenia wykonywane przez zbiorcze operacje są specyficzne dla poszczególnych operacji i są opisane w poszczególnych sekcjach dotyczących operacji. Jednak strategie, według których siatka procesów jest dzielona na grupy procesów, są wspólne dla tych operacji i opisane w tej sekcji. W formalnym ujęciu StableHLO obsługuje 4 strategię:

cross_replica

W ramach każdej grupy procesów odbywa się tylko komunikacja między replikami. Ta strategia przyjmuje replica_groups, czyli listę list identyfikatorów replik, i wylicza iloczyn kartezjański replica_groupspartition_ids. replica_groupsmusi zawierać unikalne elementy i objmować wszystkie replica_ids. Bardziej formalnie, używając składni Pythona:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Na przykład w przypadku wartości replica_groups = [[0, 1], [2, 3]]num_partitions = 2 funkcja cross_replica zwróci wartość [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

W każdej grupie procesów odbywa się tylko komunikacja między partycjami. Ta strategia wykorzystuje partition_groups – listę list identyfikatorów partycji – i oblicza kartezjański iloczyn wartości partition_groups przez replica_ids. partition_groups musi zawierać unikalne elementy i objmować wszystkie partition_ids. Bardziej oficjalnie, używając składni Pythona:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Na przykład w przypadku wartości partition_groups = [[0, 1]]num_replicas = 4 funkcja cross_partition zwróci wartość [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

W każdej grupie procesów mogą występować zarówno komunikaty między replikami, jak i między partycjami. Ta strategia wykorzystuje replica_groups – listę list identyfikatorów replik – i oblicza iloczy kartezjańskie każdego elementu replica_group według parametru partition_ids. replica_groups musi zawierać unikalne elementy i obejmować wszystkie replica_ids. Bardziej oficjalnie, używając składni Pythona:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Na przykład w przypadku wartości replica_groups = [[0, 1], [2, 3]]num_partitions = 2 funkcja cross_replica_and_partition zwróci wartość [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Ta strategia przyjmuje flattened_id_groups – listę list „spłaszczonego” identyfikatora procesu w formie replica_id * num_partitions + partition_id – i przekształca je w identyfikatory procesów. flattened_id_groups musi zawierać unikalne elementy i pokrywać wszystkie elementy process_ids. Bardziej formalnie, używając składni Pythona:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Na przykład w przypadku flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4num_partitions = 2, flattened_ids zwróci wartość [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Dokładność

Obecnie StableHLO nie gwarantuje dokładności liczbowej, ale może się to zmienić w przyszłości (#1156).

Semantyka wykonania kwantowanej operacji

Interpretacja kwantowanych operacji StableHLO może się różnić w zależności od wymagań i możliwości sprzętowych. Na przykład niektóre urządzenia mogą interpretować operacje kwantyzacji za pomocą strategii „dequantyzacja, wykonanie operacji zmiennoprzecinkowej i na koniec ponowna kwantyzacja”. Inne mogą wykonywać całe obliczenia z użyciem arytmetyki całkowitej. Dlatego interpretacja kwantyzowanych operacji StableHLO zależy wyłącznie od konkretnego wdrożenia. Interpretacja kwantyzacji hybrydowej (#1575) powinna opierać się na jej semantyce zgodnie ze specyfikacją (na stronie 1792).

Błędy

Programy StableHLO są weryfikowane za pomocą obszernego zbioru ograniczeń dotyczących poszczególnych operacji, co wyklucza wiele klas błędów przed czasem wykonywania. Nadal jednak mogą wystąpić błędy, np. przepełnienie liczb całkowitych, dostęp poza zakresem itp. O ile nie określono inaczej, wszystkie te błędy powodują zachowanie określone przez implementację, ale może się to w przyszłości zmienić (#1157).

Wyjątki dotyczące liczb zmiennoprzecinkowych

Wyjątkiem od tej reguły są wyjątki o typie zmiennoprzecinkowym w programach StableHLO, które mają dobrze zdefiniowane działanie. Operacje, które powodują wyjątki zdefiniowane przez standard IEEE-754 (nieprawidłowa operacja, dzielenie przez 0, przepełnienie, niedopełnienie lub nieprecyzyjne wyjątki) dają wyniki domyślne (zdefiniowane przez standard) i kontynuują wykonywanie bez podnoszenia odpowiedniej flagi stanu; podobnie jak obsługa wyjątku raiseNoFlag ze standardu. Wyjątki dotyczące operacji niestandardowych (np. złożonej arytmetyki i niektórych funkcji transcendentalnych) są zdefiniowane przez implementację.

niezgodności kształtów,

StableHLO obsługuje tensory o dynamicznej wielkości. Jednak kształty muszą być zgodne w czasie wykonywania, w przeciwnym razie zachowanie jest nieokreślone. StableHLO nie udostępnia bezpośrednio operacji, która może potwierdzić, że tensor ma określony kształt w czasie działania. Generowanie prawidłowego kodu jest obowiązkiem producenta.

Przykładem prawidłowego programu jest program poniżej. Jednak w czasie działania dokładne kształty obiektów %arg0 i %arg1 muszą być takie same. W przeciwnym razie działanie programu będzie niezdefiniowane:

func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
    %0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
    return %0 : tensor<?xi32>
}

Notacja

Do opisania składni w tym dokumencie użyto zmodyfikowanego rodzaju składni EBNF (ISO/IEC 14977:1996, Wikipedia). 1) regułę zdefiniowano za pomocą metody ::=, a nie =:

2) konkatenacja jest wyrażana za pomocą juxtaposition, a nie ,.

Do opisu semantyki (np. w sekcjach „Typy”, „Stałe” i „Operacje”) używamy formuł opartych na składni Pythona rozszerzonej o obsługę zwięzłego wyrażania operacji na tablicach, jak opisano poniżej. Takie rozwiązanie sprawdza się w przypadku małych fragmentów kodu, ale w rzadkich przypadkach, gdy potrzebne są większe fragmenty kodu, używamy składni Pythona, która jest zawsze wprowadzana bezpośrednio.

Wzory

Na przykładzie ze specyfikacji dot_general omówimy, jak działają formuły. Jedno z ograniczeń tej operacji wygląda tak: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Nazwy używane w tej formule pochodzą z 2 źródeł: 1) funkcji globalnych, np. dim, 2) definicji elementów członkowskich odpowiadającego elementu programu, np. lhs, lhs_batching_dimensions, rhsrhs_batching_dimensions, zdefiniowanych w sekcji „Wejścia” w funkcji dot_general.

Jak już wspomnieliśmy, składnia tej formuły jest oparta na Pythonie z kilkoma rozszerzeniami ułatwiającymi zwiężenie kodu. Aby lepiej zrozumieć tę formułę, przekształcimy ją w tradycyjną składnię Pythona.

A) W tych formułach używamy wyrażenia = do reprezentowania równości, więc pierwszym krokiem do uzyskania składni Pythona jest zastąpienie = przez == w ten sposób: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Formuły te obsługują też wielokropki (...), które zamieniają wyrażenia skalarne w wyrażenia tensorowe. Krótko mówiąc, f(xs...) oznacza mniej więcej „dla każdego wektora x w tensorze xs oblicz wektor f(x), a potem zwracaj wszystkie te wyniki wektorów jako wynik tensora”. W standardowej składni Pythona nasza przykładowa formuła zamienia się na: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Dzięki elipsom często można uniknąć pracy na poziomie poszczególnych skalarów. W niektórych trudnych przypadkach można jednak użyć nieformalnej składni na niższym poziomie, np. w formule start_indices[bi0, ..., :, ..., biN] ze specyfikacji gather. W trosce o zwiększenie zwięzłości nie podajemy dokładnego formalizmu, w którym można przetłumaczyć taką składnię na zwykły Python. Mamy nadzieję, że w każdym przypadku będzie ona intuicyjna. Jeśli zauważysz, że niektóre formuły są nieczytelne, daj nam znać, a spróbujemy je ulepszyć.

Zauważysz też, że formuły używają wielokropków do rozwijania wszelkich rodzajów list, w tym tensorów, list tensorów (które np. mogą powstać na podstawie różnej liczby tensorów) itp. W tym innym obszarze nie ma dokładnego formalności (np. listy nie są nawet częścią intuicyjnego systemu StaableHLO).

C) Ostatnim wartym uwagi sposobem zapisu, którego używamy, jest domyślne przesyłanie. Chociaż przesunięcie StableHLO nie obsługuje jawnego nadawania, formuły już tak, co również służy zwiększeniu zwięzłości. Krótko mówiąc, jeśli w kontekście, w którym oczekuje się tensora, używany jest element skalarny, element skalarny jest rozprowadzany do oczekiwanego kształtu.

Aby kontynuować przykład z dot_general, zastosuj tu kolejne ograniczenie: 0 <= lhs_batching_dimensions < rank(lhs). Zgodnie z definicją w specyfikacji dot_general element lhs_batching_dimensions jest tensorem, ale zarówno 0, jak i rank(lhs) są skalarami. Gdy zastosujemy przekazywanie niejawne, formuła zmieni się na [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Po zastosowaniu do konkretnej operacji dot_general ta formuła oceni tensor wartości logicznych. Gdy formuły są używane jako ograniczenia, ograniczenie jest spełnione, jeśli formuła zwraca wartość true lub tensor zawierający tylko elementy true.

Nazwy

W formułach zakres leksykalny obejmuje: 1) funkcje globalne, 2) definicje członków,

3) definicje lokalne. Poniżej znajdziesz listę funkcji globalnych. Lista definicji elementów zależy od elementu programu, do którego zastosowano notację:

  • W przypadku operacji definicje elementów obejmują nazwy wprowadzone w sekcjach „Wejścia” i „Wyjścia”.
  • W przypadku innych elementów definicje członków obejmują części strukturalne elementu programu, nazwane na podstawie odpowiednich nieterminali EBNF. W większości przypadków nazwy tych części strukturalnych są uzyskiwane przez konwersję nazw nieterminali na format snake case (np. IntegerLiteral => integer_literal), ale czasami nazwy są w tym procesie skracane (np. QuantizationStorageType => storage_type). W takim przypadku nazwy są wprowadzane w sposób jawny, podobnie jak sekcje „Wejścia” i „Wyjścia” w specyfikacjach operacji.
  • Definicje członków zawsze zawierają element self, który odnosi się do odpowiedniego elementu programu.

Wartości

Podczas obliczania formuł są one używane do obsługi tych typów wartości: 1) Value (rzeczywiste wartości, np. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; zawsze znają swój typ), 2) Placeholder (przyszłe wartości, np. lhs, rhs lub result; ich rzeczywiste wartości nie są jeszcze znane, znane są tylko ich typy), 3) Type (typy zdefiniowane w sekcji „Typy”), 4) Function (funkcje globalne zdefiniowane w sekcji „Funkcje”).

W zależności od kontekstu nazwy mogą się odnosić do różnych wartości. W szczególności sekcja „Semantyka” w przypadku operacji (i odpowiednich sekcji w przypadku innych elementów programu) definiuje logikę czasu wykonywania, dzięki czemu wszystkie dane wejściowe są dostępne jako Value. Z kolei sekcja „Ograniczenia” odnosząca się do operacji (i ich odpowiedników) definiuje logikę „czas kompilowania”, tj.coś, co jest zwykle wykonywane przed uruchomieniem. Dlatego jako Value dostępne są tylko stałe dane wejściowe, a inne dane wejściowe – tylko jako Placeholder.

Nazwy W sekcji „Semantyka” W sekcji „Ograniczenia”
Funkcje globalne Function Function
stałe wejścia, Value Value
Nieciągłe dane wejściowe Value Placeholder
Wyniki Value Placeholder
Definicje lokalne Zależy od definicji Zależy od definicji

Przeanalizujmy przykładową operację transpose:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

W przypadku tej operacji permutation jest stałą, więc jest dostępna jako Value zarówno w semantyce, jak i w ograniczeniach. Z kolei operandresult są dostępne jako Value w semantyce, ale tylko jako Placeholder w ograniczeniach.

Funkcje

Konstrukcja typów

Nie ma funkcji, których można używać do tworzenia typów. Zamiast tego używamy bezpośrednio składni typu, ponieważ jest ona zazwyczaj bardziej zwięzła. Na przykład (tensor<E>, tensor<E>) -> (tensor<E>) zamiast function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Funkcje w typach

  • element_type jest zdefiniowany na typach tensorów i typach zaokrąglonych tensorów oraz zwraca odpowiednio część TensorElementType lub QuantizedTensorElementType odpowiadającego TensorType lub QuantizedTensorType.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value to skrót do: is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value to skrót od is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool sprawdza, czy typ x można podnieść do typu y. Jeśli xy to QuantizedTensorElementType, promocja jest stosowana tylko do storage_type. Ta konkretna wersja promocji jest obecnie używana w kontekście obliczeń redukcji (więcej informacji znajdziesz w RFC).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value to skrót do is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Dostępne dla wszystkich typów. Na przykład is_float(x) zwraca true, jeśli x jest wartością typu FloatType. Jeśli x jest wartością lub miejscem zastępczym, ta funkcja jest skrótem funkcji is_type_name(type(x)).

  • max_value(x: Type) -> Value zwraca maksymalną wartość TensorElementType. Jeśli x nie jest TensorElementType, zwraca None.

  • min_value(x: Type) -> Value zwraca minimalną możliwą wartość TensorElementType. Jeśli x nie jest TensorElementType, zwraca None.

  • member_name(x: Value | Placeholder | Type) -> Any. Dostępne dla wszystkich definicji członków member_name wszystkich typów. Na przykład tensor_element_type(x) zwraca część TensorElementType odpowiadającego elementu TensorType. Jeśli x jest wartością lub zmienną, ta funkcja jest skrótem do member_name(type(x)). Jeśli x nie jest typem, który ma odpowiedni element, wartość lub zastępnik tego typu, zwraca None.

  • is_empty_algorithm(*args: Type) sprawdza, czy wszystkie pola algorytmu dot są ustawione na None. Jest to konieczne, ponieważ algorytmy dot mają określone przez implementację domyślne zachowania, więc podanie domyślnej wartości byłoby nieprawidłowe.

Budowa wartości

  • operation_name(*xs: Value | Type) -> Value. Dostępne w przypadku wszystkich operacji. Na przykład funkcja add(lhs, rhs) przyjmuje 2 wartości tensora lhs i rhs oraz zwraca wynik oceny operacji add z tymi danymi wejściowymi. W przypadku niektórych operacji, np. broadcast_in_dim, typy ich danych wyjściowych są „nośne”, czyli potrzebne do oceny operacji. W tym przypadku funkcja przyjmuje te typy jako argumenty.

Funkcje wartości

  • Dostępne są wszystkie operatory i funkcje Pythona. Na przykład w Pythonie dostępne są zarówno subskrypcje, jak i wycinki, które można indeksować w tensorach, kwantowych tensorach i tuplach.

  • to_destination_type(x: Value, destination_type: Type) -> Value jest zdefiniowany na tensorach i zwraca przekonwertowaną wartość x na podstawie wartości type(x) i destination_type w następujący sposób:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Trwają wstępne dyskusje na temat połączenia operacji convert, uniform_quantize i uniform_dequantize (#1576). Po scaleniu nie potrzebujesz powyższej funkcji i możemy zamiast niej użyć nazwy operacji dla funkcji convert.

  • Funkcja is_nan(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli wszystkie elementy x są równe NaN, a w przeciwnym razie zwraca wartość false. Jeśli x nie jest tensorem, zwraca None.

  • Funkcja is_sorted(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli elementy x są posortowane w kolejności rosnącej według rosnącej kolejności leksykograficznej ich indeksów lub false w innym przypadku. Jeśli x nie jest tensorem, zwraca None.

  • Funkcja is_unique(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli x nie zawiera podwójnych elementów, a w przeciwnym razie zwraca wartość false. Jeśli x nie jest tensorem, zwraca None.

  • member_name(x: Value) -> Any jest zdefiniowany dla wszystkich definicji elementów member_name wszystkich wartości. Na przykład real_part(x) zwraca element RealPart należący do odpowiadającego mu elementu ComplexConstant. Jeśli x nie jest wartością, która ma odpowiedni element, zwraca None.

  • Funkcja same(x: Value) -> Value jest zdefiniowana na tensorach i zwraca wartość true, jeśli wszystkie elementy tensora x są sobie równe, a w przeciwnym razie zwraca wartość false. Jeśli tensor nie ma elementów, liczy się to jako „wszystkie są sobie równe”, tzn. funkcja zwraca true. Jeśli x nie jest tensorem, zwraca None.

  • split(x: Value, num_results: Value, axis: Value) -> Value jest zdefiniowany na tensorach i zwraca num_results przekroje x wzdłuż osi axis. Jeśli x nie jest tensorem ani dim(x, axis) % num_results != 0, zwraca None.

  • is_defined_in_parent_scope(x: Value) -> Value jest zdefiniowana na ciągach znaków i zwraca true, jeśli x to nazwa funkcji zdefiniowanej w tym samym zakresie co funkcja nadrzędna dla odpowiedniej op.

  • Argument is_namespaced_op_name(x: Value) -> Value jest zdefiniowany na ciągach znaków i zwraca wartość true, jeśli x jest prawidłową nazwą op, czyli spełnia to wyrażenie regularne: [a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+

Obliczenia kształtu

  • axes(x: Value | Placeholder | Type) -> Value to skrót do opcji range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value to skrót do shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List to skrót do list(map(lambda axis: dim(x, axis), axes)).

  • Funkcja index_space(x: Value | Placeholder | Type) -> Value jest zdefiniowana na tensorach i zwraca indeksy size(x) dla odpowiednich TensorType posortowanych w rosnącym porządku alfabetycznym, np. [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Jeśli x nie jest typem tensora, skonwertowanego typu tensora, wartości lub zastępnika jednego z tych typów, zwraca None.

  • rank(x: Value | Placeholder | Type) -> Value to skrót do size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value jest zdefiniowana w sekcji „Funkcje według typów” za pomocą member_name.

  • size(x: Value | Placeholder | Type) -> Value to skrót do reduce(lambda x, y: x * y, shape(x)).

Obliczenia kwantyzacji

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type to skrót od element_type(baseline_type(x)).

  • Funkcja baseline_type jest zdefiniowana dla typów tensorów i typów zaokrąglonych tensorów. Przekształca je w „wartość bazową”, czyli typ o tym samym kształcie, ale z parametrami zaokrąglenia typu elementu skonfigurowanymi na wartości domyślne. To przydatna sztuczka, która pozwala jednolicie porównać zarówno tensor, jak i kwantyzowane typy tensorów, co jest potrzebne dość często. W przypadku typów kwantyzowanych umożliwia to porównywanie typów ignorujących parametry kwantyzacji, czyli shape, storage_type, expressed_type, storage_min, storage_max i quantization_dimension (w przypadku typu kwantyzowanego na oś), ale wartości scales i zero points mogą się różnić.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize jest zdefiniowany na podstawie zliczanych typów tensorów i przekształca je w typy tensorów zmiennoprzecinkowych. Polega to na konwertowaniu elementów dyskretnych, które reprezentują wartości całkowite typu magazynowania, na odpowiadające im wartości zmiennoprzecinkowe typu wyrażonego za pomocą punktu zerowego i skali powiązanej z elementem dyskretnym.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize jest zdefiniowany na typach tensorów zmiennoprzecinkowych i przekształca je w typy skonwertowanych tensorów. Polega to na konwersji wartości zmiennoprzecinkowych wyrażonego typu na odpowiadające wartości całkowite typu magazynu za pomocą punktu zerowego i skali powiązanej z kwantowanym typem elementu.
def quantize(x: Value, result_type: Type) -> Value:
  assert is_float(x) and is_quantized(result_type)
  zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
  converted_zero_points = convert(zero_points, expressed_type(result_type))
  converted_min = convert(storage_min(result_type), expressed_type(result_type))
  converted_max = convert(storage_max(result_type), expressed_type(result_type))

  x_scaled = x / compute_scales(result_type, type(x))
  x_scaled_add_zp = x_scaled + converted_zero_points
  x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
  x_rounded = round_nearest_even(x_clamped)
  return convert(x_rounded, result_type)
  • dequantize_op_quantize służy do określania obliczeń elementarnych na kwantyzowaniach tensorów. Dequantuje, czyli zamienia elementy poddane kwantyzacji na ich wyrażone typy, a potem wykonuje operację, a następnie ponownie kwantyzuje, czyli zamienia wyniki z powrotem na ich typy magazynowania. Obecnie ta funkcja działa tylko w przypadku kwantyzacji natężenia. Trwa kwantyzacja według osi (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)
  • hybrid_dequantize_then_op służy do określania kwantyzacji tylko wagi dla operacji hybrydowej, która przyjmuje lewą stronę w typie zmiennoprzecinkowym, a prawą – w typie skwantowanym. Dekwantuje zagregowane dane wejściowe w ich wyrażonych typach i wykonuje obliczenia w typie float. Typ elementu wektora lewego argumentu typu float i wyrażony typ zagregowanego prawego argumentu powinny być identyczne.
def hybrid_dequantize_then_op(op, lhs, rhs):
  assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
  return op(lhs, dequantize(rhs))

Obliczenia siatki

  • cross_partition(replica_groups: Value) -> Value. Zobacz sekcję „cross_replica” powyżej.

  • cross_replica(replica_groups: Value) -> Value. Zobacz sekcję „cross_replica” powyżej.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Zobacz sekcję „cross_replica_and_partition” powyżej.

  • flattened_ids(replica_groups: Value) -> Value. Zobacz sekcję „flattened_ids” powyżej.

Dynamizm

Wartości StableHLO mogą mieć dynamiczne rozmiary wymiarów, np. tensor<?xi64>. Wartości StableHLO nie mogą jednak mieć dynamicznej liczby wymiarów (nieuporządkowany dynamizm, np. tensor<*xi64>). Obliczenia i wyniki mogą używać dynamicznych rozmiarów wymiarów, nawet jeśli istnieją ograniczenia rozmiarów. Jeśli to możliwe, ograniczenia są weryfikowane statycznie. W przeciwnym razie są odraczane do działania w czasie działania i niezgodności powodują niezdefiniowane zachowanie. Przykłady znajdziesz poniżej.

Niezgodność kształtów w przypadku operacji jednoelementowych

Rozważ ten program:

func.func @foo(%arg0: tensor<?xf64>) {
  %0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
  return
}

Taki program jest nietypowy, ponieważ zwykle wiadomo, jak wygląda wynik, ale nie wiadomo, jak wyglądają dane wejściowe. To jednak prawidłowy program StableHLO. W tym programie nie można statycznie zweryfikować operacji abs, ponieważ nie znamy dokładnego kształtu operandu. Kształty z pewnością są jednak zgodne i można to sprawdzić statycznie: okazało się, że w czasie działania obiekt ? to 2 i nie będzie żadnych problemów. Jednak ? może też okazać się inną liczbą całkowitą, w którym przypadku działanie jest nieokreślone.

Pamiętaj, że jeśli rozmiar wymiaru jest dynamiczny, nie może wystąpić niezdefiniowane działanie. Rzeczywiście nie ma „oczekiwanego” rozmiaru, więc nie może być niezgodności.

Niezgodność kształtów w przypadku operacji binarnych element po elemencie

Rozważ ten program:

func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
  %0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
  return
}

W przypadku operacji binarnych elementowych kształty danych wejściowych i wyniku muszą być zgodne w czasie wykonywania. Podczas kompilowania wymiary statyczne muszą być równe. W przeciwnym razie wystarczy, że będą tylko zgodne. Jeśli cokolwiek w danych wejściowych jest wymiarem dynamicznym, może to spowodować nieokreślone działanie w czasie wykonywania, ponieważ rozmiar dynamiczny może nie pasować do odpowiadającego mu rozmiaru w innym operandzie (czyli stałego lub dynamicznego). Jeśli wszystkie dane wejściowe są statyczne, nie ma znaczenia, czy wynik jest dynamiczny: wymiary znane statycznie będą sprawdzane statycznie, a wymiary dynamiczne nie narzucają żadnych ograniczeń.

Niezgodności kształtów w operacjach, które przyjmują kształt wyjściowy jako operand

Rozważ ten program:

func.func @foo(%arg0: tensor<2xi32>) {
  %0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
  return
}

Wartości w operandzie kształtu w czasie wykonywania programu muszą odpowiadać kształtowi wyniku. W przeciwnym razie zachowanie jest nieokreślone. Oznacza to, że w czasie wykonywania %arg0 musi mieć wartość dense<[3, 4]> : tensor<2xi32>. Jeśli operand kształtu jest stały, można to zweryfikować statycznie. Jeśli kształt wyniku jest w pełni dynamiczny, nie może wystąpić rozbieżność.