StableHLO to zestaw operacji wysokiego poziomu (HLO) w modelach systemów uczących się. StableHLO działa jako warstwa przenoszenia między różnymi platformami ML i kompilatorami ML: platformy ML tworzące programy StableHLO są zgodne z kompilatorami ML, które korzystają z programów StableHLO.
Naszym celem jest uproszczenie i przyspieszenie programowania ML przez zapewnienie większej interoperacyjności między różnymi platformami ML (np. TensorFlow, JAX i PyTorch) i kompilatorami ML (np. XLA i IREE). W tym celu zamieściliśmy specyfikację języka programowania StableHLO.
Specyfikacja zawiera 3 główne sekcje. Sekcja Programy opisuje strukturę programów StableHLO, które składają się z funkcji StableHLO, które składają się z operacji StableHLO. W tej strukturze sekcja Operacje określa semantykę poszczególnych operacji. Sekcja Wykonanie zawiera semantykę wszystkich tych operacji wykonywanych razem w programie. W sekcji Notacja omawiamy też zapis stosowany w całej specyfikacji.
Aby wyświetlić specyfikację z poprzedniej wersji StableHLO, otwórz repozytorium w oznaczonym wydaniu. Na przykład specyfikacja StableHLO w wersji 0.19.0. Aby wyświetlić zmiany wprowadzone w każdej małej wersji StableHLO, zapoznaj się z logiem wersji w pliku VhloDialect.td.
Programy
Program ::= {Func}
Programy StableHLO składają się z dowolnej liczby funkcji StableHLO.
Poniżej znajduje się przykładowy program z funkcją @main, która ma 3 parametry wejściowe (%image, %weights i %bias) oraz 1 wyjście. Treść funkcji zawiera 6 operacji.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funkcje
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Funkcje StableHLO (nazywane też funkcjami nazwanymi) mają identyfikator, wejścia/wyjścia i ciało. W przyszłości planujemy wprowadzić dodatkowe metadane funkcji, aby zwiększyć zgodność z HLO (#425, #626, #740, #744).
Identyfikatory
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Identyfikatory StableHLO są podobne do identyfikatorów w wielu językach programowania, ale mają 2 szczególne cechy: 1) wszystkie identyfikatory mają sygnatury, które odróżniają różne rodzaje identyfikatorów, 2) identyfikatory wartości mogą być całkowicie numeryczne, aby uprościć generowanie programów StableHLO.
Typy
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Typy StableHLO są podzielone na typy wartości (nazywane też typami najwyższej klasy), które reprezentują wartości StableHLO, oraz typy bez wartości, które opisują inne elementy programu. Typy StableHLO są podobne do typów w wielu językach programowania, a ich główną osobliwością jest specyfika danej domeny, która powoduje pewne nietypowe wyniki (np. typy skalarne nie są typami wartości).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Typy Tensor reprezentują tensory, czyli tablice wielowymiarowe. Mają one kształt i typ elementu, gdzie kształt reprezentuje nieujemne lub nieznane rozmiary wymiarów w kolejności rosnącej odpowiadających im wymiarów (zwanych też osiami), których numery wahają się od 0 do R-1. Liczba wymiarów R to ranking. Na przykład tensor<2x3xf32> to typ tensora o kształcie 2x3 i typie elementu f32. Ma 2 wymiary (czyli 2 osi) – 0 i 1, których rozmiary wynoszą odpowiednio 2 i 3. Jego pozycja to 2.
Kształty mogą być częściowo lub całkowicie nieznane (dynamiczne), np. tensor<?x2xf64> jest częściowo nieznany, a tensor<?x?xf64> jest całkowicie nieznany. Dynamiczne rozmiary wymiarów są reprezentowane za pomocą ?. Kształty nie mogą być niesklasyfikowane.
W przyszłości planujemy rozszerzyć typy tensorów poza rozmiary wymiarów i typy elementów, np. o układy (#629) i sparsity (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| Nazwa | Typ | Ograniczenia |
|---|---|---|
storage_type |
typ liczba całkowita | (C1-C3), (C8) |
storage_min |
stała liczbowa | (C1), (C3), (C7) |
storage_max |
stała liczbowa | (C2), (C3), (C7) |
expressed_type |
typ zmiennoprzecinkowy, | (C4) |
quantization_dimension |
opcjonalna liczba całkowita | (C10-C12) |
scales |
zmienna liczba stałych zmiennoprzecinkowych | (C4-C6), (C9), (C10), (C13) |
zero_points |
zmienna liczba stałych całkowitych | (C7-C9) |
Typy elementów z kwantyzacją reprezentują wartości całkowite typu magazynowania w zakresie od storage_min do storage_max (łącznie) odpowiadające wartościom zmiennoprzecinkowym typu wyrażonego. Dla danej wartości całkowitej i odpowiadającą wartość zmiennoprzecinkową f można obliczyć jako f = (i - zero_point) * scale, gdzie scale i zero_point to parametry kwantyzacji. Parametry storage_min i storage_max są opcjonalne w gramatyce, ale mają wartości domyślne odpowiednio min_value(storage_type) i max_value(storage_type). Elementy typu „kwantowany” mają te ograniczenia:
- (C1)
type(storage_min) = storage_type. - (C2)
type(storage_max) = storage_type. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type). - (C4)
type(scales...) = expressed_type. - (C5)
0 < scales. - (C6)
is_finite(scales...). - (C7)
storage_min <= zero_points <= storage_max. - (C8)
type(zero_points...) = storage_type. - (C9)
size(scales) = size(zero_points). - (C10) Jeśli
is_empty(quantization_dimension), tosize(scales) = 1. - (C11)
0 <= quantization_dimension.
Obecnie QuantizationScale jest stałą zmiennoprzecinkową, ale istnieje duże zainteresowanie skalami opartymi na liczbach całkowitych, reprezentowanymi przez mnożniki i przesunięcia. W najbliższej przyszłości planujemy wziąć to pod uwagę (#1404).
Trwa dyskusja na temat semantyki funkcji QuantizationZeroPoint, w tym jej typu, wartości i tego, czy w skwantyzowanym typie tensora może istnieć tylko jeden punkt zerowy, czy może on mieć większą liczbę punktów zerowych. Na podstawie wyników tej dyskusji specyfikacja dotycząca punktów 0 może w przyszłości ulec zmianie (#1405).
Kolejna trwająca dyskusja dotyczy semantyki QuantizationStorageMin i QuantizationStorageMax, aby określić, czy należy nałożyć jakieś ograniczenia na te wartości i wartości skończonych tensorów (#1406).
Na koniec planujemy zgłębić temat przedstawiania nieznanych skal i punktów zerowych w sposób podobny do tego, w jaki planujemy przedstawiać nieznane rozmiary wymiarów (#1407).
Typy kwantyzowanych tensorów reprezentują tensory z kwantyzowanymi elementami. Te tensory są dokładnie takie same jak zwykłe tensory, z tą różnicą, że ich elementy mają kwantowane typy elementów zamiast zwykłych typów elementów.
W przypadku skończonych tensorów kwantyzowanie może być na poziomie tensora, co oznacza, że dla całego tensora jest jeden element scale i zero_point, lub na poziomie osi, co oznacza, że dla danej osi danego wymiaru quantization_dimension jest wiele elementów scales i zero_points. Bardziej formalnie, w tensorze t z kwantyzacją na osi występują dim(t, quantization_dimension) krojenia quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] itd. Wszystkie elementy w itym krojeniu używają wartości scales[i] i zero_points[i] jako parametrów kwantowania. Typy zaokrąglonych tensorów mają te ograniczenia:
- W przypadku kwantyzacji na centrów:
- Brak dodatkowych ograniczeń.
- W przypadku kwantyzacji na oś:
- (C12)
quantization_dimension < rank(self). - (C13)
dim(self, quantization_dimension) = size(scales).
- (C12)
TokenType ::= 'token'
Typy tokenów to tokeny, czyli nieprzezroczyste wartości tworzone i wykorzystywane przez niektóre operacje. Tokeny są używane do określania kolejności wykonywania operacji zgodnie z opisem w sekcji Wykonywanie.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Typy tupla reprezentują tuple, czyli listy niejednorodne. Kropki to starsza funkcja,
która zapewnia zgodność z HLO. W HLO tuple służą do reprezentowania zmiennych danych wejściowych i wyjściowych. W StableHLO obsługiwane są natywny wejścia i wyjścia, a jedynym zastosowaniem tupletów w StableHLO jest kompleksowe reprezentowanie HLO ABI, w którym np. T, tuple<T> i tuple<tuple<T>> mogą się znacznie różnić w zależności od konkretnej implementacji. W przyszłości planujemy wprowadzić zmiany w interfejsie HLO ABI, które mogą pozwolić nam usunąć typy tuple z StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Typy elementów reprezentują elementy typu tensora. W odróżnieniu od wielu języków programowania te typy nie są typu first class w StableHLO. Oznacza to, że programy StableHLO nie mogą bezpośrednio przedstawiać wartości tych typów (w efekcie są to wartości skalarne typu T z wartościami 0-wymiarowymi tensora typu tensor<T>).
- Typ logiczny reprezentuje wartości logiczne
trueifalse. - Typy całkowite mogą być typu ze znakiem (
si) lub bez znaku (ui) i mieć jedną z obsługiwanych szerokości bitów (2,4,8,16,32lub64). Podpisane typysiNreprezentują liczby całkowite z zakresu od-2^(N-1)do2^(N-1)-1włącznie, a bez znakuuiN– wartości całkowite z zakresu od0do2^N-1. - Typy zmiennoprzecinkowe mogą być następujące:
f8E3M4,f8E4M3if8E5M28-bitowe liczby zmiennoprzecinkową zgodnie z konwencjami IEEE-754.- Typy
f8E4M3FNif8E5M2odpowiadające odpowiednio kodowaniomE4M3iE5M2formatu FP8 opisanego w artykule Formaty FP8 do uczenia głębokiego. - Typy
f8E4M3FNUZif8E5M2FNUZodpowiadające kodowaniomE4M3iE5M2formatów FP8 opisanych w artykule 8-bitowe formaty liczbowe na potrzeby głębokich sieci neuronowych. - Typ
f8E4M3B11FNUZodpowiadający kodowaniuE4M3formatów FP8 opisanych w artykule Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks (Trenowanie i wykonywanie wnioskowania w sieciach głębokich neuronowych przy użyciu hybrydowego formatu 8-bitowego liczby zmiennoprzecinkowej (HFP8)). - Typ
bf16odpowiadający formatowibfloat16opisanemu w artykule BFloat16: sekret wysokiej wydajności w Cloud TPU. - Typy
f16,f32if64odpowiadające odpowiednio formatombinary16(„półpełnej precyzji”),binary32(„pełnej precyzji”) ibinary64(„podwójnej precyzji”) opisanym w standardzie IEEE 754. - Typ
tf32odpowiada formatowi TensorFloat32 i ma ograniczoną obsługę w ramach StableHLO. - Typy
f4E2M1FN,f6E2M3FN,f6E3M2FNif8E8M0FNUMX (mikroskalowanie) opisane w specyfikacji formatów mikroskalowania OCP.
- Typy zespolone reprezentują wartości zespolone, które mają część rzeczywistą i część urojoną tego samego typu elementu. Obsługiwane złożone typy to
complex<f32>(obie części są typuf32) icomplex<f64>(obie części są typuf64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Typy funkcji reprezentują zarówno funkcje nazwane, jak i anonimowe. Mają typy wejścia (lista typów po lewej stronie ->) i typy wyjścia (lista typów po prawej stronie ->). W wielu językach programowania typy funkcji są typu first class, ale nie w StableHLO.
StringType ::= 'string'
Typ ciągu tekstowego reprezentuje sekwencje bajtów. W przeciwieństwie do wielu języków programowania typ ciągu znaków nie jest pierwszą klasą w StableHLO i służy tylko do określania statycznych metadanych elementów programu.
Operacje
Operacje StableHLO (nazywane też operacjami) stanowią zamknięty zbiór operacji wysokiego poziomu w modelach uczenia maszynowego. Jak już wspomnieliśmy, składnia StableHLO jest mocno inspirowana MLIR, co niekoniecznie jest najbardziej ergonomiczną alternatywą, ale prawdopodobnie najlepiej pasuje do celu StableHLO, którym jest zwiększenie interoperacyjności między frameworkami i kompilatorami uczenia maszynowego.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Operacje StableHLO (nazywane też operacjami) mają nazwę, dane wejściowe/wyjściowe i podpis. Nazwa składa się z prefiksu stablehlo. i mnemotechniki, która jednoznacznie identyfikuje jedno z obsługiwanych operacji. Pełną listę wszystkich obsługiwanych operacji znajdziesz poniżej.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Operacje wykorzystują dane wejściowe i generują dane wyjściowe. Dane wejściowe są podzielone na wartości wejściowe (obliczone podczas wykonywania), funkcje wejściowe (podawane statycznie, ponieważ w StableHLO funkcje nie są wartościami najwyższej klasy) oraz atrybuty wejściowe (również podawane statycznie). Rodzaj danych wejściowych i wyjściowych zużywanych oraz wytwarzanych przez operatora zależy od jego skrótu. Na przykład funkcja add
op pobiera 2 wartości wejściowe i zwraca 1 wartość wyjściową. Natomiast operator select_and_scatter wymaga 3 wartości wejściowych, 2 funkcji wejściowych i 3 atrybutów wejściowych.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Funkcje wejściowe (nazywane też funkcjami anonimowymi) są bardzo podobne do funkcji nazwanych, z tym że: 1) nie mają identyfikatora (stąd nazwa „anonimowa”) i 2) nie deklarują typów danych wyjściowych (typy danych są wywnioskowane z opcji return w ramach funkcji).
Składnia funkcji wejściowych zawiera obecnie nieużywaną część (patrz Unused produkcja powyżej), która jest tam ze względu na zgodność z MLIR. W MLIR występuje bardziej ogólna koncepcja „regionów”, które mogą mieć wiele „bloków” operacji połączonych ze sobą za pomocą operacji przeskoku. Te bloki mają identyfikatory odpowiadające Unused produkcji, dzięki czemu można je odróżnić od siebie.
StableHLO nie obsługuje wykonywania operacji przejść, więc odpowiadająca mu część składni MLIR nie jest używana (ale nadal jest).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Atrybuty wejściowe mają nazwę i wartość, która jest jedną z obsługiwanych stałych. Stanowią one podstawowy sposób określania metadanych statycznych elementów programu. Na przykład operator concatenate używa atrybutu dimension do określania wymiaru, wzdłuż którego wartości wejściowe są łączone. Podobnie operacja slice używa wielu atrybutów, takich jak start_indices i limit_indices, aby określić granice używane do wycinania wartości wejściowej.
Obecnie programy StableHLO działające w środowisku naturalnym mogą czasami zawierać atrybuty, które nie są opisane w tym dokumencie. W przyszłości planujemy uwzględnić te atrybuty w opcji StableHLO lub zabronić ich wykorzystywania w programach StableHLO. Oto lista tych atrybutów:
layout(#629).mhlo.frontend_attributes(#628).mhlo.sharding(#619).output_operand_aliases(#740).- metadane lokalizacji (#594);
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Podpis operacji składa się z typów wszystkich wartości wejściowych (lista typów po lewej stronie ->) oraz typów wszystkich wartości wyjściowych (lista typów po prawej stronie ->). Ściśle mówiąc, typy wejścia są zbędne, a typy wyjścia prawie zawsze są zbędne (ponieważ w przypadku większości operacji StableHLO typy wyjścia można wywnioskować z danych wejściowych). Mimo to op
signature jest celowo częścią składni StableHLO, aby zapewnić zgodność z MLIR.
Poniżej znajduje się przykład operacji, której skrót to select_and_scatter. Przyjmuje 3 wartości wejściowe (%operand, %source i %init_value), 2 funkcje wejściowe i 3 atrybuty wejściowe (window_dimensions, window_strides i padding).
Zwróć uwagę, że podpis operacji obejmuje tylko typy wartości wejściowych (ale nie typy funkcji i atrybutów wejściowych podane w tekście).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Stałe
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Stałe StableHLO mają literał i typ, które razem reprezentują wartość StableHLO. Typ jest zwykle częścią składni stałej, z wyjątkiem sytuacji, gdy jest jednoznaczny (np. stała logiczna ma jednoznacznie typ i1, podczas gdy stała całkowita może mieć kilka możliwych typów).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Stałe logiczne reprezentują wartości logiczne true i false. Stałe logiczne mają typ i1.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Stałe liczbowe reprezentują wartości liczb całkowitych za pomocą ciągów tekstowych, które używają zapisu dziesiętnego lub szesnastkowego. Inne systemy liczbowe, np. binarny czy octal, nie są obsługiwane. Stałe liczbowe są objęte tymi ograniczeniami:
- (C1)
is_wellformed(integer_literal, integer_type).
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Stałe zmiennoprzecinkowe reprezentują wartości zmiennoprzecinkowe za pomocą ciągów znaków, które używają zapisu dziesiętnego lub wykładniczego. Dodatkowo przy użyciu notacji szesnastkowej można bezpośrednio określać bazowe bity w formacie zmiennoprzecinkowym odpowiedniego typu. Stałe zmiennoprzecinkowe są objęte tymi ograniczeniami:
- (C1) Jeśli używany jest zapis w formacie innym niż szesnastkowy,
is_wellformed(float_literal, float_type). - (C2) Jeśli używana jest notacja szesnastkowa,
size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Stałe zespolone reprezentują wartości zespolone za pomocą list części rzeczywistej (pojawia się jako pierwsza) i części urojonej (pojawia się jako druga). Na przykład (1.0, 0.0) : complex<f32> oznacza 1.0 + 0.0i, a (0.0, 1.0) : complex<f32> oznacza 0.0 + 1.0i. Kolejność, w jakiej te części są przechowywane w pamięci, jest zdefiniowana przez implementację. Stałe złożone mają te ograniczenia:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type)). - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Stałe tensora reprezentują wartości tensora za pomocą zagnieżdżonych list określonych za pomocą notacji NumPy. Na przykład dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> reprezentuje wartość tensora z tym mapowaniem indeksów na elementy: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. Kolejność, w jakiej te elementy są przechowywane w pamięci, jest zdefiniowana przez implementację. Stałe tensora mają następujące ograniczenia:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)), gdzie:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
- (C2)
has_shape(tensor_literal, shape(tensor_type)), gdzie:has_shape(element_literal: Syntax, []) = true.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).- w przeciwnym razie:
false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Zwartości skonwertowanego tensora reprezentują wartości skonwertowanego tensora za pomocą tej samej notacji co stałe tensora, a ich elementy są określone jako stałe ich typu magazynu. Poddane kwantyzacji stałe tensora mają następujące ograniczenia:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)). - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Literały łańcuchowe składają się z bajtów określonych za pomocą znaków ASCII i sekwencji ucieczki. Są one niezależne od kodowania, więc interpretacja tych bajtów jest zdefiniowana przez implementację. Literały łańcuchowe mają typ string.
Operacje
abs
Semantyka
Wykonuje elementarną operację abs na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- W przypadku liczb całkowitych ze znakiem: moduł liczby całkowitej.
- Dla liczb zmiennoprzecinkowych:
absz IEEE-754. - W przypadku liczb zespolonych: moduł zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(abs, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora | (C1-C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu signed integer lub zmiennoprzecinkowego albo tensor kwantyzowany na podstawie tensora | (K1–C2) |
Ograniczenia
- (C1)
shape(result) = shape(operand). - (C2)
baseline_element_type(result)zdefiniowano jako:complex_element_type(element_type(operand))jeżeliis_complex(operand).baseline_element_type(operand)w innych przypadkach.
Przykłady
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
dodaj
Semantyka
Wykonuje dodawanie elementów dwóch tensorów lhs i rhs i tworzy tensor result. W zależności od typu elementu:
- W przypadku wartości logicznych: operator logiczny LUB.
- W przypadku liczb całkowitych: dodawanie liczb całkowitych.
- Dla liczb zmiennoprzecinkowych:
additionz IEEE-754. - W przypadku liczb zespolonych: dodawanie zespolone.
- W przypadku typów skwantowanych:
dequantize_op_quantize(add, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor lub kwantyzowany tensor | (C1-C6) |
| (I2) | rhs |
tensor lub kwantyzowany tensor | (C1-C5), (C7) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor (tensor kwantowy) | (C1-C7) |
Ograniczenia
- Jeśli operacja używa niespakowanych tensorów:
- (C1)
type(lhs) = type(rhs) = type(result).
- (C1)
- Jeśli operacja używa kwantowanych tensorów:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result). - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result). - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result). - (C6) Jeśli
is_per_axis_quantized(lhs), toquantization_dimension(lhs) = quantization_dimension(result). - (C7) Jeśli
is_per_axis_quantized(rhs), toquantization_dimension(rhs) = quantization_dimension(result).
- (C2)
Przykłady
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantyka
Zapewnia, że operacje generujące inputs są wykonywane przed operacjami zależnymi od result. Wykonywanie tej operacji nie powoduje żadnych zmian. Służy ona tylko do ustanowienia zależności danych z poziomu result do inputs.
Dane wejściowe
| Etykieta | Nazwa | Typ |
|---|---|---|
| (I1) | inputs |
zmienna liczba token |
Wyniki
| Nazwa | Typ |
|---|---|
result |
token |
Przykłady
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantyka
W każdej grupie procesów w siatce procesów StableHLO konkatenuje wartości tensorów operands z każdego procesu wzdłuż wymiaru all_gather_dim i tworzy tensory results.
Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)jeślichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Następnie w każdym wierszu process_group:
operands...@receiver = [operand@sender for sender in process_group]za wszystkiereceiverwprocess_group.results...@process = concatenate(operands...@process, all_gather_dim)za wszystkieprocesswprocess_group.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operands |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1), (C6) |
| (I2) | all_gather_dim |
stała typu si64 |
(C1), (C6) |
| (I3) | replica_groups |
2-wymiarowa stała tensora typu si64 |
(C2-C4) |
| (I4) | channel_id |
stała typu si64 |
(C5) |
| (I5) | use_global_device_ids |
stała typu i1 |
(K5) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C6) |
Ograniczenia
- (C1)
0 <= all_gather_dim < rank(operands...). - (C2)
is_unique(replica_groups). - (C3) Parametr
size(replica_groups)jest zdefiniowany jako:num_replicas, jeśli używana jest właściwośćcross_replica.num_replicas, jeśli używana jest właściwośćcross_replica_and_partition.num_processes, jeśli używana jest właściwośćflattened_ids.
- (C4)
0 <= replica_groups < size(replica_groups). - (C5) Jeśli
use_global_device_ids = true, tochannel_id > 0. - (C6)
type(results...) = type(operands...)z wyjątkiem:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantyka
W ramach każdej grupy procesów w siatce procesów StableHLO stosuje funkcję redukcji computation do wartości tensorów operands z każdego procesu i generuje tensory results.
Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)jeślichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Następnie w każdym wierszu process_group:
results...@process[result_index] = exec(schedule)dla niektórych drzew binarnychschedulegdzie:exec(node)=computation(exec(node.left), exec(node.right)).exec(leaf)=leaf.value.
scheduleto drzewo binarne zdefiniowane przez implementację, którego traversal w kolejności toto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operands |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C5), (C6) |
| (I2) | replica_groups |
zmienna liczba stałych tensora jednowymiarowego typu si64 |
(K1–C3) |
| (I3) | channel_id |
stała typu si64 |
(C4) |
| (I4) | use_global_device_ids |
stała typu i1 |
(C4) |
| (I5) | computation |
funkcja | (C5) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C6-C7) |
Ograniczenia
- (C1)
is_unique(replica_groups). - (C2)
size(replica_groups)zdefiniowano jako:num_replicas, jeśli używana jest właściwośćcross_replica.num_replicas, jeśli używana jest właściwośćcross_replica_and_partition.num_processes, jeśli używana jest właściwośćflattened_ids.
- (C3)
0 <= replica_groups < size(replica_groups). - (C4) Jeśli
use_global_device_ids = true, tochannel_id > 0. - (C5)
computationma typ(tensor<E>, tensor<E>) -> (tensor<E>), gdzieis_promotable(element_type(operand), E). - (C6)
shape(results...) = shape(operands...). - (C7)
element_type(results...) = E.
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantyka
W ramach każdej grupy procesów w siatce procesów StableHLO dzieli wartości tensorów operands wzdłuż split_dimension na części, rozprasza te części między procesami, konkatenuje rozproszone części wzdłuż concat_dimension i tworzy tensory results.
Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)jeżelichannel_id <= 0.cross_partition(replica_groups)jeżelichannel_id > 0.
Następnie w każdym wierszu process_group:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)dla wszystkichsenderwprocess_group.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]gdziereceiver_index = process_group.index(receiver).results...@process = concatenate(scattered_parts...@process, concat_dimension).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operands |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1-C3), (C9) |
| (I2) | split_dimension |
stała typu si64 |
(C1), (C2), (C9) |
| (I3) | concat_dimension |
stała typu si64 |
(C3), (C9) |
| (I4) | split_count |
stała typu si64 |
(C2), (C4), (C8), (C9) |
| (I5) | replica_groups |
2-wymiarowa stała tensora typu si64 |
(C5-C8) |
| (I6) | channel_id |
stała typu si64 |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C9) |
Ograniczenia
- (C1)
0 <= split_dimension < rank(operands...). - (C2)
dim(operands..., split_dimension) % split_count = 0. - (C3)
0 <= concat_dimension < rank(operands...). - (C4)
0 < split_count. - (C5)
is_unique(replica_groups). - (C6)
size(replica_groups)jest zdefiniowany jako:num_replicas, jeśli używana jest właściwośćcross_replica.num_partitions, jeśli używana jest właściwośćcross_partition.
- (C7)
0 <= replica_groups < size(replica_groups). - (C8)
dim(replica_groups, 1) = split_count. - (C9)
type(results...) = type(operands...)oprócz tych, któresplit_dimension != concat_dimension:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
i
Semantyka
Wykonuje z punktu widzenia elementu ORAZ dwa tensory lhs i rhs, tworzy intensywność result. W zależności od typu elementu:
- W przypadku wartości logicznych: operator logiczny OR.
- W przypadku liczb całkowitych: bitowe ORAZ.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu logicznego lub całkowitego | (C1) |
| (I2) | rhs |
tensor typu logicznego lub całkowitego | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu logicznego lub całkowitego | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result).
Przykłady
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantyka
Wykonuje elementową operację atan2 na tensorach lhs i rhs, tworząc tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
atan2z IEEE-754. - W przypadku liczb zespolonych: complex atan2.
- W przypadku typów skwantowanych:
dequantize_op_quantize(atan2, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
| (I2) | rhs |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Przykłady
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantyka
Oblicza gradienty kilku wejść batch_norm_training z backpropagatinggrad_output i tworzy tensory grad_operand, grad_scale i grad_offset. Bardziej formalnie ta operacja może zostać wyrażona jako dekompozycja istniejących operacji StableHLO z użyciem składni Pythona:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
W przypadku typów skwantowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1-C3), (C5) |
| (I2) | scale |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4), (C5) |
| (I3) | mean |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
| (I4) | variance |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
| (I5) | grad_output |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C2), (C3) |
| (I6) | epsilon |
stała typu f32 |
|
| (I7) | feature_index |
stała typu si64 |
(C1), (C5) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
grad_operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C2), (C3) |
grad_scale |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
grad_offset |
Jednowymiarowy tensor typu kwantyzowanego typu zmiennoprzecinkowego lub na tensor | (C2), (C4) |
Ograniczenia
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,mean,variance,grad_output,grad_operand,grad_scaleigrad_offsetmają te same wartościbaseline_element_type. - (C3)
operand,grad_outputigrad_operandmają ten sam kształt. - (C4) Znaczniki
scale,mean,variance,grad_scaleigrad_offsetmają ten sam kształt. - (C5)
size(scale) = dim(operand, feature_index).
Przykłady
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantyka
Normalizuje tensor operand we wszystkich wymiarach z wyjątkiem wymiaru feature_index i tworzy tensor result. Bardziej formalnie tę operację można wyrazić jako dekompozycję istniejących operacji StableHLO za pomocą składni Pythona w ten sposób:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
W przypadku typów kwantowych wybiera dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub kwantyzowany tensor na poziomie procesora | (C1-C7) |
| (I2) | scale |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C3) |
| (I3) | offset |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
| (I4) | mean |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (K5) |
| (I5) | variance |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C6) |
| (I6) | epsilon |
stała typu f32 |
|
| (I7) | feature_index |
stała typu si64 |
(C1), (C3-C6) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C2), (C7) |
Ograniczenia
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,mean,varianceiresultmają te same ustawieniabaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(mean) = dim(operand, feature_index). - (C6)
size(variance) = dim(operand, feature_index). - (C7)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantyka
Oblicza średnią i wariancję we wszystkich wymiarach oprócz wymiaru feature_index oraz normalizuje tensor operand, tworząc tensory output, batch_mean i batch_var. Bardziej formalnie ta operacja może być wyrażona jako dekompozycja istniejących operacji StableHLO z użyciem składni Pythona:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
W przypadku typów skwantowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
| (I2) | scale |
Jednowymiarowy tensor kwantyzowany (zmiennoprzecinkowy lub na tensor) | (C2), (C3) |
| (I3) | offset |
1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora | (C2), (C4) |
| (I4) | epsilon |
stała typu f32 |
(C1), (C3-C6) |
| (I5) | feature_index |
stała typu si64 |
(C1), (C3-C6) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
output |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C7) |
batch_mean |
1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora | (C2), (C5) |
batch_var |
1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora | (C2), (C6) |
Ograniczenia
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,batch_mean,batch_varioutputmają ten sambaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(batch_mean) = dim(operand, feature_index). - (C6)
size(batch_var) = dim(operand, feature_index). - (C7)
baseline_type(output) = baseline_type(operand).
Przykłady
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantyka
Wykonuje operację bitcast na tensorze operand i tworzy tensor result, w którym bity całego tensora operand są ponownie interpretowane przy użyciu typu tensora result.
Bardziej formalnie, jeśli E = element_type(operand), E' = element_type(result) i R = rank(operand):
- Jeśli
num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). - Jeśli
num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). - Jeśli
num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
Funkcja bits zwraca reprezentację danej wartości w pamięci, a jej działanie jest definiowane przez implementację, ponieważ dokładna reprezentacja tensorów jest zdefiniowana przez implementację, a dokładna reprezentacja typów elementów również jest zdefiniowana w implementacji.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub kwantyzowany tensor | (C1-C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub kwantyzowany tensor | (K1–C2) |
Ograniczenia
- (C1) Przy założeniu, że
E = is_quantized(operand) ? storage_type(operand) : element_type(operand),E' = is_quantized(result) ? storage_type(result) : element_type(result)iR = rank(operand):- Jeśli
num_bits(E') = num_bits(E),shape(result) = shape(operand). - Jeśli
num_bits(E') < num_bits(E): rank(result) = R + 1.dim(result, i) = dim(operand, i)dla wszystkich0 <= i < R.dim(result, R) * num_bits(E') = num_bits(E).- Jeśli
num_bits(E') > num_bits(E): rank(result) = R - 1.dim(result, i) = dim(operand, i)za wszystkie0 <= i < R.dim(operand, R - 1) * num_bits(E) = num_bits(E').
- Jeśli
- (C2) Jeśli
is_complex(operand) or is_complex(result), tois_complex(operand) and is_complex(result).
Przykłady
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantyka
Rozwija wymiary lub rangę wejściowego tensora przez powielanie danych w tensorze operand i tworzenie tensora result. Formalnie:
result[result_index] = operand[operand_index] dla wszystkich d w ramach axes(operand):
operand_index[d] = 0jeżelidim(operand, d) = 1.operand_index[d] = result_index[broadcast_dimensions[d]]w innych przypadkach.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub kwantyzowany tensor | (C1-C2), (C5-C6) |
| (I2) | broadcast_dimensions |
Jednowymiarowa stała tensora typu si64 |
(C2-C6) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor (tensor kwantowy) | (C1), (C3), (C5-C6) |
Ograniczenia
- (C1)
element_type(result)otrzymuje:element_type(operand), jeśli!is_per_axis_quantized(operand).element_type(operand), z tym żequantization_dimension(operand),scales(operand)izero_points(operand)mogą się różnić odquantization_dimension(result),scales(result)izero_points(result)odpowiednio.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) W przypadku wszystkich
dwaxes(operand):dim(operand, d) = 1lubdim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Jeśli
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Jeśli
dim(operand, quantization_dimension(operand)) = 1, toscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
Przykłady
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
etui
Semantyka
Generuje dane wyjściowe w wyniku wykonania dokładnie 1 funkcji z funkcji branches w zależności od wartości index. Bardziej formalnie: result = selected_branch()gdzie:
selected_branch = branches[index]jeżeli0 <= index < size(branches).selected_branch = branches[-1]w innych przypadkach.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | index |
Tensor 0-wymiarowy typu si32 |
|
| (I2) | branches |
zmienna liczba funkcji | (C1-C4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C4) |
Ograniczenia
- (C1)
0 < size(branches). - (C2)
input_types(branches...) = []. - (C3)
same(output_types(branches...)). - (C4)
type(results...) = output_types(branches[0]).
Przykłady
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
Cbrt
Semantyka
Wykonuje elementarną operację pierwiastka sześciennego na tensorze operand i tworzy tensor result. W zależności od typu elementu wykonuje te działania:
- W przypadku jednostek zmiennoprzecinkowych:
rootn(x, 3)w standardzie IEEE-754. - Liczby zespolone: pierwiastek sześcienny zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(cbrt, operand, type(result))
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantyka
Przeprowadza element po elemencie zaokrąglenie w górę tensora operand i generuje tensor result.
Realizuje operację roundToIntegralTowardPositive według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(ceil, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semantyka
Oblicza rozkład Choleskiego dla zbioru macierzy.
W bardziej formalnym ujęciu dla wszystkich i w index_space(result) result[i0, ..., iR-3, :, :] jest rozkładem Cholesky'ego a[i0, ..., iR-3, :, :] w postaci dolnej macierzy trójkątnej (jeśli lower jest true) lub górnej macierzy trójkątnej (jeśli lower jest false).
Wartości wyjściowe w trójkącie przeciwległym, tj. odpowiednio ścisły górny trójkąt lub odpowiednio ścisły trójkąt dolny, są definiowane na podstawie implementacji.
Jeśli istnieje element i, w którym macierz wejściowy nie jest macierzową o dodatnie dodatnim, działanie jest niezdefiniowane.
W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | a |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1-C3) |
| (I2) | lower |
Stałe tensora 0-wymiarowego typu i1 |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(a) = baseline_type(result). - (C2)
2 <= rank(a). - (C3)
dim(a, -2) = dim(a, -1).
Przykłady
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
ograniczać (zakres)
Semantyka
Łączy każdy element tensora operand między wartością minimalną a maksymalną, tworząc tensor result. Bardziej formalnie: result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element), gdzie min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(clamp, min, operand, max, type(result)).
Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | min |
tensor lub tensor zagregowany z tensorów | (C1), (C3) |
| (I2) | operand |
tensor lub tensor zagregowany z tensorów | (C1-C4) |
| (I3) | max |
tensor lub tensor zagregowany z tensorów | (C2), (C3) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C4) |
Ograniczenia
- (C1)
rank(min) = 0 or shape(min) = shape(operand). - (C2)
rank(max) = 0 or shape(max) = shape(operand). - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max). - (C4)
baseline_type(operand) = baseline_type(result).
Przykłady
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantyka
W każdej grupie procesów w siatce procesów StableHLO wyślij wartość tensora operand z procesu źródłowego do procesów docelowych i utwórz tensor result.
Ta operacja dzieli siatkę procesów StableHLO na siatkę procesów process_groups, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)jeżelichannel_id <= 0.cross_partition(replica_groups)jeżelichannel_id > 0.
Następnie result@process jest obliczana jako:
operand@process_groups[i, 0], jeśli istniejei, który określa, że proces znajduje się w regionieprocess_groups[i].broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))w innym przypadku.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub tensor zagregowany z tensorów | (C3) |
| (I2) | replica_groups |
zmienna liczba stałych tensora jednowymiarowego typu si64 |
(C1), (C2) |
| (I3) | channel_id |
stała typu si64 |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C3) |
Ograniczenia
- (C1)
is_unique(replica_groups). - (C2)
0 <= replica_groups < N, gdzieNjest zdefiniowany jako:num_replicas, jeśli używana jest właściwośćcross_replica.num_partitions, jeśli używana jest właściwośćcross_partition.
- (C3)
type(result) = type(operand).
Przykłady
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantyka
W każdej grupie procesów w siatce procesów StableHLO wysyła wartość tensora operand z procesu źródłowego do procesu docelowego i tworzy tensor result.
Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:
cross_replica(source_target_pairs), jeślichannel_id <= 0.cross_partition(source_target_pairs)jeżelichannel_id > 0.
Później wartość result@process jest określana przez:
operand@process_groups[i, 0], jeśli istniejeitakie, żeprocess_groups[i, 1] = process.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
kwantowy tensor lub tensor kwantowy | (K5) |
| (I2) | source_target_pairs |
2-wymiarowa stała tensora typu si64 |
(C1-C4) |
| (I3) | channel_id |
stała typu si64 |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, gdzieNjest zdefiniowany jako:num_replicas, jeśli używana jest właściwośćcross_replica.num_partitions, jeśli używana jest właściwośćcross_partition.
- (C5)
type(result) = type(operand).
Przykłady
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
porównaj
Semantyka
Porównuje elementy tensorów lhs i rhs zgodnie z definicjami comparison_direction i compare_type, tworząc tensor result.
Wartości comparison_direction i compare_type mają następującą semantykę:
W przypadku typów elementów wartości logicznych i liczb całkowitych:
EQ:lhs = rhs.NE:lhs != rhs.GE:lhs >= rhs.GT:lhs > rhs.LE:lhs <= rhs.LT:lhs < rhs.
W przypadku typów elementów zmiennoprzecinkowych z compare_type = FLOAT operator op implementuje te operacje IEEE-754:
EQ:compareQuietEqual.NE:compareQuietNotEqual.GE:compareQuietGreaterEqual.GT:compareQuietGreater.LE:compareQuietLessEqual.LT:compareQuietLess.
W przypadku typów elementów zmiennoprzecinkowych z wartością compare_type = TOTALORDER operator używa kombinacji operacji totalOrder i compareQuietEqual z normy IEEE-754.
W przypadku złożonych typów elementów przeprowadzane jest porównanie leksykograficzne par (real, imag) za pomocą podanych wartości comparison_direction i compare_type.
Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych, gdy comparison_direction = GE, GT, LE lub LT (#560).
W przypadku typów poddanych kwantyzacji wykonuje działanie dequantize_compare(lhs, rhs,
comparison_direction).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C1-C3) |
| (I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1-C2) |
| (I3) | comparison_direction |
enum EQ, NE, GE, GT, LE i LT |
|
| (I4) | compare_type |
enum FLOAT, TOTALORDER, SIGNED i UNSIGNED |
(C3) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu logicznego | (K2) |
Ograniczenia
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs). - (C2)
shape(lhs) = shape(rhs) = shape(result). - (C3)
compare_typejest zdefiniowany jako:SIGNEDjeżeliis_signed_integer(element_type(lhs)).UNSIGNEDjeżeliis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).FLOATlubTOTALORDER, jeśliis_float(element_type(lhs)).FLOATjeżeliis_complex(element_type(lhs)).
Przykłady
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
złożone
Semantyka
Przeprowadza konwersję element po elemencie na wartość zespoloną z pary wartości rzeczywistych i urojonych, lhs i rhs, oraz generuje tensor result.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu f32 lub f64 |
(C1-C3) |
| (I2) | rhs |
tensor typu f32 lub f64 |
(C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zespolonego | (C2), (C3) |
Ograniczenia
- (C1)
type(lhs) = type(rhs). - (C2)
shape(result) = shape(lhs). - (C3)
element_type(result)ma typcomplex<E>, gdzieE = element_type(lhs).
Przykłady
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
wieloskładnikowa
Semantyka
Zawiera operację złożoną z innych operacji StableHLO, przyjmując inputs i composite_attributes oraz zwraca results. Semantyka operacji jest implementowana przez atrybut decomposition. Opcję composite można zastąpić jej rozkładem bez zmiany semantyki programu. Jeśli wstawienie dekompozycji w kod źródłowy nie zapewnia tej samej semantyki op, użyj custom_call.
Pole version (domyślnie 0) służy do informowania o zmianie semantyki elementu złożonego.
Dane wejściowe
| Etykieta | Nazwa | Typ |
|---|---|---|
| (I1) | inputs |
zmienna liczba wartości |
| (I2) | name |
stała typu string |
| (I3) | composite_attributes |
słownik atrybutów |
| (I4) | decomposition |
stała typu string |
| (I5) | version |
stała typu si32 |
Wyniki
| Nazwa | Typ |
|---|---|
results |
zmienna liczba wartości |
Ograniczenia
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
Przykłady
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
konkatenacja
Semantyka
Łączy elementy inputs wzdłuż wymiaru dimension w takim samym porządku jak podane argumenty i tworzy tensor result. Bardziej formalnie:
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], gdzie:
id = d0 + ... + dk-1 + kd.djest równydimension, ad0, ... todrozmiaryinputs.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (K1–C6) |
| (I2) | dimension |
stała typu si64 |
(C2), (C4), (C6) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C5–C6) |
Ograniczenia
- (C1)
same(element_type(inputs...)). - (C2)
same(shape(inputs...))z wyjątkiemdim(inputs..., dimension). - (C3)
0 < size(inputs). - (C4)
0 <= dimension < rank(inputs[0]). - (C5)
element_type(result) = element_type(inputs[0]). - (C6)
shape(result) = shape(inputs[0])oprócz:dim(result, dimension) = dim(inputs[0], dimension) + ....
Przykłady
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
stała
Semantyka
Tworzy tensor output na podstawie stałej value.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | value |
stała | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
output |
tensor lub kwantyzowany tensor | (C1) |
Ograniczenia
- (C1)
type(value) = type(output).
Przykłady
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
dokonają konwersji
Semantyka
Przeprowadza konwersję elementów z jednego typu na inny w tensorze operand i tworzy tensor result.
W przypadku konwersji boolean-to-any-supported-type wartość false jest zamieniana na 0, a wartość true na 1. W przypadku konwersji typu any-supported-type-to-boolean wartość 0 jest konwertowana na false, a wartości inne niż zero – na true. Poniżej znajdziesz informacje o tym, jak to działa w przypadku typów złożonych.
W przypadku konwersji zawierających liczby całkowite na liczbę całkowitą, liczbę zmiennoprzecinkową lub zmiennoprzecinkową na liczbę zmiennoprzecinkową wartość źródłową może być dokładnie reprezentowana w typie miejsca docelowego. W przeciwnym razie zachowanie jest nieokreślone (#180).
W przypadku konwersji zawierających floating-point-to-integer część ułamkowa jest skracana. Jeśli skrócona wartość nie może być przedstawiona w przypadku typu miejsca docelowego, zachowanie jest nieokreślone (#180).
Konwersja z typu zespolonego na typ zespolony działa tak samo jak konwersja z typu zmiennoprzecinkowego na typ zmiennoprzecinkowy w przypadku konwersji części rzeczywistej i części urojonej.
W przypadku konwersji z typu złożonego na dowolny inny typ i z dowolnego innego typu na złożony wartość wyobrażona źródła jest ignorowana lub docelowa wartość wyobrażona jest ustawiana na 0. Konwersja części rzeczywistej następuje zgodnie z konwersją na liczby zmiennoprzecinkowe.
Zasadniczo ta operacja może wyrażać dekwantyzację (przekształcanie tensorów regularnych w tensory regularne), kwantyzację (konwersja z tensorów regularnych na tensory kwantyzowane) i rekwantyzację (przekształcanie między kwantyzowanymi procesorami), ale obecnie mamy dla nich specjalne operacje – uniform_dequantize dla pierwszego przypadku użycia i uniform_quantize w drugim i trzecim przypadku użycia. W przyszłości te 2 operacje mogą zostać połączone w convert (#1576).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor | (C1) |
Ograniczenia
- (C1)
shape(operand) = shape(result).
Przykłady
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
splotu
Semantyka
Oblicza iloczyny skalarne między oknami lhs i wycinkami rhs oraz generuje result. Na diagramie poniżej widać, jak elementy w elementach result są obliczane na podstawie elementów lhs i rhs.
Bardziej formalnie rozważ następujące zmiany w danych wejściowych w postaci elementu lhs, aby można było tworzyć przedziały czasu lhs:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).lhs_window_strides = lhs_shape(1, window_strides, 1).lhs_padding = lhs_shape([0, 0], padding, [0, 0]).lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).
Ta zmiana kadrowania wykorzystuje te funkcje pomocnicze:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]gdziej[d] = i[permutation[d]].
Jeśli feature_group_count = 1 i batch_group_count = 1, to dla wszystkich output_spatial_index w index_space(dim(result, output_spatial_dimensions...)):result[result_shape(:, output_spatial_index, :)] = dot_product gdzie:
padding_value = constant(0, element_type(lhs)).padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Ta funkcja wydaje się nieużywana, więc planujemy ją usunąć w przyszłości (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).
Jeśli feature_group_count > 1:
lhses = split(lhs, feature_group_count, input_feature_dimension).rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Jeśli batch_group_count > 1:
lhses = split(lhs, batch_group_count, input_batch_dimension).rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).result = concatenate(results, output_feature_dimension).
W przypadku typów kwantowych wybiera dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)).
W przypadku typów hybrydowych kwantyzowana wartość wylicza hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
| (I2) | rhs |
tensor lub kwantyzowany tensor | (C1), (C14–C16), (C25), (C27–C29), (C31–C34) |
| (I3) | window_strides |
Jednowymiarowa stała tensora typu si64 |
(C2-C3), (C25) |
| (I4) | padding |
2-wymiarowa stała tensora typu si64 |
(C4), (C25) |
| (I5) | lhs_dilation |
Jednowymiarowa stała tensora typu si64 |
(C5-C6), (C25) |
| (I6) | rhs_dilation |
1-wymiarowa stała tensora typu si64 |
(C7-C8), (C25) |
| (I7) | window_reversal |
1-wymiarowa stała tensora typu i1 |
(C9) |
| (I8) | input_batch_dimension |
stała typu si64 |
(C10), (C13), (C25) |
| (I9) | input_feature_dimension |
stała typu si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C12), (C13), (C25) |
| (I11) | kernel_input_feature_dimension |
stała typu si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
stała typu si64 |
(C15–C16), (C18), (C25), (C29) |
| (I13) | kernel_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C17-C18), (C25) |
| (I14) | output_batch_dimension |
stała typu si64 |
(C20), (C25) |
| (I15) | output_feature_dimension |
stała typu si64 |
(C20), (C25), (C30) |
| (I16) | output_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C19-C20), (C25) |
| (I17) | feature_group_count |
stała typu si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
stała typu si64 |
(C10), (C15), (C22), (C23), (C25) |
| (I19) | precision_config |
zmienna liczba typów enumeracji DEFAULT, HIGH i HIGHEST |
(C24) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor (tensor kwantowy) | (C25-C28), (C30), (C32-34) |
Ograniczenia
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Zgodnie z
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Podany
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Zgodnie z
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)jest zdefiniowany jako:dim(lhs, input_batch_dimension) / batch_group_countjeżeliresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)jeżeliresult_dim = output_feature_dimension.num_windowsw przeciwnym razie, gdzie:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Jeśli operacja używa tensorów niekwantowych:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Jeśli operacja używa kwantowanych tensorów:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Jeśli
is_per_axis_quantized(rhs), toquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Jeśli
is_per_axis_quantized(result), toquantization_dimension(result) = output_feature_dimension. - Jeśli
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Jeśli
is_per_tensor_quantized(rhs), tois_per_tensor_quantized(result). - Jeśli
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Przykłady
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosinus
Semantyka
Wykonuje elementarną operację cosinusa na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
cosz IEEE-754. - W przypadku liczb zespolonych: cosinus zespolony.
- W przypadku typów kwantowych:
dequantize_op_quantize(cosine, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantyka
Wykonuje element po elemencie zliczanie liczby początkowych zer w tensorze operandi tworzy tensor result.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu liczba całkowita | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result).
Przykłady
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantyka
Zawiera zdefiniowaną przez implementację operację call_target_name, która przyjmuje argumenty inputs i called_computations oraz zwraca results. Atrybuty has_side_effect, backend_config i api_version mogą służyć do udostępniania dodatkowych metadanych zdefiniowanych przez implementację.
Obecnie ta operacja zawiera dość nieuporządkowaną kolekcję metadanych, która odzwierciedla organiczną ewolucję jej odpowiednika w kompilatorze XLA. W przyszłości planujemy ujednolicić te metadane (#741).
Dane wejściowe
| Etykieta | Nazwa | Typ |
|---|---|---|
| (I1) | inputs |
zmienna liczba wartości |
| (I2) | call_target_name |
stała typu string |
| (I3) | has_side_effect |
stała typu i1 |
| (I4) | backend_config |
stała typu string lub słownik atrybutów |
| (I5) | api_version |
stała typu si32 |
| (I6) | called_computations |
liczba zmiennoprzecinkowa typu string |
Wyniki
| Nazwa | Typ |
|---|---|
results |
liczba zmiennoprzecinkowa |
Przykłady
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
dzielenie
Semantyka
Wykonuje element po elemencie dzielenie tensorów dzielnika lhs i dzielenia rhs oraz zwraca tensor result. W zależności od typu elementu:
- W przypadku liczb całkowitych: dzielenie liczb całkowitych, które zwraca iloraz algebraiczny z pominięciem części ułamkowej.
- Dla liczb zmiennoprzecinkowych:
divisionz IEEE-754. - W przypadku liczb zespolonych: dzielenie zespolone.
- W przypadku typów skwantowanych:
dequantize_op_quantize(divide, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
| (I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Przykłady
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantyka
Oblicza iloczyn skalarny między wycinkami lhs i rhs, uzyskując tensor result.
Bardziej formalnie: result[result_index] = dot_product, gdzie:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].result_batching_index + result_lhs_index + result_rhs_index = result_indexgdziesize(result_batching_index) = size(lhs_batching_dimensions),size(result_lhs_index) = size(lhs_result_dimensions)isize(result_rhs_index) = size(rhs_result_dimensions).transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).
W przypadku typów skwantowanych wykonuje operację dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).
W przypadku hybrydowych typów kwantyzacji wykonuje działanie hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs).
precision_config kontroluje kompromis między szybkością a dokładnością obliczeń na backendzie akceleratora. Może to być jedna z tych wartości (obecnie semantyka tych wartości jest niewystarczająco sprecyzowana, ale planujemy rozwiązać ten problem w bładze zgłoszenia 755):
DEFAULT: najszybsze obliczenia, ale najmniej dokładne przybliżenie do pierwotnego wyniku.HIGH: wolniejsze obliczenia, ale dokładniejsze przybliżenie do pierwotnej wartości.HIGHEST: najwolniejsze obliczenia, ale najbardziej dokładne przybliżenie do pierwotnej wartości.
DotAlgorithm definiuje główne właściwości algorytmu używanego do implementacji operacji kropki, która określa też dokładność. Jeśli pola atrybutów algorytmu są ustawione, precision_config musi mieć wartość DEFAULT. DotAlgorithms nie mają wartości domyślnej, ponieważ parametry domyślne są definiowane przez implementację. Dlatego wszystkie pola algorytmu kropki mogą być ustawione na None, aby określić pusty algorytm kropki, który zamiast tego użyje wartości precision_config.
Pola DotAlgorithm:
lhs_precision_typeirhs_precision_type, dokładność, do której zaokrąglane są wartości po lewej i po prawej stronie operacji. Typy dokładności są niezależne od typów magazynowania danych wejściowych i wyjściowych.accumulation_typedokładność użyta do skumulowania.lhs_component_count,rhs_component_countinum_primitive_operationsstosuje się, gdy używamy algorytmu, który rozkłada lewą lub prawą stronę na kilka komponentów i wykonuje na tych wartościach wiele „prostych” operacji dot – zwykle w celu emulacji większej precyzji (np. Korzystanie z typu danych bfloat16 do obliczeń o większej precyzji: bf16_6x tf32_3x itp.). W przypadku algorytmów bez dekompozycji te wartości powinny wynosić1.allow_imprecise_accumulation, aby określić, czy akumulacja z mniejszą dokładnością jest dozwolona w przypadku niektórych kroków (np.CUBLASLT_MATMUL_DESC_FAST_ACCUM).
Przykładowe atrybuty DotAlgorithm:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
To od implementacji zależy, które kombinacje są obsługiwane. Ogólnie nie ma gwarancji, że każdy algorytm jest obsługiwany przez każdy typ akceleratora przez użytkownika StableHLO. Jeśli dany algorytm nie jest obsługiwany, należy zgłosić błąd zamiast korzystać z alternatywnego rozwiązania. Weryfikacja StableHLO zapewni najlepszą weryfikację, zapobiegając działaniu algorytmów, które nie są obsługiwane na żadnym sprzęcie.
Niektóre obsługiwane wartości algorytmu znajdziesz w sekcji xla_data.proto > Algorithm. zgłoszenie #2483 zawiera plan stworzenia scentralizowanego dokumentu na temat obsługiwanych algorytmów na zapleczu.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
| (I2) | rhs |
tensor lub kwantyzowany tensor | (C7-C10), (C12-C20) |
| (I3) | lhs_batching_dimensions |
1-wymiarowa stała tensora typu si64 |
(C1), (C3), (C5), (C9), (C12) |
| (I4) | rhs_batching_dimensions |
1-wymiarowa stała tensora typu si64 |
(C1), (C4), (C7), (C9) |
| (I5) | lhs_contracting_dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C3), (C6), (C10) |
| (I6) | rhs_contracting_dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C8), (C10), (C16) |
| (I7) | precision_config |
zmienna liczba typów enumeracji DEFAULT, HIGH i HIGHEST |
(C11), (C21) |
| (I8) | lhs_precision_type |
FloatType lub TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType lub TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType lub TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
stała typu si32 |
(C21), (C22) |
| (I12) | rhs_component_count |
stała typu si32 |
(C21), (C23) |
| (I13) | num_primitive_operations |
stała typu si32 |
(C21), (C24) |
| (I14) | allow_imprecise_accumulation |
stała typu bool |
(C21) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub kwantyzowany tensor | (C12), (C14), (C18–C20) |
Ograniczenia
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions). - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions). - (C5)
0 <= lhs_batching_dimensions < rank(lhs). - (C6)
0 <= lhs_contracting_dimensions < rank(lhs). - (C7)
0 <= rhs_batching_dimensions < rank(rhs). - (C8)
0 <= rhs_contracting_dimensions < rank(rhs). - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...). - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...). - (C11)
size(precision_config) = 2. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions). - Jeśli operacja używa niespakowanych tensorów:
- (C13)
element_type(lhs) = element_type(rhs).
- (C13)
- Jeśli operacja używa kwantowanych tensorów:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C15)
zero_points(rhs) = 0. - (C16) Jeśli
is_per_axis_quantized(rhs), toquantization_dimension(rhs)nie jest wrhs_contracting_dimensions. - Jeśli
is_quantized(lhs): - (C17)
storage_type(lhs) = storage_type(rhs). - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C19) Jeśli
is_per_tensor_quantized(rhs), tois_per_tensor_quantized(result). - Jeśli
!is_quantized(lhs): - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C14)
- Jeśli
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation):- (C21)
precision_config... = DEFAULT. - (C22)
0 < lhs_component_count. - (C23)
0 < rhs_component_count. - (C24)
0 < num_primitive_operations.
- (C21)
Przykłady
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją broadcast_in_dim, ale kształt wyniku jest określany dynamicznie za pomocą parametru output_dimensions.
Operacja akceptuje też opcjonalne atrybuty known_expanding_dimensions i known_nonexpanding_dimensions, które służą do wyrażania statycznej wiedzy o zachowaniu rozszerzania wymiarów.
Jeśli nie są określone, przyjmuje się, że wszystkie wymiary mogą się rozszerzać.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor (tensor kwantowy) | (C1–C2), (C5–C6), (C9) |
| (I2) | output_dimensions |
Jednowymiarowy tensor typu liczby całkowitej | (C7) |
| (I3) | broadcast_dimensions |
1-wymiarowy tensor stały typu całkowitego | (C2-C6) |
| (I4) | known_expanding_dimensions |
Jednowymiarowy tensor stały typu liczby całkowitej | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
1-wymiarowy tensor stały typu całkowitego | (C8–C9) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub kwantyzowany tensor | (C1), (C3), (C5-C7) |
Ograniczenia
- (C1)
element_type(result)otrzymuje:element_type(operand), jeśli!is_per_axis_quantized(operand).element_type(operand), z tym żequantization_dimension(operand),scales(operand)izero_points(operand)mogą się różnić odquantization_dimension(result),scales(result)izero_points(result)odpowiednio.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) W przypadku wszystkich
dwaxes(operand):dim(operand, d) = 1lubdim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Jeśli
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Jeśli
dim(operand, quantization_dimension(operand)) = 1, toscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
- (C7)
size(output_dimensions) = rank(result). - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions). - (C9)
0 <= known_expanding_dimensions < rank(operand). - (C10)
0 <= known_nonexpanding_dimensions < rank(operand).
Przykłady
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją convolution, ale wypełnienie jest podawane dynamicznie za pomocą padding.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | rhs |
tensor lub kwantyzowany tensor | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | padding |
2-wymiarowy tensor typu całkowitego | (C4) |
| (I4) | window_strides |
Jednowymiarowa stała tensora typu si64 |
(C2-C3) |
| (I5) | lhs_dilation |
Jednowymiarowa stała tensora typu si64 |
(C5-C6) |
| (I6) | rhs_dilation |
1-wymiarowa stała tensora typu si64 |
(C7-C8) |
| (I7) | window_reversal |
1-wymiarowa stała tensora typu i1 |
(C9) |
| (I8) | input_batch_dimension |
stała typu si64 |
(C10), (C13) |
| (I9) | input_feature_dimension |
stała typu si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C12), (C13) |
| (I11) | kernel_input_feature_dimension |
stała typu si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
stała typu si64 |
(C15-C16), (C18), (C28) |
| (I13) | kernel_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C17-C18) |
| (I14) | output_batch_dimension |
stała typu si64 |
(C20) |
| (I15) | output_feature_dimension |
stała typu si64 |
(C20), (C29) |
| (I16) | output_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C19-C20) |
| (I17) | feature_group_count |
stała typu si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
stała typu si64 |
(C10), (C15), (C22), (C23) |
| (I19) | precision_config |
zmienna liczba typów enumeracji DEFAULT, HIGH i HIGHEST |
(C24) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub kwantyzowany tensor | (C25-C27), (C29), (C31-C33) |
Ograniczenia
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Zgodnie z
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Podany
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Zgodnie z
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)jest zdefiniowany jako:dim(lhs, input_batch_dimension) / batch_group_countjeżeliresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)jeżeliresult_dim = output_feature_dimension.num_windowsw przeciwnym razie, gdzie:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Jeśli operacja używa tensorów niekwantowych:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Jeśli operacja używa kwantowanych tensorów:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Jeśli
is_per_axis_quantized(rhs), toquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Jeśli
is_per_axis_quantized(result), toquantization_dimension(result) = output_feature_dimension. - Jeśli
is_quantized(lhs): - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Jeśli
is_per_tensor_quantized(rhs), tois_per_tensor_quantized(result). - Jeśli
!is_quantized(lhs): - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Przykłady
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z gather
op, z tym że slice_sizes jest dynamicznie określany jako wartość.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1), (C7), (C10-C12), (C14) |
| (I2) | start_indices |
tensor typu liczba całkowita | (C2), (C3), (C13) |
| (I3) | slice_sizes |
1-wymiarowy tensor typu całkowitego | (C8), (C11–C13) |
| (I4) | offset_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C4–C5), (C13) |
| (I5) | collapsed_slice_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C6-C8), (C13) |
| (I6) | start_index_map |
1-wymiarowa stała tensora typu si64 |
(C3), (C9), (C10) |
| (I7) | index_vector_dim |
stała typu si64 |
(C2), (C3), (C13) |
| (I8) | indices_are_sorted |
stała typu i1 |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C5), (C13–C14) |
Ograniczenia
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims). - (C7)
0 <= collapsed_slice_dims < rank(operand). - (C8)
slice_sizes[collapsed_slice_dims...] <= 1. - (C9)
is_unique(start_index_map). - (C10)
0 <= start_index_map < rank(operand). - (C11)
size(slice_sizes) = rank(operand). - (C12)
0 <= slice_sizes <= shape(operand). - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)gdzie:batch_dim_sizes = shape(start_indices), z tym że nie uwzględnia wymiarustart_indicesodpowiadającego wymiarowiindex_vector_dim.offset_dim_sizes = shape(slice_sizes), z tym że nie uwzględnia wymiarówslice_sizesodpowiadających wymiarowicollapsed_slice_dims.- Funkcja
combineumieszczabatch_dim_sizesna osi odpowiadającejbatch_dims, a funkcjaoffset_dim_sizes– na osi odpowiadającejoffset_dims.
- (C14)
element_type(operand) = element_type(result).
Przykłady
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją iota, ale kształt wyniku jest dynamicznie określany za pomocą parametru output_shape.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | output_shape |
1-wymiarowy tensor typu całkowitego | (C1), (C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C2) |
Ograniczenia
- (C1)
0 <= iota_dimension < size(output_shape). - (C2)
rank(result) = size(output_shape).
Przykłady
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantyka
Ta operacja jest funkcjonalnie identyczna z pad, ale z wartościami edge_padding_low, edge_padding_high i interior_padding określonymi dynamicznie.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor 0-wymiarowy lub kwantowy tensor na tensor | (C1) |
| (I3) | edge_padding_low |
Jednowymiarowy tensor typu liczby całkowitej | (C1), (C4) |
| (I4) | edge_padding_high |
1-wymiarowy tensor typu całkowitego | (C1), (C4) |
| (I5) | interior_padding |
Jednowymiarowy tensor typu liczby całkowitej | (C2-C4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C3-C6) |
Ograniczenia
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Przykłady
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją reshape, ale kształt wyniku jest dynamicznie określany za pomocą argumentu output_shape.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub kwantyzowany tensor | (K1–C3) |
| (I2) | output_shape |
1-wymiarowy tensor typu całkowitego | (C4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub kwantyzowany tensor | (C1–C4) |
Ograniczenia
- (C1)
element_type(result)otrzymuje:element_type(operand), jeśli!is_per_axis_quantized(operand).element_type(operand)– oprócz tychquantization_dimension(operand)iquantization_dimension(result)mogą się różnić.
- (C2)
size(operand) = size(result). - (C3) Jeśli
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
- (C4)
size(output_shape) = rank(result).
Przykłady
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantyka
Wyodrębnia wycinek z operand, używając dynamicznie obliczanych indeksów początkowych i tworzy tensor result. start_indices zawiera indeksy początkowe przekroju dla każdego wymiaru, który może zostać dostosowany, a slice_sizes zawiera rozmiary przekroju dla każdego wymiaru. Bardziej oficjalnie:
result[result_index] = operand[operand_index], gdzie:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).operand_index = adjusted_start_indices + result_index.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C4) |
| (I2) | start_indices |
zmienna liczba 0-wymiarowych tensorów typu całkowitego | (C2), (C3) |
| (I3) | slice_sizes |
Jednowymiarowa stała tensora typu si64 |
(C2), (C4), (C5) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C5) |
Ograniczenia
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(slice_sizes) = rank(operand). - (C3)
same(type(start_indices...)). - (C4)
0 <= slice_sizes <= shape(operand). - (C5)
shape(result) = slice_sizes.
Przykłady
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantyka
Tworzy tensor result, który jest równy tensorowi operand, ale wycinek zaczynający się od start_indices jest aktualizowany o wartości w update.
Bardziej formalnie result[result_index] jest zdefiniowana jako:
update[update_index]jeśli0 <= update_index < shape(update)gdzie:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).update_index = result_index - adjusted_start_indices.
- W przeciwnym razie:
operand[result_index].
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1-C4), (C6) |
| (I2) | update |
tensor lub tensor zagregowany z tensorów | (C2), (C3), (C6) |
| (I3) | start_indices |
zmienna liczba 0-wymiarowych tensorów typu całkowitego | (C4), (C5) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result). - (C2)
element_type(update) = element_type(operand). - (C3)
rank(update) = rank(operand). - (C4)
size(start_indices) = rank(operand). - (C5)
same(type(start_indices...)). - (C6)
0 <= shape(update) <= shape(operand).
Przykłady
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
wykładniczo
Semantyka
Wykonuje elementarną operację wykładniczą na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
expz IEEE-754. - Liczby zespolone: wykładnik zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(exponential, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantyka
Wykonuje elementarną operację wykładniczą minus 1 na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
expm1z IEEE-754. - W przypadku liczb zespolonych: zespolona wykładnicza wartość minus 1.
- W przypadku typów skwantowanych:
dequantize_op_quantize(exponential_minus_one, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
FFT
Semantyka
Wykonuje bezpośrednie i odwrotne transformacje Fouriera w przypadku rzeczywistych i zespolonych wejść/wyjść.
fft_type może mieć jedną z tych wartości:
FFT: FFT kompleksowe w przód.IFFT: odwrotna transformacja FFT z kompleksowej na złożoną.RFFT: przesuwanie widoku w kierunku rzeczywistym do złożonego.IRFFT: odwrotna transformacja Fouriera z realnego na zespolony (czyli przyjmuje zespolony, zwraca rzeczywisty).
Bardziej formalnie, jeśli dana funkcja fft przyjmuje jako dane wejściowe 1-wymiarowe tensory złożonych typów, to zwraca 1-wymiarowe tensory tego samego typu co dane wyjściowe i wylicza dyskretną transformację Fouriera:
W przypadku fft_type = FFT wartość result jest zdefiniowana jako końcowy wynik serii obliczeń L, gdzie L = size(fft_length). Na przykład w przypadku L = 3:
result1[i0, ..., :] = fft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Ponadto funkcja ifft, która ma ten sam podpis typu i oblicza odwrotność fft:
W przypadku fft_type = IFFT wartość result jest zdefiniowana jako odwrotna wartość obliczeń dla fft_type = FFT. Na przykład w przypadku L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = ifft(result2[i0, ..., :]).
Funkcja rfft przyjmuje 1-wymiarowe tensory typu zmiennoprzecinkowego i tworzy 1-wymiarowe tensory typu zespolonego o tej samej semantyce zmiennoprzecinkowej. Działa ona w ten sposób:
rfft(real_operand) = truncated_resultgdziecomplex_operand... = (real_operand..., 0.0).complex_result = fft(complex_operand).truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].
(gdy dyskretna transformacja Fouriera jest obliczana dla rzeczywistych operandów, pierwsze
N/2 + 1 elementy wyniku jednoznacznie określają pozostałą część wyniku,
dlatego wynik rfft jest obcinany, aby uniknąć obliczania zbędących elementów).
W przypadku funkcji fft_type = RFFT wartość result jest definiowana jako końcowy wynik serii obliczeń L, gdzie L = size(fft_length). Na przykład w przypadku L = 3:
result1[i0, ..., :] = rfft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Na koniec, jeśli dana funkcja irfft ma tę samą deklarację typu i oblicza odwrotność funkcji rfft:
W przypadku fft_type = IRFFT wartość result jest zdefiniowana jako odwrotna wartość obliczeń dla fft_type = RFFT. Na przykład w przypadku L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = irfft(result2[i0, ..., :]).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego | (C1), (C2), (C4), (C5) |
| (I2) | fft_type |
enum FFT, IFFT, RFFT i IRFFT |
(C2), (C5) |
| (I3) | fft_length |
1-wymiarowa stała tensora typu si64 |
(C1), (C3), (C4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego | (C2), (C4), (C5) |
Ograniczenia
- (C1)
size(fft_length) <= rank(operand). - (C2) Związek między typami elementów
operandiresultjest różny:- Jeśli
fft_type = FFT,element_type(operand)ielement_type(result)mają ten sam typ złożony. - Jeśli
fft_type = IFFT,element_type(operand)ielement_type(result)mają ten sam typ złożony. - Jeśli
fft_type = RFFT,element_type(operand)jest typem zmiennoprzecinkowym, aelement_type(result)jest złożonym typem o tej samej semantyce zmiennoprzecinkowej. - Jeśli
fft_type = IRFFT,element_type(operand)jest typem złożonym, aelement_type(result)jest typem zmiennoprzecinkowym o tej samej semantyce zmiennoprzecinkowej.
- Jeśli
- (C3)
1 <= size(fft_length) <= 3. - (C4) Jeśli wśród elementów
operandiresultznajduje się tensorrealtypu zmiennoprzecinkowego, toshape(real)[-size(fft_length):] = fft_length. - (C5)
shape(result) = shape(operand)z wyjątkiem:- Jeśli
fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1. - Jeśli
fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
- Jeśli
Przykłady
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piętro
Semantyka
Przeprowadza element po elemencie zaokrąglenie w dół tensora operand i generuje tensor result.
Realizuje operację roundToIntegralTowardNegative według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(floor, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
zbierać
Semantyka
Gromadzi wycinki z tensora operand z przesunięć określonych w start_indices i tworzy tensor result.
Ten diagram pokazuje na konkretnym przykładzie, jak elementy w result są mapowane na elementy w operand. Diagram wybiera kilka przykładowych result indeksów i szczegółowo wyjaśnia, którym indeksom operand odpowiadają indeksy.
Bardziej formalnie: result[result_index] = operand[operand_index], gdzie:
batch_dims = [d for d in axes(result) and d not in offset_dims].batch_index = result_index[batch_dims...].start_indexjest zdefiniowany jako:start_indices[bi0, ..., :, ..., biN], gdziebito poszczególne elementy wbatch_index, a element:jest wstawiany w indeksieindex_vector_dim, jeśliindex_vector_dim<rank(start_indices).[start_indices[batch_index]]w innych przypadkach.
- W przypadku
d_operandwaxes(operand):full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])ifd_operand = start_index_map[d_start].- W przeciwnym razie:
full_start_index[d_operand] = 0.
- W przypadku
d_operandwaxes(operand):full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]jeślid_operand = operand_batching_dims[i_batching]id_start = start_indices_batching_dims[i_batching].full_batching_index[d_operand] = 0w innych przypadkach.
offset_index = result_index[offset_dims...].full_offset_index = [oi0, ..., 0, ..., oiN], gdzieoito poszczególne elementy w tablicyoffset_index, a element0jest wstawiany na pozycjach z tabliccollapsed_slice_dimsioperand_batching_dims.operand_index = full_start_index + full_batching_index + full_offset_index.
Jeśli indices_are_sorted ma wartość true, implementacja może zakładać, że dane start_indices są posortowane według argumentu start_index_map. W przeciwnym razie działanie jest niezdefiniowane. Bardziej formalnie, w przypadku wszystkich i1 < i2 z indices(result):full_start_index(i1) <= full_start_index(i2).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
| (I2) | start_indices |
tensor typu liczba całkowita | (C2–C3), (C14), (C17), (C22) |
| (I3) | offset_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C4-C5), (C22) |
| (I4) | collapsed_slice_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C6-C9), (C22) |
| (I5) | operand_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
| (I6) | start_indices_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C13-C17) |
| (I7) | start_index_map |
1-wymiarowa stała tensora typu si64 |
(C3), (C18-C19) |
| (I8) | index_vector_dim |
stała typu si64 |
(C2-C3), (C15), (C22) |
| (I9) | slice_sizes |
1-wymiarowa stała tensora typu si64 |
(C9), (C12), (C20–C22) |
| (I10) | indices_are_sorted |
stała typu i1 |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C5), (C22-C23) |
Ograniczenia
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims). - (C8)
0 <= collapsed_slice_dims < rank(operand). - (C9)
slice_sizes[collapsed_slice_dims...] <= 1. - (C10)
is_sorted(operand_batching_dims). - (C11)
0 <= operand_batching_dims < rank(operand). - (C12)
slice_sizes[operand_batching_dims...] <= 1. - (C13)
is_unique(start_indices_batching_dims). - (C14)
0 <= start_indices_batching_dims < rank(start_indices). - (C15)
index_vector_dim not in start_indices_batching_dims. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims). - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...). - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims)). - (C19)
0 <= start_index_map < rank(operand). - (C20)
size(slice_sizes) = rank(operand). - (C21)
0 <= slice_sizes <= shape(operand). - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)gdzie:batch_dim_sizes = shape(start_indices), z tym że nie uwzględnia wymiarustart_indicesodpowiadającego wymiarowiindex_vector_dim.offset_dim_sizes = slice_sizes, z tą różnicą, że rozmiary wymiarów w poluslice_sizesodpowiadające wartościomcollapsed_slice_dimsioperand_batching_dimsnie są uwzględniane.- Funkcja
combineumieszczabatch_dim_sizesna osiach odpowiadających wartościbatch_dimsioffset_dim_sizesna osiach odpowiadających wartościoffset_dims.
- (C23)
element_type(operand) = element_type(result).
Przykłady
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantyka
Zwraca rozmiar danego dimension w ramach operand. Bardziej formalnie:
result = dim(operand, dimension). Semantyka dotyczy tylko komponentu kształtu danego typu. Typ elementu może być dowolny.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub kwantyzowany tensor | (C1) |
| (I2) | dimension |
stała typu si64 |
(C1) |
Wyniki
| Nazwa | Typ |
|---|---|
result |
Tensor 0-wymiarowy typu si32 |
Ograniczenia
- (C1)
0 <= dimension < rank(operand).
Przykłady
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantyka
Wyodrębnia element w pozycji index krotki operand i tworzy result. Więcej formalnie: result = operand[index].
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tablice | (C1), (C2) |
| (I2) | index |
stała typu si32 |
(C1), (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
dowolny obsługiwany typ | (C2) |
Ograniczenia
- (C1)
0 <= index < size(operand). - (C2)
type(result) = tuple_element_types(operand)[index].
Przykłady
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
jeśli
Semantyka
Wynikiem jest wynik wykonania dokładnie jednej funkcji z true_branch lub false_branch w zależności od wartości elementu pred. W bardziej formalnym ujęciu: result =
pred ? true_branch() : false_branch().
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | pred |
0-wymiarowy tensor typu i1 |
|
| (I2) | true_branch |
funkcja | (C1-C3) |
| (I3) | false_branch |
funkcja | (C1), (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C3) |
Ograniczenia
- (C1)
input_types(true_branch) = input_types(false_branch) = []. - (C2)
output_types(true_branch) = output_types(false_branch). - (C3)
type(results...) = output_types(true_branch).
Przykłady
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Semantyka
Wyodrębnia część urojona z elementów tensora operand i tworzy tensor result. W formalnej postaci dla każdego elementu x:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego | (C1), (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego | (C1), (C2) |
Ograniczenia
- (C1)
shape(result) = shape(operand). - (C2) Parametr
element_type(result)jest zdefiniowany jako:complex_element_type(element_type(operand))jeżeliis_complex(operand).- W przeciwnym razie:
element_type(operand).
Przykłady
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
infeed
Semantyka
Odczytuje dane z infeedu i generuje results.
Semantyka elementu infeed_config jest zdefiniowana przez implementację.
results składa się z wartości ładunku, które występują na początku, oraz tokena, który występuje na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).
Dane wejściowe
| Etykieta | Nazwa | Typ |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
stała typu string |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C1-C3) |
Ograniczenia
- (C1)
0 < size(results). - (C2)
is_empty(result[:-1])lubis_tensor(type(results[:-1])). - (C3)
is_token(type(results[-1])).
Przykłady
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semantyka
Wypełnia tensor output wartościami w rosnącej kolejności, zaczynając od 0 w wymiarze iota_dimension. Bardziej formalnie,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
output |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
0 <= iota_dimension < rank(output).
Przykłady
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantyka
Sprawdza, czy wartość w x jest skończona (tzn.nie jest skończona (tzn. nie jest liczbą +Inf, -Inf ani NaN) i generuje tensor y. Realizuje operację isFinitez specyfikacji IEEE-754. W przypadku typów kwantowych wynik to zawsze true.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | x |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
y |
tensor typu logicznego | (C1) |
Ograniczenia
- (C1)
shape(x) = shape(y).
Przykłady
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantyka
Wykonuje operację logarytmiczną z uwzględnieniem elementów na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
logz IEEE-754. - W przypadku liczb zespolonych: logarytm zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(log, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantyka
Wykonuje logarytm z elementami oraz 1 operację na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
logp1z IEEE-754. - W przypadku liczb zespolonych: logarytm zespolony plus 1.
- W przypadku typów skwantowanych:
dequantize_op_quantize(log_plus_one, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistyczna
Semantyka
Wykonuje operacje logistyczne związane z elementami na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
division(1, addition(1, exp(-x)))z IEEE-754. - W przypadku liczb zespolonych: złożona logistyczna.
- W przypadku typów skwantowanych:
dequantize_op_quantize(logistic, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mapa
Semantyka
Stosuje funkcję mapy computation do inputs na dimensions i tworzy tensor result.
Więcej formalnie: result[result_index] = computation(inputs...[result_index]).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1–C4) |
| (I2) | dimensions |
1-wymiarowa stała tensora typu si64 |
(K3) |
| (I3) | computation |
funkcja | (K4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C4) |
Ograniczenia
- (C1)
shape(inputs...) = shape(result). - (C2)
0 < size(inputs) = N. - (C3)
dimensions = range(rank(inputs[0])). - (C4)
computationma typ(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, gdzieEi = element_type(inputs[i])iE' = element_type(result).
Przykłady
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maksimum
Semantyka
Wykonuje operację elementarnego maksimum na tensorach lhs i rhs oraz zwraca tensor result. W zależności od typu elementu wykonuje te działania:
- W przypadku wartości logicznych: logiczny LUB.
- W przypadku liczb całkowitych: maksymalna liczba całkowita.
- Dla liczb zmiennoprzecinkowych:
maximumz IEEE-754. - W przypadku liczb zespolonych: maksymalna wartość leksykograficzna pary
(real, imaginary). Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560). - W przypadku typów skwantowanych:
dequantize_op_quantize(maximum, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
kwantowy tensor lub tensor kwantowy | (C1) |
| (I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Przykłady
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Semantyka
Wykonuje elementarną operację min na tensorach lhs i rhs oraz zwraca tensor result. W zależności od typu elementu wykonuje te działania:
- W przypadku wartości logicznych: operator logiczny OR.
- W przypadku liczb całkowitych: minimalna liczba całkowita.
- Dla liczb zmiennoprzecinkowych:
minimumz IEEE-754. - W przypadku liczb zespolonych: leksykograficzne minimum dla pary
(real, imaginary). Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560). - W przypadku typów kwantowych:
dequantize_op_quantize(minimum, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
kwantowy tensor lub tensor kwantowy | (C1) |
| (I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Przykłady
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
pomnóż
Semantyka
Wykonuje element po elemencie iloczyn dwóch tensorów lhs i rhs oraz generuje tensor result. W zależności od typu elementu:
- W przypadku wartości logicznych: logiczne I.
- W przypadku liczb całkowitych: mnożenie liczb całkowitych.
- W przypadku jednostek zmiennoprzecinkowych:
multiplicationw standardzie IEEE-754. - Liczby zespolone: mnożenie zespolone.
- W przypadku typów skwantowanych:
dequantize_op_quantize(multiply, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
kwantowy tensor lub tensor kwantowy | (C1) |
| (I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negacja
Semantyka
Wykonuje negację tensora operand z uwzględnieniem elementów i generuje tensor result. W zależności od typu elementu:
- W przypadku liczb całkowitych ze znakiem: negacja liczby całkowitej.
- W przypadku liczb bez znaku: bitowy zapis jako liczba ze znakiem, negacja liczby, bitowy zapis z powrotem jako liczba bez znaku.
- Dla liczb zmiennoprzecinkowych:
negatez IEEE-754. - Liczby zespolone: negacja zespolona.
- W przypadku typów kwantowych:
dequantize_op_quantize(negate, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
nie
Semantyka
Przeprowadza elementarną operację NOT na tensorze operand i tworzy tensor result.
W zależności od typu elementu:
- W przypadku wartości logicznych: NOT logiczna.
- W przypadku liczb całkowitych: bitowa operacja NOT.
Argumenty
| Nazwa | Typ | Ograniczenia |
|---|---|---|
operand |
tensor typu logicznego lub całkowitego | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu logicznego lub całkowitego | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result).
Przykłady
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantyka
Zapewnienie, że operacje, które generują operand, są wykonywane przed operacjami, które zależą od result, oraz zapobieganie przemieszczaniu operacji przez barierę przez przekształcenia kompilatora. W przeciwnym razie operacja jest
tożsamością, czyli result = operand.
Argumenty
| Nazwa | Typ | Ograniczenia |
|---|---|---|
operand |
zmienna liczba tensorów, tensorów skwantowanych na podstawie tensora lub tokenów | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
zmienna liczba tensorów, tensorów skwantowanych na podstawie tensora lub tokenów | (C1) |
Ograniczenia
- (C1)
type(operand...) = type(result...).
Przykłady
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
lub
Semantyka
Wykonuje elementarną operację OR na dwóch tensorach lhs i rhs, tworząc tensor result. W zależności od typu elementu:
- W przypadku wartości logicznych: operator logiczny LUB.
- W przypadku liczb całkowitych: operator bitowy LUB.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu liczba całkowita lub logiczna | (C1) |
| (I2) | rhs |
tensor typu liczba całkowita lub logiczna | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu liczba całkowita lub logiczna | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result).
Przykłady
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
outfeed
Semantyka
Zapisuje inputs w outfiedzie i generuje token result.
Semantyka elementu outfeed_config jest zdefiniowana przez implementację.
Dane wejściowe
| Etykieta | Nazwa | Typ |
|---|---|---|
| (I1) | inputs |
zmienna liczba tensorów lub zagęszczonych tensorów |
| (I2) | token |
token |
| (I3) | outfeed_config |
stała typu string |
Wyniki
| Nazwa | Typ |
|---|---|
result |
token |
Przykłady
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
Semantyka
Rozszerza operand przez wypełnienie przestrzeni wokół tensora oraz między elementami tensora za pomocą podanego padding_value.
Parametry edge_padding_low i edge_padding_high określają ilość wypełnień dodawanych na niskich wartościach (obok indeksu 0) i na wysokich wartościach (obok najwyższego indeksu) w każdym wymiarze. Ilość wypełnienia może być ujemna, a wartość bezwzględna ujemnego wypełnienia wskazuje liczbę elementów do usunięcia z wybranego wymiaru.
interior_padding określa dopełnienie między dwoma dowolnymi elementami w każdym wymiarze, które może nie być ujemne. Wypełnienie wewnętrzne występuje przed wypełnieniem krawędzi, dzięki czemu wypełnienie krawędzi o ujemnej wartości spowoduje usunięcie elementów z operanda z wypełnieniem wewnętrznym.
Bardziej formalnie result[result_index] jest zdefiniowana jako:
operand[operand_index]jeśliresult_index = edge_padding_low + operand_index * (interior_padding + 1).- W przeciwnym razie:
padding_value.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor 0-wymiarowy lub kwantowy tensor na tensor | (C1) |
| (I3) | edge_padding_low |
1-wymiarowa stała tensora typu si64 |
(C1), (C4) |
| (I4) | edge_padding_high |
1-wymiarowa stała tensora typu si64 |
(C1), (C4) |
| (I5) | interior_padding |
1-wymiarowa stała tensora typu si64 |
(C2-C4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C3-C6) |
Ograniczenia
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Przykłady
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantyka
Tworzy partition_id bieżącego procesu.
Wyniki
| Nazwa | Typ |
|---|---|
result |
0-wymiarowy tensor typu ui32 |
Przykłady
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Semantyka
Wykonuje za pomocą elementu liczby bitów ustawione w tensorze operand i tworzy tensor result.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu liczba całkowita | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result).
Przykłady
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
moc
Semantyka
Wykonuje elementową potęgowanie tensora lhs przez tensor rhs i tworzy tensor result. W zależności od typu elementu:
- W przypadku liczb całkowitych: wykładnik całkowity.
- Dla liczb zmiennoprzecinkowych:
powz IEEE-754. - W przypadku liczb zespolonych: wykładnik zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(power, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
| (I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora według tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semantyka
Wyodrębnia rzeczywistą część z operand element po elemencie i tworzy tensor result. W formalnej postaci dla każdego elementu x:
real(x) = is_complex(x) ? real_part(x) : x.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego | (C1), (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego | (C1), (C2) |
Ograniczenia
- (C1)
shape(result) = shape(operand). - (C2) Parametr
element_type(result)jest zdefiniowany jako:complex_element_type(element_type(operand))jeżeliis_complex(operand).- W przeciwnym razie:
element_type(operand).
Przykłady
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantyka
Pobiera dane z kanału z parametrem channel_id i tworzy results.
Jeśli is_host_transfer to true, operacja przenosi dane z hosta. W przeciwnym razie dane są przenoszone z innego urządzenia. Co to oznacza, zależy od implementacji. Ta flaga powiela informacje podane w flagach channel_type, dlatego w przyszłości planujemy zachować tylko jedną z nich (#666).
results składa się z wartości ładunku, które występują na początku, oraz tokena, który występuje na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | token |
token |
(C4) |
| (I2) | channel_id |
stała typu si64 |
|
| (I3) | channel_type |
enum DEVICE_TO_DEVICE i HOST_TO_DEVICE |
(C1) |
| (I4) | is_host_transfer |
stała typu i1 |
(C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C2–C4) |
Ograniczenia
- (C1)
channel_typejest zdefiniowany jako:HOST_TO_DEVICEjeżeliis_host_transfer = true,DEVICE_TO_DEVICEw innych przypadkach.
- (C2)
0 < size(results). - (C3)
is_empty(result[:-1])lubis_tensor(type(results[:-1])). - (C4)
is_token(type(results[-1])).
Przykłady
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
zmniejszyć
Semantyka
Stosuje funkcję redukcji body do inputs i init_values wzdłuż dimensions i zwraca tensory results.
Kolejność redukcji jest zdefiniowana w implementacji, co oznacza, że body i init_values muszą utworzyć monoid, aby zagwarantować, że operacja da takie same wyniki przy wszystkich danych wejściowych we wszystkich implementacjach. Jednak w przypadku wielu popularnych rabatów to założenie nie jest spełnione. Przykładowo dodawanie liczb zmiennoprzecinkowych w przypadku wartości body i zera w przypadku wartości init_values nie tworzy monoidu, ponieważ dodawanie liczb zmiennoprzecinkowych nie jest skojarzone.
Bardziej formalnie: results...[j0, ..., jR-1] = reduce(input_slices_converted), gdzie:
input_slices = inputs...[j0, ..., :, ..., jR-1], gdzie:są wstawiane w miejscudimensions.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).reduce(input_slices_converted) = exec(schedule)dla niektórych drzew binarnychschedulegdzie:exec(node) = body(exec(node.left), exec(node.right)).exec(leaf) = leaf.value.
scheduleto pełne drzewo binarne zdefiniowane przez implementację, którego przeglądanie w kolejności zgodnej z ich poziomem w drzewie obejmuje:- wartości
input_slices_converted...[index]dla wszystkichindexw tablicyindex_space(input_slices_converted)w rosnącym porządku leksykograficznymindex. - Wplecione z określoną przez implementację ilością znaków
init_values_convertedw określonych przez implementację pozycjach.
- wartości
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1-C4), (C6), (C7) |
| (I2) | init_values |
zmienna liczba 0-wymiarowych tensorów lub tensorów skwantowanych na tensor | (C2), (C3) |
| (I3) | dimensions |
1-wymiarowa stała tensora typu si64 |
(C4), (C5), (C7) |
| (I4) | body |
funkcja | (C6) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C3), (C7), (C8) |
Ograniczenia
- (C1)
same(shape(inputs...)). - (C2)
element_type(inputs...) = element_type(init_values...). - (C3)
0 < size(inputs) = size(init_values) = size(results) = N. - (C4)
0 <= dimensions < rank(inputs[0]). - (C5)
is_unique(dimensions). - (C6)
bodyma typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)gdzieis_promotable(element_type(inputs[i]), Ei). - (C7)
shape(results...) = shape(inputs...), z tym że nie uwzględnia wymiaru rozmiary w elementachinputs...odpowiadających elementomdimensions. - (C8)
element_type(results[i]) = Eidla wszystkichiw[0,N).
Przykłady
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantyka
Przeprowadza konwersję elementów operand na inny typ zmiennoprzecinkowy, który używa exponent_bits i mantissa_bits, a następnie z powrotem na pierwotny typ zmiennoprzecinkowy. Tworzy też tensor output.
Bardziej formalnie:
- Bity mantissy pierwotnej wartości są aktualizowane w celu zaokrąglenia pierwotnej wartości do najbliższej wartości, którą można przedstawić za pomocą funkcji
mantissa_bitsprzy użyciu semantykiroundToIntegralTiesToEven. - Następnie, jeśli
mantissa_bitsjest mniejsza od pierwotnej wartości, bity modliszki są obcinane domantissa_bits. - Następnie, jeśli bity wykładnika wyniku pośredniego nie mieszczą się w zakresie określonym przez
exponent_bits, wynik pośredni jest przepełniony do nieskończoności z użyciem znaku pierwotnego lub jest podpięty do zera z użyciem znaku pierwotnego. - W przypadku typów skwantowanych wykonuje operację
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
| (I2) | exponent_bits |
stała typu si32 |
(C2) |
| (I3) | mantissa_bits |
stała typu si32 |
(K3) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
output |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(output). - (C2)
1 <= exponent_bits. - (C3)
0 <= mantissa_bits.
Przykłady
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantyka
W ramach każdej grupy procesów w siatce procesów StableHLO wykonuje redukcję (za pomocą funkcji computations) wartości tensora operand z każdego procesu, dzieli wynik redukcji wzdłuż osi scatter_dimension na części, a potem rozprasza te części między procesami, aby wygenerować result.
Operacja dzieli siatkę procesu StableHLO na process_groups, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)jeślichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)ifchannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)ifchannel_id > 0 and use_global_device_ids = true.
Następnie w każdym wierszu process_group:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).result@receiver = parts@sender[receiver_index]dla wszystkichsenderwprocess_group, gdziereceiver_index = process_group.index(receiver).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C7), (C8) |
| (I2) | scatter_dimension |
stała typu si64 |
(C1), (C2), (C8) |
| (I3) | replica_groups |
2-wymiarowa stała tensora typu si64 |
(C3-C5) |
| (I4) | channel_id |
stała typu si64 |
(C6) |
| (I5) | use_global_device_ids |
stała typu i1 |
(C6) |
| (I6) | computation |
funkcja | (C7) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C8-C9) |
Ograniczenia
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0. - (C2)
0 <= scatter_dimension < rank(operand). - (C3)
is_unique(replica_groups). - (C4)
size(replica_groups)jest zdefiniowana jako:num_replicas, jeśli używana jest właściwośćcross_replica.num_replicas, jeśli używana jest właściwośćcross_replica_and_partition.num_processes, jeśli używana jest właściwośćflattened_ids.
- (C5)
0 <= replica_groups < size(replica_groups). - (C6) Jeśli
use_global_device_ids = true, tochannel_id > 0. - (C7)
computationma typ(tensor<E>, tensor<E>) -> (tensor<E>), gdzieis_promotable(element_type(operand), E). - (C8)
shape(result) = shape(operand)z wyjątkiem:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
- (C9)
element_type(result) = E.
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantyka
Stosuje funkcję redukcji body do okien inputs i init_values oraz zwraca wartość results.
Ten diagram pokazuje, jak elementy w elementach results... są obliczane na podstawie elementów inputs... na konkretnym przykładzie.
Bardziej oficjalnie:
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(patrz reduce), gdzie:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).window_start = result_index * window_strides.window_end = window_start + (window_dimensions - 1) * window_dilations + 1.windows = slice(padded_inputs..., window_start, window_end, window_dilations).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | inputs |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C1–C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | init_values |
liczba zmiennoprzecinkowa tensorów 0-wymiarowych lub kwantyzowanych tensorów na tensor | (C1), (C13) |
| (I3) | window_dimensions |
Jednowymiarowa stała tensora typu si64 |
(C4), (C5), (C15) |
| (I4) | window_strides |
1-wymiarowa stała tensora typu si64 |
(C6), (C7), (C15) |
| (I5) | base_dilations |
1-wymiarowa stała tensora typu si64 |
(C8), (C9), (C15) |
| (I6) | window_dilations |
1-wymiarowa stała tensora typu si64 |
(C10), (C11), (C15) |
| (I7) | padding |
2-wymiarowa stała tensora typu si64 |
(C12), (C15) |
| (I8) | body |
funkcja | (C13) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C1), (C14–C16) |
Ograniczenia
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N. - (C2)
same(shape(inputs...)). - (C3)
element_type(inputs...) = element_type(init_values...). - (C4)
size(window_dimensions) = rank(inputs[0]). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(inputs[0]). - (C7)
0 < window_strides. - (C8)
size(base_dilations) = rank(inputs[0]). - (C9)
0 < base_dilations. - (C10)
size(window_dilations) = rank(inputs[0]). - (C11)
0 < window_dilations. - (C12)
shape(padding) = [rank(inputs[0]), 2]. - (C13)
bodyma typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)gdzieis_promotable(element_type(inputs[i]), Ei). - (C14)
same(shape(results...)). - (C15)
shape(results[0]) = num_windowsgdzie:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
- (C16)
element_type(results[i]) = Eidla wszystkichiw[0,N).
Przykłady
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
reszta
Semantyka
Wykonuje element po elemencie resztę dzielnika lhs i dzielonki rhs tensorów oraz zwraca tensor result.
Bardziej oficjalnie znak wyniku jest wyliczany z dywidendy, a wartość bezwzględna wyniku jest zawsze mniejsza od wartości bezwzględnej dzielnika.
Reszta jest obliczana według wzoru: lhs - d * rhs, gdzie d jest określona przez:
- W przypadku liczb całkowitych:
stablehlo.divide(lhs, rhs). - W przypadku liczb zmiennoprzecinkowych:
division(lhs, rhs)z IEEE-754 z atrybutem zaokrągleniaroundTowardZero. - W przypadku liczb zespolonych: do ustalenia (#997).
- W przypadku typów skwantowanych:
dequantize_op_quantize(remainder, lhs, rhs, type(result)).
W przypadku elementów zmiennoprzecinkowych ta operacja jest niezgodna z operacją remainder ze specyfikacji IEEE-754, w której d jest wartością całkowitą zbliżoną do dokładnej wartości parametru lhs/rhs z opisem równomiernym.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
| (I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantyka
Tworzy replica_id bieżącego procesu.
Wyniki
| Nazwa | Typ |
|---|---|
result |
0-wymiarowy tensor typu ui32 |
Przykłady
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
zmienić kształt
Semantyka
Przekształca tensor operand w tensor result. W zasadzie jest to zachowanie tego samego reprezentacji kanonicznej, ale potencjalnie zmiany kształtu, np. z tensor<2x3xf32> na tensor<3x2xf32> lub tensor<6xf32>.
Dokładniej rzecz ujmując: result[result_index] = operand[operand_index], gdzie result_index i operand_index mają to samo miejsce w kolejności leksykograficznej index_space(result) i index_space(operand).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub kwantyzowany tensor | (C1-C3) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor (tensor kwantowy) | (C1-C3) |
Ograniczenia
- (C1)
element_type(result)otrzymuje:element_type(operand), jeśli!is_per_axis_quantized(operand).element_type(operand)– oprócz tychquantization_dimension(operand)iquantization_dimension(result)mogą się różnić.
- (C2)
size(operand) = size(result). - (C3) Jeśli
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
Przykłady
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
odwróć
Semantyka
Odwraca kolejność elementów w operand wzdłuż podanego wymiaru dimensions i tworzy tensor result. Bardziej formalnie:
result[result_index] = operand[operand_index] gdzie:
operand_index[d] = dim(result, d) - result_index[d] - 1jeżelidwdimensions.operand_index[d] = result_index[d]w innych przypadkach.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1), (C3) |
| (I2) | dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C3) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C3) |
Ograniczenia
- (C1)
type(operand) = type(result). - (C2)
is_unique(dimensions). - (C3)
0 <= dimensions < rank(result).
Przykłady
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantyka
Generuje liczby losowe za pomocą algorytmu rng_distribution i tworzy tensor result o kształcie shape.
Jeśli rng_distribution = UNIFORM, liczby losowe są generowane zgodnie z rozkładem stałym w przedziałzie [a, b). Jeśli a >= b, zachowanie jest nieokreślone.
Jeśli rng_distribution = NORMAL, liczby losowe są generowane zgodnie z rozkładem normalnym, ze średnią = a i odchyleniem standardowym = b.
Jeśli b < 0, zachowanie jest nieokreślone.
Dokładny sposób generowania liczb losowych jest zdefiniowany w implementacji. Mogą na przykład być deterministyczne, ale nie muszą.
W rozmowach z wielu interesariuszami okazało się, że ta opcja jest w istocie wycofana, dlatego w przyszłości planujemy jej usunięcie (#597).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | a |
0-wymiarowy tensor typu liczby całkowitej, wartości logicznej lub zmiennoprzecinkowej | (C1), (C2) |
| (I2) | b |
0-wymiarowy tensor typu liczby całkowitej, wartości logicznej lub zmiennoprzecinkowej | (C1), (C2) |
| (I3) | shape |
1-wymiarowa stała tensora typu si64 |
(C3) |
| (I4) | rng_distribution |
wyliczenie UNIFORM i NORMAL |
(C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu całkowita, logiczna lub zmiennoprzecinkowa | (C1-C3) |
Ograniczenia
- (C1)
element_type(a) = element_type(b) = element_type(result). - (C2) Jeśli
rng_distribution = NORMAL, tois_float(a). - (C3)
shape(result) = shape.
Przykłady
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantyka
Zwraca pole output wypełnione jednolitymi losowymi bitami i zaktualizowany stan wyjściowy output_state za pomocą algorytmu generatora liczb pseudorandom rng_algorithm z przypisanym stanem początkowym initial_state. Wyjście jest zdefiniowane jako funkcja deterministyczna initial_state, ale nie jest deterministyczna w różnych implementacjach.
rng_algorithm jest jedną z tych:
DEFAULT: algorytm zdefiniowany przez implementację.THREE_FRY: wariant algorytmu Threefry zdefiniowany przez implementację.*PHILOX: wariant algorytmu Philox zdefiniowany przez implementację.*
* Patrz: Salmon et al. SC 2011. Losowe liczby równoległe: to bardzo proste.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | rng_algorithm |
wyliczenie DEFAULT, THREE_FRY i PHILOX |
(C2) |
| (I2) | initial_state |
Jednowymiarowy tensor typu ui64 |
(C1), (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
output_state |
1-wymiarowy tensor typu ui64 |
(C1) |
output |
tensor typu liczba całkowita lub zmiennoprzecinkowa |
Ograniczenia
- (C1)
type(initial_state) = type(output_state). - (C2)
size(initial_state)zdefiniowano jako:- zdefiniowaną przez implementację, jeśli
rng_algorithm = DEFAULT. 2jeżelirng_algorithm = THREE_FRY.2lub3, jeślirng_algorithm = PHILOX.
- zdefiniowaną przez implementację, jeśli
Przykłady
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantyka
Wykonuje zaokrąglanie z poziomu elementu do najbliższej liczby całkowitej, zrywając remisy od zera na tensorze operand i tworzy tensor result. Realizuje operację roundToIntegralTiesToAway zgodnie ze specyfikacją IEEE-754. W przypadku typów skompresowanych wykonuje działanie dequantize_op_quantize(round_nearest_afz, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantyka
Zaokrągla elementy tensora operand do najbliższej liczby całkowitej, rozstrzygając remisy na korzyść parzystej liczby całkowitej, i tworzy tensor result. Realizuje operację roundToIntegralTiesToEven według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(round_nearest_even, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantyka
Przeprowadza elementarną operację odwrotnego pierwiastka kwadratowego na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
rSqrtz IEEE-754. - Liczby zespolone: odwrotny pierwiastek kwadratowy z liczby zespolonej.
- W przypadku typów kwantowych:
dequantize_op_quantize(rsqrt, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
rozproszenie
Semantyka
Wyświetla tensory results, które są równe tensorom inputs, ale kilka wycinków określonych przez scatter_indices zostało zaktualizowanych wartościami updates przy użyciu metody update_computation.
Ten diagram pokazuje na konkretnym przykładzie, jak elementy w updates... są mapowane na elementy w results.... Diagram przedstawia kilka przykładowych indeksów updates... i szczegółowo wyjaśnia, z którymi indeksami results... są one powiązane.
Więcej formalnie dla wszystkich update_index w index_space(updates[0]):
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].update_scatter_index = update_index[update_scatter_dims...].start_indexto:scatter_indices[si0, ..., :, ..., siN], gdziesito poszczególne elementy wupdate_scatter_index, a element:jest wstawiany w indeksieindex_vector_dim, jeśliindex_vector_dim<rank(scatter_indices).[scatter_indices[update_scatter_index]]w innych przypadkach.
- Za
d_inputwaxes(inputs[0]):full_start_index[d_input] = start_index[d_start]jeślid_input = scatter_dims_to_operand_dims[d_start].full_start_index[d_input] = 0w innych przypadkach.
- W przypadku
d_inputwaxes(inputs[0]):full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]jeślid_input = input_batching_dims[i_batching]id_start = scatter_indices_batching_dims[i_batching].- W przeciwnym razie:
full_batching_index[d_input] = 0.
update_window_index = update_index[update_window_dims...].full_window_index = [wi0, ..., 0, ..., wiN], gdziewito poszczególne elementy w tablicyupdate_window_index, a element0jest wstawiany na pozycjach z tablicinserted_window_dimsiinput_batching_dims.result_index = full_start_index + full_batching_index + full_window_index.
W związku z tym results = exec(schedule, inputs), gdzie:
scheduleto określona przez implementację permutacja funkcjiindex_space(updates[0]).exec([update_index, ...], results) = exec([...], updated_results), gdzie:- Jeśli
result_indexmieści się w zakresieshape(results...) updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_resultsto kopiaresultsz wartościąresults...[result_index]ustawioną naupdated_values....- W innym przypadku
updated_results = results.
- Jeśli
exec([], results) = results.
Jeśli indices_are_sorted to true, implementacja może założyć, że scatter_indices są posortowane zgodnie z scatter_dims_to_operand_dims. W przeciwnym razie zachowanie jest nieokreślone. Bardziej formalnie dotyczy wszystkich i1 < i2 od
indices(result), full_start_index(i1) <= full_start_index(i2).
Jeśli unique_indices to true, implementacja może założyć, że wszystkie indeksy result_index, które są rozproszone, są unikalne. Jeśli unique_indices to
true, ale indeksy, do których jest rozproszony, nie są unikalne, działanie jest niezdefiniowane.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
| (I2) | scatter_indices |
tensor typu liczba całkowita | (C4), (C15), (C19), (C22) |
| (I3) | updates |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C3-C6), (C8) |
| (I4) | update_window_dims |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C7-C8) |
| (I5) | inserted_window_dims |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C9-C11) |
| (I6) | input_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
| (I7) | scatter_indices_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
1-wymiarowa stała tensora typu si64 |
(C19-C21) |
| (I9) | index_vector_dim |
stała typu si64 |
(C4), (C16), (C19), (C22) |
| (I10) | indices_are_sorted |
stała typu i1 |
|
| (I11) | unique_indices |
stała typu i1 |
|
| (I12) | update_computation |
funkcja | (C23) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C24-C25) |
Ograniczenia
- (C1)
same(shape(inputs...)). - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...)). - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes), gdzie:update_scatter_dim_sizes = shape(scatter_indices), z tym że nie uwzględnia wymiaruscatter_indicesodpowiadającego wymiarowiindex_vector_dim.update_window_dim_sizes <= shape(inputs[0]), z tym że nie uwzględnia wymiarówinputs[0]odpowiadających wymiarominserted_window_dimsiinput_batching_dims.- Funkcja
combineumieszcza poleupdate_scatter_dim_sizesna osiach odpowiadających podziałowiupdate_scatter_dimsiupdate_window_dim_sizesna osiach odpowiadających wartościupdate_window_dims.
- (C5)
0 < size(inputs) = size(updates) = N. - (C6)
element_type(updates...) = element_type(inputs...). - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims). - (C8)
0 <= update_window_dims < rank(updates[0]). - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims). - (C11)
0 <= inserted_window_dims < rank(inputs[0]). - (C12)
is_sorted(input_batching_dims). - (C13)
0 <= input_batching_dims < rank(inputs[0])). - (C14)
is_unique(scatter_indices_batching_dims). - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices). - (C16)
index_vector_dim not in scatter_indices_batching_dims. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims). - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...). - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)). - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0]). - (C22)
0 <= index_vector_dim <= rank(scatter_indices). - (C23)
update_computationma typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), gdzieis_promotable(element_type(inputs[i]), Ei). - (C24)
shape(inputs...) = shape(results...). - (C25)
element_type(results[i]) = Eidla wszystkichiw[0,N).
Przykłady
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
wybierz
Semantyka
Tworzy tensor result, w którym każdy element jest wybierany z tensora on_true lub on_false na podstawie wartości odpowiadającego elementu tensora pred.
W formie bardziej oficjalnej: result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index], gdzie pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]. W przypadku typów kwantowych wybiera dequantize_select_quantize(pred, on_true, on_false, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | pred |
tensor typu i1 |
(C1) |
| (I2) | on_true |
tensor lub tensor zagregowany z tensorów | (C1-C2) |
| (I3) | on_false |
tensor lub tensor zagregowany z tensorów | (K2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C2) |
Ograniczenia
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true). - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).
Przykłady
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantyka
Rozprasza wartości z tensora source za pomocą funkcji scatter na podstawie wyniku reduce_window z tensora input za pomocą funkcji select i tworzy tensor result.
Na diagramie poniżej widać, jak elementy w elementach result są obliczane na podstawie elementów operand i source.
Bardziej formalnie:
selected_values = reduce_window_without_init(...)z tymi danymi wejściowymi:inputs = [operand].window_dimensions,window_stridesipadding, które są używane w takiej postaci, w jakiej zostały przesłane.base_dilations = windows_dilations = 1.bodyjest zdefiniowana jako:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;gdzie
E = element_type(operand)ireduce_window_without_initdziałają dokładnie tak jakreduce_window, z tą różnicą, żeschedulewartości podstawowejreduce(patrz redukcja) nie zawiera wartości init. Obecnie nie określono, co się stanie, jeśli odpowiednie okno nie zawiera wartości (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter), gdzie:source_values = [source[source_index] for source_index in source_indices].selected_index(source_index) = operand_indexjeśliselected_values[source_index]ma elementoperandzoperand_index.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1-C4), (C6), (C8-C11) |
| (I2) | source |
kwantowy tensor lub tensor kwantowy | (C1), (C2) |
| (I3) | init_value |
0-wymiarowy tensor lub tensor kwantyzowany na podstawie tensora | (C3) |
| (I4) | window_dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C5) |
| (I5) | window_strides |
1-wymiarowa stała tensora typu si64 |
(C2), (C6), (C7) |
| (I6) | padding |
2-wymiarowa stała tensora typu si64 |
(C2), (C8) |
| (I7) | select |
funkcja | (C9) |
| (I8) | scatter |
funkcja | (C10) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C11-C12) |
Ograniczenia
- (C1)
element_type(operand) = element_type(source). - (C2)
shape(source) = num_windows, gdzie:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
- (C3)
element_type(init_value) = element_type(operand). - (C4)
size(window_dimensions) = rank(operand). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(operand). - (C7)
0 < window_strides. - (C8)
shape(padding) = [rank(operand), 2]. - (C9)
selectma typ(tensor<E>, tensor<E>) -> tensor<i1>, gdzieE = element_type(operand). - (C10)
scatterma typ(tensor<E>, tensor<E>) -> tensor<E>, gdzieis_promotable(element_type(operand), E). - (C11)
shape(operand) = shape(result). - (C12)
element_type(result) = E.
Przykłady
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
wyślij
Semantyka
Wysyła kod inputs do kanału channel_id i tworzy token result.
Jeśli is_host_transfer to true, operacja przenosi dane do hosta. W przeciwnym razie dane zostaną przeniesione na inne urządzenie. Co to oznacza, zależy od implementacji. Ta flaga powiela informacje podane w flagach channel_type, dlatego w przyszłości planujemy zachować tylko jedną z nich (#666).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | inputs |
zmienna liczba tensorów lub zagęszczonych tensorów | |
| (I2) | token |
token |
|
| (I3) | channel_id |
stała typu si64 |
|
| (I4) | channel_type |
enum DEVICE_TO_DEVICE i DEVICE_TO_HOST |
(C1) |
| (I5) | is_host_transfer |
stała typu i1 |
(C1) |
Wyniki
| Nazwa | Typ |
|---|---|
result |
token |
Ograniczenia
- (C1)
channel_typejest zdefiniowany jako:DEVICE_TO_HOSTjeżeliis_host_transfer = true,DEVICE_TO_DEVICEw innych przypadkach.
Przykłady
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantyka
Przeprowadza elementową operację przesunięcia w lewo na tensorze lhs o liczbę bitów rhs i tworzy tensor result.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu liczba całkowita | (C1) |
| (I2) | rhs |
tensor typu liczby całkowitej | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result).
Przykłady
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantyka
Przesuwa elementowo w prawo o określoną liczbę bitów (rhs) elementarną operację arytmetyczną na tensorze lhs i tworzy tensor result.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu liczba całkowita | (C1) |
| (I2) | rhs |
tensor typu liczby całkowitej | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result).
Przykłady
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantyka
Wykonuje logiczne operację przesunięcia w prawo na tensorze lhs o rhs bitów i tworzy tensor result.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu liczba całkowita | (C1) |
| (I2) | rhs |
tensor typu liczby całkowitej | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result).
Przykłady
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
podpisywanie
Semantyka
Zwraca znak operand element po elemencie i tworzy tensor result.
Bardziej formalnie, w przypadku każdego elementu x semantyka może być wyrażona za pomocą składni Pythona w ten sposób:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(sign, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Semantyka
Wykonuje elementarną operację sinusa na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
sinz IEEE-754. - W przypadku liczb zespolonych: sinus zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(sine, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
wycinek
Semantyka
Wyodrębnia wycinek z operand, używając statycznie obliczonych indeksów początkowych i tworzy tensor result. start_indices zawiera indeksy początkowe przekroju dla każdego wymiaru, limit_indices zawiera indeksy końcowe (wykluczające) przekroju dla każdego wymiaru, a strides zawiera kroki dla każdego wymiaru.
Bardziej formalnie: result[result_index] = operand[operand_index], gdzie operand_index = start_indices + result_index * strides.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1-C3), (C5) |
| (I2) | start_indices |
1-wymiarowa stała tensora typu si64 |
(C2), (C3), (C5) |
| (I3) | limit_indices |
1-wymiarowa stała tensora typu si64 |
(C2), (C3), (C5) |
| (I4) | strides |
1-wymiarowa stała tensora typu si64 |
(C2), (C4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C5) |
Ograniczenia
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand). - (C3)
0 <= start_indices <= limit_indices <= shape(operand). - (C4)
0 < strides. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides).
Przykłady
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
sortuj
Semantyka
Sortuje jednowymiarowe wycinki obiektu inputs wzdłuż wymiaru dimension według wartości comparator i tworzy results.
W przeciwieństwie do podobnych danych wejściowych w innych operacjach funkcja dimension umożliwia stosowanie wartości ujemnych. Wybrane wartości mają następującą interpretację: W przyszłości możemy zablokować tę funkcję ze względu na spójność (problem #1377).
Jeśli is_stable ma wartość prawda, sortowanie jest stabilne, co oznacza, że zachowana jest względna kolejność elementów uznanych przez porównywarkę za równe. W przypadku pojedynczego wejścia dwa elementy e1 i e2 są uznawane za równe przez porównywacz, jeśli i tylko jeśli comparator(e1, e2) = comparator(e2, e1) = false. Zobacz formalny opis poniżej, aby dowiedzieć się, jak to uogólniać do wielu danych wejściowych.
Więcej formalnie dla wszystkich result_index w index_space(results[0]):
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.result_slice = [ri0, ..., :, ..., riR-1], gdzieriNto pojedyncze elementy w komórceresult_index, a:jest wstawiony w miejscuadjusted_dimension.inputs_together = (inputs[0]..., ..., inputs[N-1]...).results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).- gdzie
sortsortuje jednowymiarowy wycinek w kolejności niemalejącej, oczekując, że funkcjacomparator_togetherzwrócitrue, jeśli argument po lewej stronie jest mniejszy niż argument drugiej strony po prawej stronie. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1-C5) |
| (I2) | dimension |
stała typu si64 |
(C4) |
| (I3) | is_stable |
stała typu i1 |
|
| (I4) | comparator |
funkcja | (C5) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C2), (C3) |
Ograniczenia
- (C1)
0 < size(inputs). - (C2)
type(inputs...) = type(results...). - (C3)
same(shape(inputs...) + shape(results...)). - (C4)
-R <= dimension < R, gdzieR = rank(inputs[0]). - (C5)
comparatorma typ(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, gdzieEi = element_type(inputs[i]).
Przykłady
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantyka
Wykonuje operację pierwiastka kwadratowego uwzględniającą elementy na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
squareRootz IEEE-754. - W przypadku liczb zespolonych: pierwiastek kwadratowy z liczby zespolonej.
- W przypadku typów skwantowanych:
dequantize_op_quantize(sqrt, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
odejmij
Semantyka
Wykonuje element po elemencie odejmowanie dwóch tensorów lhs i rhs oraz generuje tensor result. W zależności od typu elementu:
- W przypadku liczb całkowitych: odejmowanie liczb całkowitych.
- Dla liczb zmiennoprzecinkowych:
subtractionz IEEE-754. - W przypadku liczb zespolonych: odejmowanie liczb zespolonych.
- W przypadku typów skwantowanych:
dequantize_op_quantize(subtract, lhs, rhs, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
| (I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora według tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Przykłady
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantyka
Wykonuje elementarną operację pochodnej cząstkowej na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
tanz IEEE-754. - W przypadku liczb zespolonych: tangens zespolony.
- W przypadku typów kwantowych:
dequantize_op_quantize(tan, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semantyka
Wykonuje elementarną operację tangensu hiperbolicznego na tensorze operand i tworzy tensor result. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
tanhz IEEE-754. - W przypadku liczb zespolonych: tangens hiperboliczny zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(tanh, operand, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result).
Przykłady
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transponować
Semantyka
Zamienia wymiary tensora operand za pomocą permutation i tworzy tensor result. Bardziej formalnie: result[result_index] = operand[operand_index]gdzie result_index[d] = operand_index[permutation[d]].
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor lub kwantyzowany tensor | (C1-C4) |
| (I2) | permutation |
1-wymiarowa stała tensora typu si64 |
(C2-C4) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor lub kwantyzowany tensor | (C1), (C3-C4) |
Ograniczenia
- (C1)
element_type(result)otrzymuje:element_type(operand), jeśli!is_per_axis_quantized(operand).element_type(operand)z tym, żequantization_dimension(operand)iquantization_dimension(result)mogą się różnić.
- (C2)
permutationjest permutacjąrange(rank(operand)). - (C3)
shape(result) = dim(operand, permutation...). - (C4) Jeśli
is_per_axis_quantized(result), toquantization_dimension(operand) = permutation(quantization_dimension(result)).
Przykłady
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantyka
Rozwiązuje partie układów równań liniowych z macierzowymi współczynnikami o dolnej lub górnej trójkątności.
Bardziej formalnie, przy założeniu a i b, result[i0, ..., iR-3, :, :] jest rozwiązaniem op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :], gdy left_side jest równe true lub x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :], gdy left_side jest równe false, przy czym zmienna x jest rozwiązaniem op(a), a jej wartość jest określana przez transpose_a, który może być równy:
NO_TRANSPOSE: wykonaj operację, używającaw postaci domyślnej.TRANSPOSE: wykonaj operację na przekształceniu macierzowyma.ADJOINT: wykonaj operację na sprzężonym przekształceniu macierzowyma.
Dane wejściowe są odczytywane tylko z dolnego trójkąta elementu a, jeśli lower to true, lub z górnego trójkąta elementu a, jeśli nie. Dane wyjściowe są zwracane w tym samym trójkącie, a wartości w drugim trójkącie są zdefiniowane przez implementację.
Jeśli unit_diagonal ma wartość true, implementacja może założyć, że elementy diagonalne funkcji a są równe 1. W przeciwnym razie działanie jest nieokreślone.
W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | a |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1-C3) |
| (I2) | b |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1-C4) |
| (I3) | left_side |
stała typu i1 |
(K3) |
| (I4) | lower |
stała typu i1 |
|
| (I5) | unit_diagonal |
stała typu i1 |
|
| (I6) | transpose_a |
enum NO_TRANSPOSE, TRANSPOSE i ADJOINT |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_element_type(a) = baseline_element_type(b). - (C2)
2 <= rank(a) = rank(b) = R. - (C3) Związek między
shape(a)ashape(b)jest zdefiniowany w ten sposób:shape(a)[:-3] = shape(b)[:-3].dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
- (C4)
baseline_type(b) = baseline_type(result).
Przykłady
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tablice
Semantyka
Tworzy kwaternion result z wartości val.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | val |
zmienna liczba wartości | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
krotka | (C1) |
Ograniczenia
- (C1)
resultma typtuple<E0, ..., EN-1>, gdzieEi = type(val[i]).
Przykłady
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantyka
Przekształca element po elemencie kwantyzowany tensor operand na tensor zmiennoprzecinkowy result zgodnie z parametrami kwantyzacji zdefiniowanymi przez typ operand.
Więcej formalnie: result = dequantize(operand).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor kwantowy | (C1), (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu zmiennoprzecinkowego | (C1), (C2) |
Ograniczenia
- (C1)
shape(operand) = shape(result). - (C2)
element_type(result) = expressed_type(operand).
Przykłady
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantyka
Wykonuje konwersję tensora zmiennoprzecinkowego lub kwantyzowanego tensora operand na kwantyzowany tensor result zgodnie z parametrami kwantyzacji zdefiniowanych przez typ result.
Bardziej formalnie,
- Jeśli
is_float(operand):result = quantize(operand, type(result)).
- Jeśli
is_quantized(operand):float_result = dequantize(operand).result = quantize(float_result, type(result)).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
tensor typu zmiennoprzecinkowego lub kwantyzowanego | (C1), (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor kwantyzowany | (C1), (C2) |
Ograniczenia
- (C1)
shape(operand) = shape(result). - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).
Przykłady
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
podczas gdy
Semantyka
Wyświetla dane wyjściowe z wykonania funkcji body co najmniej 0 razy, gdy funkcja cond zwraca wartość true. Bardziej oficjalnie semantyka można wyrazić za pomocą składni Pythona:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Zachowanie nieskończonego pętli jest nieznane (#383).
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | operand |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C1-C3) |
| (I2) | cond |
funkcja | (C1) |
| (I3) | body |
funkcja | (C2) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C3) |
Ograniczenia
- (C1)
condma typ(T0, ..., TN-1) -> tensor<i1>, gdzieTi = type(operand[i]). - (C2)
bodyma typ(T0, ..., TN-1) -> (T0, ..., TN-1), gdzieTi = type(operand[i]). - (C3)
type(results...) = type(operand...).
Przykłady
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semantyka
Wykonuje elementarną operację XOR na dwóch tensorach lhs i rhs oraz zwraca tensor result. W zależności od typu elementu wykonuje te działania:
- W przypadku wartości logicznych: XOR logiczny.
- W przypadku liczb całkowitych: XOR bitowy.
Dane wejściowe
| Etykieta | Nazwa | Typ | Ograniczenia |
|---|---|---|---|
| (I1) | lhs |
tensor typu logicznego lub całkowitego | (C1) |
| (I2) | rhs |
tensor typu logicznego lub całkowitego | (C1) |
Wyniki
| Nazwa | Typ | Ograniczenia |
|---|---|---|
result |
tensor typu logicznego lub całkowitego | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result).
Przykłady
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interoperacyjność dialektów
Obecnie programy StableHLO w naturze czasami zawierają operacje, które nie są zdefiniowane przez StableHLO.
Moduł, funkcja, wywołanie i zwracanie
StableHLO używa operacji MLIR z upstream do operacji ModuleOp, FuncOp, CallOp i ReturnOp. Zrobiliśmy to, aby zapewnić lepszą współpracę z dotychczasowymi mechanizmami MLIR, ponieważ wiele przydatnych przejść jest napisanych z uwzględnieniem operacji FuncOp i ModuleOp, a wiele ścieżek kompilacji oczekuje obecności tych operacji. Do tych operacji stosowane są gwarancje pełnej zgodności. Jeśli cokolwiek w tych działaniach zmieni się w niekompatybilny sposób (np. zostanie usunięty), zostaną dodane odpowiedniki StableHLO, aby zachować zgodność.
CHLO
Opset CHLO zawiera operacje wyższego poziomu, które rozkładają się na StableHLO. Obecnie nie ma żadnych gwarancji zgodności w przypadku CHLO. Aby zapewnić zgodność, przed serializacją należy użyć przejścia chlo-legalize-to-stablehlo.
Operacje kształtu
W społeczności często stosuje się w programach dynamicznych StableHLO pewne operacje z podstawowych dialektów MLIR do wykonywania obliczeń kształtu.
Najczęściej są to operacje shape, takie jak shape_of lub num_elements, operacje tensor, takie jak dim lub from_elements, oraz wbudowany typ index.
Dynamism RFC > O2 wskazuje, że te typy wykraczają poza zakres, ale ze względu na interoperacyjność uwzględniono w nim niektóre typy index. W przypadku tych typów i wersji nie ma gwarancji zgodności. Za pomocą przejścia shape-legalize-to-stablehlo można przekształcić te operacje w w pełni obsługiwane operacje StableHLO.
Wycofane operacje
Istnieje kilka operacji StableHLO, które zostały odziedziczone z MHLO, są wycofane i zostaną usunięte z StableHLO. Szczegółowe informacje na ten temat znajdziesz w dokumencie StableHLO v1.0 Cleanup #2283 (Czyszczenie StableHLO w wersji 1.0). W przypadku tych wycofanych rozwiązań występuje problem z lokalizatorem to #2340.
Operacje te można podzielić na kilka kategorii:
- Kategoria operacji StableHLO „Nie w HLO” – początkowo były one częścią zestawu operacji StableHLO, ale później uznano, że nie pasują do niego:
broadcast,create_token,cross-replica-sum,dot,einsum,torch_index_select,unary_einsum(#3). - Nieużywane operacje – te operacje mogły być przydatne w danym momencie, ale były niedostatecznie rozwinięte lub ścieżki korzystające z tych operacji zostały przebudowane tak, aby ich już nie wymagały. Dotyczy to funkcji
map,tuple(#598),get_tuple_element,rng,complex(#560) oraz funkcjiwindow_reversal(#1181).
Niektóre z tych operacji można łatwo usunąć, ponieważ można je wyrazić za pomocą istniejących operacji (broadcast, create_token, cross-replica-sum, dot, unary_einsum) i zostaną usunięte po upływie obecnego okna zgodności (6 miesięcy). Inne są nadal analizowane do usunięcia (porównania: einsum, get_tuple_element, map, rng torch_index_select, tuple, complex, window_reversal). Oczekujemy na opinię społeczności. Te operacje zostaną usunięte lub dodane do specyfikacji z pełną obsługą. Dopóki nie poznamy tych funkcji, gwarantujemy zgodność tylko przez 6 miesięcy.
Wykonanie
Wykonywanie sekwencyjne
Program StableHLO jest wykonywany przez podanie wartości wejściowych do funkcji main i obliczenie wartości wyjściowych. Wartości wyjściowe funkcji są obliczane przez wykonanie grafu operacji z korzenia w odpowiednim elemencie return.
Kolejność wykonywania jest określana przez implementację, o ile jest zgodna z przepływem danych, czyli jeśli operacje są wykonywane przed ich użyciem. W StableHLO wszystkie operacje o efektach ubocznych zużywają 1 token i generują 1 token (wiele tokenów można zmultipleksować w jeden token za pomocą funkcji after_all), więc kolejność wykonywania efektów ubocznych jest również zgodna z dataflow. Na przykład w programie poniżej są 2 możliwe zamówienia: %0 → %1 → %2 → return i %1 → %0 → %2 → return.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Bardziej oficjalnie proces StableHLO składa się z tych elementów:
1) program StableHLO, 2) stanów operacji (jeszcze niewykonany, już wykonany) oraz 3) wartości pośrednich, nad którymi dany proces pracuje.
Proces zaczyna się od wartości wejściowych funkcji main, przechodzi przez wykres operacji aktualizowania stanów operacji i wartości pośrednich, a kończy się wartościami wyjściowymi. Dalsze formalizowanie jest jeszcze do ustalenia (#484).
równoległe wykonanie,
Programy StableHLO mogą być wykonywane równolegle, podzielone na siatkę procesów 2D num_replicas według num_partitions, z których obydwa mają typ ui32.
W siatce procesów StableHLO jednocześnie wykonywanych jest num_replicas * num_partitions procesów StableHLO. Każdy proces ma unikalny element process_id = (replica_id, partition_id), gdzie replica_id w replica_ids = range(num_replicas) i partition_id w partition_ids = range(num_partitions) mają typ ui32.
Rozmiar siatki procesów jest znany statycznie w przypadku każdego programu (w przyszłości planujemy, aby był on częścią programów StableHLO #650), a pozycja w siatce procesów jest znana statycznie w przypadku każdego procesu. Każdy proces ma dostęp do swojej pozycji w siatce procesów za pomocą operacji replica_id i partition_id.
W siatce procesów programy mogą być takie same (w stylu „Jedno program, wiele danych”), różne (w stylu „Wiele programów, wiele danych”) lub mieścić się gdzieś pośrodku. W przyszłości planujemy wprowadzić obsługę innych idiomów definiowania równoległych programów StableHLO, w tym GSPMD (#619).
W ramach siatki procesów procesy są w większości niezależne od siebie – mają oddzielne stany operacji, oddzielne wartości wejścia/pośrednie/wyjścia i większość operacji jest wykonywana oddzielnie w ramach procesów, z wyjątkiem niewielkiej liczby operacji zbiorczych opisanych poniżej.
Ponieważ większość operacji wykonuje się tylko z użyciem wartości z tego samego procesu, zwykle odwołania do tych wartości za pomocą ich nazw są jednoznaczne.
Jednak opis semantyki działań zbiorowych jest niewystarczający. Powoduje to powstanie zapisu name@process_id, który odwołuje się do wartości name w konkretnym procesie. (z tego punktu widzenia niekwalifikowana wartość name może być traktowana jako skrót od wartości name@(replica_id(), partition_id())).
Kolejność wykonywania procesów jest określana przez implementację, z wyjątkiem synchronizacji wprowadzonej przez komunikację punkt-punkt i operacje zbiorcze, jak opisano poniżej.
Komunikacja punkt-punkt
Procesy StableHLO mogą komunikować się ze sobą za pomocą kanałów StableHLO. Kanał jest reprezentowany przez dodatni identyfikator typu si64. Za pomocą różnych operacji można wysyłać wartości do kanałów i odbierać je z kanałów.
Dalsze formalizowanie, np. skąd pochodzą te identyfikatory kanałów, jak procesy programów je rozpoznają i jakie synchronizacja jest przez nie wprowadzana, jest jeszcze do ustalenia (#484).
Komunikacja strumieniowa
Każdy proces StableHLO ma dostęp do 2 interfejsów strumieniowego przesyłania danych:
- Infeed, z których można odczytać treści.
- Outfeed, na którym można zapisać dane.
W przeciwieństwie do kanałów, które służą do komunikacji między procesami, a więc mają procesy po obu końcach, dane wejściowe i dane wyjściowe mają zdefiniowany drugi koniec w ramach implementacji.
Dalszą formalizację, np. o tym, jak strumieniowa komunikacja wpływa na kolejność wykonywania i jaki rodzaj synchronizacji jest przez nie wprowadzana, to do ustalenia (#484).
Operacje zbiorcze
W StableHLO jest 6 zbiorowych operacji: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute i reduce_scatter. Wszystkie te operacje dzielą procesy w sieci procesów StableHLO na grupy procesów StableHLO i wykonują wspólne obliczenia w każdej z nich niezależnie od innych grup procesów.
W ramach każdej grupy procesów operacje zbiorcze mogą wprowadzać barierę synchronizacji. Dalsza formalizacja, np. określenie, kiedy dokładnie ma miejsce synchronizacja, jak dokładnie procesy docierają do tej bariery i co się dzieje, jeśli tego nie zrobią, jest jeszcze nierozstrzygnięta (#484).
Jeśli grupa procesów obejmuje komunikację między partycjami, czyli w grupie procesów są procesy o różnych identyfikatorach partycji, wykonanie operacji zbiorczej wymaga kanału, a operacja zbiorcza musi zawierać dodatnią wartość channel_id typu si64. Komunikacja między replikami
nie wymaga kanałów.
Obliczenia wykonywane przez zbiorcze operacje są specyficzne dla poszczególnych operacji i są opisane w poszczególnych sekcjach dotyczących operacji. Jednak strategie, według których siatka procesów jest dzielona na grupy procesów, są wspólne dla tych operacji i opisane w tej sekcji. W formalnym ujęciu StableHLO obsługuje 4 strategię:
cross_replica
W ramach każdej grupy procesów odbywa się tylko komunikacja między replikami. Ta strategia przyjmuje replica_groups, czyli listę list identyfikatorów replik, i wylicza iloczyn kartezjański replica_groups i partition_ids. replica_groupsmusi zawierać unikalne elementy i objmować wszystkie replica_ids. Bardziej formalnie, używając składni Pythona:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku wartości replica_groups = [[0, 1], [2, 3]] i num_partitions = 2 funkcja cross_replica zwróci wartość [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].
cross_partition
W każdej grupie procesów odbywa się tylko komunikacja między partycjami. Ta strategia wykorzystuje partition_groups – listę list identyfikatorów partycji – i oblicza kartezjański iloczyn wartości partition_groups przez replica_ids.
partition_groups musi zawierać unikalne elementy i objmować wszystkie partition_ids.
Bardziej oficjalnie, używając składni Pythona:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku wartości partition_groups = [[0, 1]] i num_replicas = 4 funkcja cross_partition zwróci wartość [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].
cross_replica_and_partition
W każdej grupie procesów mogą występować zarówno komunikaty między replikami, jak i między partycjami. Ta strategia wykorzystuje replica_groups – listę list identyfikatorów replik – i oblicza iloczy kartezjańskie każdego elementu replica_group według parametru partition_ids. replica_groups musi zawierać unikalne elementy i obejmować wszystkie
replica_ids. Bardziej oficjalnie, używając składni Pythona:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku wartości replica_groups = [[0, 1], [2, 3]] i num_partitions = 2 funkcja cross_replica_and_partition zwróci wartość [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].
flattened_ids
Ta strategia przyjmuje flattened_id_groups – listę list „spłaszczonego” identyfikatora procesu w formie replica_id * num_partitions + partition_id – i przekształca je w identyfikatory procesów. flattened_id_groups musi zawierać unikalne elementy i pokrywać wszystkie elementy process_ids. Bardziej formalnie, używając składni Pythona:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 i num_partitions = 2, flattened_ids zwróci wartość [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].
Dokładność
Obecnie StableHLO nie gwarantuje dokładności liczbowej, ale może się to zmienić w przyszłości (#1156).
Semantyka wykonania kwantowanej operacji
Interpretacja kwantowanych operacji StableHLO może się różnić w zależności od wymagań i możliwości sprzętowych. Na przykład niektóre urządzenia mogą interpretować operacje kwantyzacji za pomocą strategii „dequantyzacja, wykonanie operacji zmiennoprzecinkowej i na koniec ponowna kwantyzacja”. Inne mogą wykonywać całe obliczenia z użyciem arytmetyki całkowitej. Dlatego interpretacja kwantyzowanych operacji StableHLO zależy wyłącznie od konkretnego wdrożenia. Interpretacja kwantyzacji hybrydowej (#1575) powinna opierać się na jej semantyce zgodnie ze specyfikacją (na stronie 1792).
Błędy
Programy StableHLO są weryfikowane za pomocą obszernego zbioru ograniczeń dotyczących poszczególnych operacji, co wyklucza wiele klas błędów przed czasem wykonywania. Nadal jednak mogą wystąpić błędy, np. przepełnienie liczb całkowitych, dostęp poza zakresem itp. O ile nie określono inaczej, wszystkie te błędy powodują zachowanie określone przez implementację, ale może się to w przyszłości zmienić (#1157).
Wyjątki dotyczące liczb zmiennoprzecinkowych
Wyjątkiem od tej reguły są wyjątki o typie zmiennoprzecinkowym w programach StableHLO, które mają dobrze zdefiniowane działanie. Operacje, które powodują wyjątki zdefiniowane przez standard IEEE-754 (nieprawidłowa operacja, dzielenie przez 0, przepełnienie, niedopełnienie lub nieprecyzyjne wyjątki) dają wyniki domyślne (zdefiniowane przez standard) i kontynuują wykonywanie bez podnoszenia odpowiedniej flagi stanu; podobnie jak obsługa wyjątku raiseNoFlag ze standardu. Wyjątki dotyczące operacji niestandardowych (np. złożonej arytmetyki i niektórych funkcji transcendentalnych) są zdefiniowane przez implementację.
niezgodności kształtów,
StableHLO obsługuje tensory o dynamicznej wielkości. Jednak kształty muszą być zgodne w czasie wykonywania, w przeciwnym razie zachowanie jest nieokreślone. StableHLO nie udostępnia bezpośrednio operacji, która może potwierdzić, że tensor ma określony kształt w czasie działania. Generowanie prawidłowego kodu jest obowiązkiem producenta.
Przykładem prawidłowego programu jest program poniżej. Jednak w czasie działania dokładne kształty obiektów %arg0 i %arg1 muszą być takie same. W przeciwnym razie działanie programu będzie niezdefiniowane:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notacja
Do opisania składni w tym dokumencie użyto zmodyfikowanego rodzaju składni EBNF (ISO/IEC 14977:1996, Wikipedia). 1) regułę zdefiniowano za pomocą metody ::=, a nie =:
2) konkatenacja jest wyrażana za pomocą juxtaposition, a nie ,.
Do opisu semantyki (np. w sekcjach „Typy”, „Stałe” i „Operacje”) używamy formuł opartych na składni Pythona rozszerzonej o obsługę zwięzłego wyrażania operacji na tablicach, jak opisano poniżej. Takie rozwiązanie sprawdza się w przypadku małych fragmentów kodu, ale w rzadkich przypadkach, gdy potrzebne są większe fragmenty kodu, używamy składni Pythona, która jest zawsze wprowadzana bezpośrednio.
Wzory
Na przykładzie ze specyfikacji dot_general omówimy, jak działają formuły. Jedno z ograniczeń tej operacji wygląda tak: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
Nazwy używane w tej formule pochodzą z 2 źródeł: 1) funkcji globalnych, np. dim, 2) definicji elementów członkowskich odpowiadającego elementu programu, np. lhs, lhs_batching_dimensions, rhs i rhs_batching_dimensions, zdefiniowanych w sekcji „Wejścia” w funkcji dot_general.
Jak już wspomnieliśmy, składnia tej formuły jest oparta na Pythonie z kilkoma rozszerzeniami ułatwiającymi zwiężenie kodu. Aby lepiej zrozumieć tę formułę, przekształcimy ją w tradycyjną składnię Pythona.
A) W tych formułach używamy wyrażenia = do reprezentowania równości, więc pierwszym krokiem do uzyskania składni Pythona jest zastąpienie = przez == w ten sposób: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).
B) Formuły te obsługują też wielokropki (...), które zamieniają wyrażenia skalarne w wyrażenia tensorowe. Krótko mówiąc, f(xs...) oznacza mniej więcej „dla każdego wektora x w tensorze xs oblicz wektor f(x), a potem zwracaj wszystkie te wyniki wektorów jako wynik tensora”. W standardowej składni Pythona nasza przykładowa formuła zamienia się na:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions].
Dzięki elipsom często można uniknąć pracy na poziomie poszczególnych skalarów. W niektórych trudnych przypadkach można jednak użyć nieformalnej składni na niższym poziomie, np. w formule start_indices[bi0, ..., :, ..., biN] ze specyfikacji gather. W trosce o zwiększenie zwięzłości nie podajemy dokładnego formalizmu, w którym można przetłumaczyć taką składnię na zwykły Python. Mamy nadzieję, że w każdym przypadku będzie ona intuicyjna.
Jeśli zauważysz, że niektóre formuły są nieczytelne, daj nam znać, a spróbujemy je ulepszyć.
Zauważysz też, że formuły używają wielokropków do rozwijania wszelkich rodzajów list, w tym tensorów, list tensorów (które np. mogą powstać na podstawie różnej liczby tensorów) itp. W tym innym obszarze nie ma dokładnego formalności (np. listy nie są nawet częścią intuicyjnego systemu StaableHLO).
C) Ostatnim wartym uwagi sposobem zapisu, którego używamy, jest domyślne przesyłanie. Chociaż przesunięcie StableHLO nie obsługuje jawnego nadawania, formuły już tak, co również służy zwiększeniu zwięzłości. Krótko mówiąc, jeśli w kontekście, w którym oczekuje się tensora, używany jest element skalarny, element skalarny jest rozprowadzany do oczekiwanego kształtu.
Aby kontynuować przykład z dot_general, zastosuj tu kolejne ograniczenie: 0 <= lhs_batching_dimensions < rank(lhs). Zgodnie z definicją w specyfikacji dot_general element lhs_batching_dimensions jest tensorem, ale zarówno 0, jak i rank(lhs) są skalarami. Gdy zastosujemy przekazywanie niejawne, formuła zmieni się na [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].
Po zastosowaniu do konkretnej operacji dot_general ta formuła oceni tensor wartości logicznych. Gdy formuły są używane jako ograniczenia, ograniczenie jest spełnione, jeśli formuła zwraca wartość true lub tensor zawierający tylko elementy true.
Nazwy
W formułach zakres leksykalny obejmuje: 1) funkcje globalne, 2) definicje członków,
3) definicje lokalne. Poniżej znajdziesz listę funkcji globalnych. Lista definicji elementów zależy od elementu programu, do którego zastosowano notację:
- W przypadku operacji definicje elementów obejmują nazwy wprowadzone w sekcjach „Wejścia” i „Wyjścia”.
- W przypadku innych elementów definicje członków obejmują części strukturalne elementu programu, nazwane na podstawie odpowiednich nieterminali EBNF. W większości przypadków nazwy tych części strukturalnych są uzyskiwane przez konwersję nazw nieterminali na format snake case (np.
IntegerLiteral=>integer_literal), ale czasami nazwy są w tym procesie skracane (np.QuantizationStorageType=>storage_type). W takim przypadku nazwy są wprowadzane w sposób jawny, podobnie jak sekcje „Wejścia” i „Wyjścia” w specyfikacjach operacji. - Definicje członków zawsze zawierają element
self, który odnosi się do odpowiedniego elementu programu.
Wartości
Podczas obliczania formuł są one używane do obsługi tych typów wartości:
1) Value (rzeczywiste wartości, np. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; zawsze znają swój typ),
2) Placeholder (przyszłe wartości, np. lhs, rhs lub result; ich rzeczywiste wartości nie są jeszcze znane, znane są tylko ich typy),
3) Type (typy zdefiniowane w sekcji „Typy”),
4) Function (funkcje globalne zdefiniowane w sekcji „Funkcje”).
W zależności od kontekstu nazwy mogą się odnosić do różnych wartości. W szczególności sekcja „Semantyka” w przypadku operacji (i odpowiednich sekcji w przypadku innych elementów programu) definiuje logikę czasu wykonywania, dzięki czemu wszystkie dane wejściowe są dostępne jako Value.
Z kolei sekcja „Ograniczenia” odnosząca się do operacji (i ich odpowiedników) definiuje logikę „czas kompilowania”, tj.coś, co jest zwykle wykonywane przed uruchomieniem. Dlatego jako Value dostępne są tylko stałe dane wejściowe, a inne dane wejściowe – tylko jako Placeholder.
| Nazwy | W sekcji „Semantyka” | W sekcji „Ograniczenia” |
|---|---|---|
| Funkcje globalne | Function |
Function |
| stałe wejścia, | Value |
Value |
| Nieciągłe dane wejściowe | Value |
Placeholder |
| Wyniki | Value |
Placeholder |
| Definicje lokalne | Zależy od definicji | Zależy od definicji |
Przeanalizujmy przykładową operację transpose:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
W przypadku tej operacji permutation jest stałą, więc jest dostępna jako Value zarówno w semantyce, jak i w ograniczeniach. Z kolei operand i result są dostępne jako Value w semantyce, ale tylko jako Placeholder w ograniczeniach.
Funkcje
Konstrukcja typów
Nie ma funkcji, których można używać do tworzenia typów. Zamiast tego używamy bezpośrednio składni typu, ponieważ jest ona zazwyczaj bardziej zwięzła. Na przykład (tensor<E>, tensor<E>) -> (tensor<E>) zamiast function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).
Funkcje w typach
element_typejest zdefiniowany na typach tensorów i typach zaokrąglonych tensorów oraz zwraca odpowiednio częśćTensorElementTypelubQuantizedTensorElementTypeodpowiadającegoTensorTypelubQuantizedTensorType.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valueto skrót do:is_quantized(x) and quantization_dimension(x) is not None.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valueto skrót odis_quantized(x) and quantization_dimension(x) is None.is_promotable(x: Type, y: Type) -> boolsprawdza, czy typxmożna podnieść do typuy. JeślixiytoQuantizedTensorElementType, promocja jest stosowana tylko dostorage_type. Ta konkretna wersja promocji jest obecnie używana w kontekście obliczeń redukcji (więcej informacji znajdziesz w RFC).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valueto skrót dois_quantized_tensor_element_type(x).is_type_name(x: Value | Placeholder | Type) -> Value. Dostępne dla wszystkich typów. Na przykładis_float(x)zwracatrue, jeślixjest wartością typuFloatType. Jeślixjest wartością lub miejscem zastępczym, ta funkcja jest skrótem funkcjiis_type_name(type(x)).max_value(x: Type) -> Valuezwraca maksymalną wartośćTensorElementType. Jeślixnie jestTensorElementType, zwracaNone.min_value(x: Type) -> Valuezwraca minimalną możliwą wartośćTensorElementType. Jeślixnie jestTensorElementType, zwracaNone.member_name(x: Value | Placeholder | Type) -> Any. Dostępne dla wszystkich definicji członkówmember_namewszystkich typów. Na przykładtensor_element_type(x)zwraca częśćTensorElementTypeodpowiadającego elementuTensorType. Jeślixjest wartością lub zmienną, ta funkcja jest skrótem domember_name(type(x)). Jeślixnie jest typem, który ma odpowiedni element, wartość lub zastępnik tego typu, zwracaNone.is_empty_algorithm(*args: Type)sprawdza, czy wszystkie pola algorytmu dot są ustawione naNone. Jest to konieczne, ponieważ algorytmy dot mają określone przez implementację domyślne zachowania, więc podanie domyślnej wartości byłoby nieprawidłowe.
Budowa wartości
operation_name(*xs: Value | Type) -> Value. Dostępne w przypadku wszystkich operacji. Na przykład funkcjaadd(lhs, rhs)przyjmuje 2 wartości tensoralhsirhsoraz zwraca wynik oceny operacjiaddz tymi danymi wejściowymi. W przypadku niektórych operacji, np.broadcast_in_dim, typy ich danych wyjściowych są „nośne”, czyli potrzebne do oceny operacji. W tym przypadku funkcja przyjmuje te typy jako argumenty.
Funkcje wartości
Dostępne są wszystkie operatory i funkcje Pythona. Na przykład w Pythonie dostępne są zarówno subskrypcje, jak i wycinki, które można indeksować w tensorach, kwantowych tensorach i tuplach.
to_destination_type(x: Value, destination_type: Type) -> Valuejest zdefiniowany na tensorach i zwraca przekonwertowaną wartośćxna podstawie wartościtype(x)idestination_typew następujący sposób:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Trwają wstępne dyskusje na temat połączenia operacji convert, uniform_quantize i uniform_dequantize (#1576).
Po scaleniu nie potrzebujesz powyższej funkcji i możemy zamiast niej użyć nazwy operacji dla funkcji convert.
Funkcja
is_nan(x: Value) -> Valuejest zdefiniowana na tensorach i zwraca wartośćtrue, jeśli wszystkie elementyxsą równeNaN, a w przeciwnym razie zwraca wartośćfalse. Jeślixnie jest tensorem, zwracaNone.Funkcja
is_sorted(x: Value) -> Valuejest zdefiniowana na tensorach i zwraca wartośćtrue, jeśli elementyxsą posortowane w kolejności rosnącej według rosnącej kolejności leksykograficznej ich indeksów lubfalsew innym przypadku. Jeślixnie jest tensorem, zwracaNone.Funkcja
is_unique(x: Value) -> Valuejest zdefiniowana na tensorach i zwraca wartośćtrue, jeślixnie zawiera podwójnych elementów, a w przeciwnym razie zwraca wartośćfalse. Jeślixnie jest tensorem, zwracaNone.member_name(x: Value) -> Anyjest zdefiniowany dla wszystkich definicji elementówmember_namewszystkich wartości. Na przykładreal_part(x)zwraca elementRealPartnależący do odpowiadającego mu elementuComplexConstant. Jeślixnie jest wartością, która ma odpowiedni element, zwracaNone.Funkcja
same(x: Value) -> Valuejest zdefiniowana na tensorach i zwraca wartośćtrue, jeśli wszystkie elementy tensoraxsą sobie równe, a w przeciwnym razie zwraca wartośćfalse. Jeśli tensor nie ma elementów, liczy się to jako „wszystkie są sobie równe”, tzn. funkcja zwracatrue. Jeślixnie jest tensorem, zwracaNone.split(x: Value, num_results: Value, axis: Value) -> Valuejest zdefiniowany na tensorach i zwracanum_resultsprzekrojexwzdłuż osiaxis. Jeślixnie jest tensorem anidim(x, axis) % num_results != 0, zwracaNone.is_defined_in_parent_scope(x: Value) -> Valuejest zdefiniowana na ciągach znaków i zwracatrue, jeślixto nazwa funkcji zdefiniowanej w tym samym zakresie co funkcja nadrzędna dla odpowiedniej op.Argument
is_namespaced_op_name(x: Value) -> Valuejest zdefiniowany na ciągach znaków i zwraca wartośćtrue, jeślixjest prawidłową nazwą op, czyli spełnia to wyrażenie regularne:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Obliczenia kształtu
axes(x: Value | Placeholder | Type) -> Valueto skrót do opcjirange(rank(x)).dim(x: Value | Placeholder | Type, axis: Value) -> Valueto skrót doshape(x)[axis].dims(x: Value | Placeholder | Type, axes: List) -> Listto skrót dolist(map(lambda axis: dim(x, axis), axes)).Funkcja
index_space(x: Value | Placeholder | Type) -> Valuejest zdefiniowana na tensorach i zwraca indeksysize(x)dla odpowiednichTensorTypeposortowanych w rosnącym porządku alfabetycznym, np.[0, ..., 0],[0, ..., 1], ...,shape(x) - 1. Jeślixnie jest typem tensora, skonwertowanego typu tensora, wartości lub zastępnika jednego z tych typów, zwracaNone.rank(x: Value | Placeholder | Type) -> Valueto skrót dosize(shape(x)).shape(x: Value | Placeholder | Type) -> Valuejest zdefiniowana w sekcji „Funkcje według typów” za pomocąmember_name.size(x: Value | Placeholder | Type) -> Valueto skrót doreduce(lambda x, y: x * y, shape(x)).
Obliczenia kwantyzacji
def baseline_element_type(x: Value | Placeholder | Type) -> Typeto skrót odelement_type(baseline_type(x)).Funkcja
baseline_typejest zdefiniowana dla typów tensorów i typów zaokrąglonych tensorów. Przekształca je w „wartość bazową”, czyli typ o tym samym kształcie, ale z parametrami zaokrąglenia typu elementu skonfigurowanymi na wartości domyślne. To przydatna sztuczka, która pozwala jednolicie porównać zarówno tensor, jak i kwantyzowane typy tensorów, co jest potrzebne dość często. W przypadku typów kwantyzowanych umożliwia to porównywanie typów ignorujących parametry kwantyzacji, czylishape,storage_type,expressed_type,storage_min,storage_maxiquantization_dimension(w przypadku typu kwantyzowanego na oś), ale wartościscalesizero pointsmogą się różnić.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantizejest zdefiniowany na podstawie zliczanych typów tensorów i przekształca je w typy tensorów zmiennoprzecinkowych. Polega to na konwertowaniu elementów dyskretnych, które reprezentują wartości całkowite typu magazynowania, na odpowiadające im wartości zmiennoprzecinkowe typu wyrażonego za pomocą punktu zerowego i skali powiązanej z elementem dyskretnym.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantizejest zdefiniowany na typach tensorów zmiennoprzecinkowych i przekształca je w typy skonwertowanych tensorów. Polega to na konwersji wartości zmiennoprzecinkowych wyrażonego typu na odpowiadające wartości całkowite typu magazynu za pomocą punktu zerowego i skali powiązanej z kwantowanym typem elementu.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantizesłuży do określania obliczeń elementarnych na kwantyzowaniach tensorów. Dequantuje, czyli zamienia elementy poddane kwantyzacji na ich wyrażone typy, a potem wykonuje operację, a następnie ponownie kwantyzuje, czyli zamienia wyniki z powrotem na ich typy magazynowania. Obecnie ta funkcja działa tylko w przypadku kwantyzacji natężenia. Trwa kwantyzacja według osi (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_opsłuży do określania kwantyzacji tylko wagi dla operacji hybrydowej, która przyjmuje lewą stronę w typie zmiennoprzecinkowym, a prawą – w typie skwantowanym. Dekwantuje zagregowane dane wejściowe w ich wyrażonych typach i wykonuje obliczenia w typie float. Typ elementu wektora lewego argumentu typu float i wyrażony typ zagregowanego prawego argumentu powinny być identyczne.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Obliczenia siatki
cross_partition(replica_groups: Value) -> Value. Zobacz sekcję „cross_replica” powyżej.cross_replica(replica_groups: Value) -> Value. Zobacz sekcję „cross_replica” powyżej.cross_replica_and_partition(replica_groups: Value) -> Value. Zobacz sekcję „cross_replica_and_partition” powyżej.flattened_ids(replica_groups: Value) -> Value. Zobacz sekcję „flattened_ids” powyżej.
Dynamizm
Wartości StableHLO mogą mieć dynamiczne rozmiary wymiarów, np. tensor<?xi64>.
Wartości StableHLO nie mogą jednak mieć dynamicznej liczby wymiarów (nieuporządkowany
dynamizm, np. tensor<*xi64>). Obliczenia i wyniki mogą używać dynamicznych rozmiarów wymiarów, nawet jeśli istnieją ograniczenia rozmiarów. Jeśli to możliwe, ograniczenia są weryfikowane statycznie. W przeciwnym razie są odraczane do działania w czasie działania i niezgodności powodują niezdefiniowane zachowanie. Przykłady znajdziesz poniżej.
Niezgodność kształtów w przypadku operacji jednoelementowych
Rozważ ten program:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Taki program jest nietypowy, ponieważ zwykle wiadomo, jak wygląda wynik, ale nie wiadomo, jak wyglądają dane wejściowe. To jednak prawidłowy program StableHLO. W tym programie nie można statycznie zweryfikować operacji abs, ponieważ nie znamy dokładnego kształtu operandu. Kształty z pewnością są jednak zgodne i można to sprawdzić statycznie: okazało się, że w czasie działania obiekt ? to 2 i nie będzie żadnych problemów. Jednak ? może też okazać się inną liczbą całkowitą, w którym przypadku działanie jest nieokreślone.
Pamiętaj, że jeśli rozmiar wymiaru jest dynamiczny, nie może wystąpić niezdefiniowane działanie. Rzeczywiście nie ma „oczekiwanego” rozmiaru, więc nie może być niezgodności.
Niezgodność kształtów w przypadku operacji binarnych element po elemencie
Rozważ ten program:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
W przypadku operacji binarnych elementowych kształty danych wejściowych i wyniku muszą być zgodne w czasie wykonywania. Podczas kompilowania wymiary statyczne muszą być równe. W przeciwnym razie wystarczy, że będą tylko zgodne. Jeśli cokolwiek w danych wejściowych jest wymiarem dynamicznym, może to spowodować nieokreślone działanie w czasie wykonywania, ponieważ rozmiar dynamiczny może nie pasować do odpowiadającego mu rozmiaru w innym operandzie (czyli stałego lub dynamicznego). Jeśli wszystkie dane wejściowe są statyczne, nie ma znaczenia, czy wynik jest dynamiczny: wymiary znane statycznie będą sprawdzane statycznie, a wymiary dynamiczne nie narzucają żadnych ograniczeń.
Niezgodności kształtów w operacjach, które przyjmują kształt wyjściowy jako operand
Rozważ ten program:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Wartości w operandzie kształtu w czasie wykonywania programu muszą odpowiadać kształtowi wyniku. W przeciwnym razie zachowanie jest nieokreślone. Oznacza to, że w czasie wykonywania %arg0 musi mieć wartość dense<[3, 4]> : tensor<2xi32>. Jeśli operand kształtu jest stały, można to zweryfikować statycznie. Jeśli kształt wyniku jest w pełni dynamiczny, nie może wystąpić rozbieżność.