StableHLO to zestaw operacji wysokiego poziomu (HLO) w modelach systemów uczących się. StableHLO działa jako warstwa przenoszenia między różnymi platformami ML i kompilatorami ML: platformy ML tworzące programy StableHLO są zgodne z kompilatorami ML, które korzystają z programów StableHLO.
Naszym celem jest uproszczenie i przyspieszenie programowania ML przez zapewnienie większej interoperacyjności między różnymi platformami ML (np. TensorFlow, JAX i PyTorch) i kompilatorami ML (np. XLA i IREE). W tym celu zamieściliśmy specyfikację języka programowania StableHLO.
Specyfikacja zawiera 3 główne sekcje. Sekcja Programy opisuje strukturę programów StableHLO, które składają się z funkcji StableHLO, które składają się z operacji StableHLO. W tej strukturze sekcja Operacje określa semantykę poszczególnych operacji. Sekcja Wykonanie zawiera semantykę wszystkich tych operacji wykonywanych razem w programie. W sekcji Notacja omawiamy też zapis stosowany w całej specyfikacji.
Aby wyświetlić specyfikację z poprzedniej wersji StableHLO, otwórz repozytorium w oznaczonym wydaniu. Na przykład specyfikacja StableHLO w wersji 0.19.0. Aby wyświetlić zmiany wprowadzone w każdej małej wersji StableHLO, zapoznaj się z logiem wersji w pliku VhloDialect.td.
Programy
Program ::= {Func}
Programy StableHLO składają się z dowolnej liczby funkcji StableHLO.
Poniżej znajduje się przykładowy program z funkcją @main
, która ma 3 parametry wejściowe (%image
, %weights
i %bias
) oraz 1 wyjście. Treść funkcji zawiera 6 operacji.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funkcje
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Funkcje StableHLO (nazywane też funkcjami nazwanymi) mają identyfikator, wejścia/wyjścia i ciało. W przyszłości planujemy wprowadzić dodatkowe metadane funkcji, aby zwiększyć zgodność z HLO (#425, #626, #740, #744).
Identyfikatory
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Identyfikatory StableHLO są podobne do identyfikatorów w wielu językach programowania, ale mają 2 szczególne cechy: 1) wszystkie identyfikatory mają sygnatury, które odróżniają różne rodzaje identyfikatorów, 2) identyfikatory wartości mogą być całkowicie numeryczne, aby uprościć generowanie programów StableHLO.
Typy
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Typy StableHLO są podzielone na typy wartości (nazywane też typami najwyższej klasy), które reprezentują wartości StableHLO, oraz typy bez wartości, które opisują inne elementy programu. Typy StableHLO są podobne do typów w wielu językach programowania, a ich główną osobliwością jest specyfika danej domeny, która powoduje pewne nietypowe wyniki (np. typy skalarne nie są typami wartości).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Typy Tensor reprezentują tensory, czyli tablice wielowymiarowe. Mają one kształt i typ elementu, gdzie kształt reprezentuje nieujemne lub nieznane rozmiary wymiarów w kolejności rosnącej odpowiadających im wymiarów (zwanych też osiami), których numery wahają się od 0
do R-1
. Liczba wymiarów R
to ranking. Na przykład tensor<2x3xf32>
to typ tensora o kształcie 2x3
i typie elementu f32
. Ma 2 wymiary (czyli 2 osi) – 0 i 1, których rozmiary wynoszą odpowiednio 2 i 3. Jego pozycja to 2.
Kształty mogą być częściowo lub całkowicie nieznane (dynamiczne), np. tensor<?x2xf64>
jest częściowo nieznany, a tensor<?x?xf64>
jest całkowicie nieznany. Dynamiczne rozmiary wymiarów są reprezentowane za pomocą ?
. Kształty nie mogą być niesklasyfikowane.
W przyszłości planujemy rozszerzyć typy tensorów poza rozmiary wymiarów i typy elementów, np. o układy (#629) i sparsity (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nazwa | Typ | Ograniczenia |
---|---|---|
storage_type |
typ liczba całkowita | (C1-C3), (C8) |
storage_min |
stała liczbowa | (C1), (C3), (C7) |
storage_max |
stała liczbowa | (C2), (C3), (C7) |
expressed_type |
typ zmiennoprzecinkowy, | (C4) |
quantization_dimension |
opcjonalna liczba całkowita | (C10-C12) |
scales |
zmienna liczba stałych zmiennoprzecinkowych | (C4-C6), (C9), (C10), (C13) |
zero_points |
zmienna liczba stałych całkowitych | (C7-C9) |
Typy elementów z kwantyzacją reprezentują wartości całkowite typu magazynowania w zakresie od storage_min
do storage_max
(łącznie) odpowiadające wartościom zmiennoprzecinkowym typu wyrażonego. Dla danej wartości całkowitej i
odpowiadającą wartość zmiennoprzecinkową f
można obliczyć jako f = (i - zero_point) * scale
, gdzie scale
i zero_point
to parametry kwantyzacji. Parametry storage_min
i storage_max
są opcjonalne w gramatyce, ale mają wartości domyślne odpowiednio min_value(storage_type)
i max_value(storage_type)
. Elementy typu „kwantowany” mają te ograniczenia:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Jeśli
is_empty(quantization_dimension)
, tosize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Obecnie QuantizationScale
jest stałą zmiennoprzecinkową, ale istnieje duże zainteresowanie skalami opartymi na liczbach całkowitych, reprezentowanymi przez mnożniki i przesunięcia. W najbliższej przyszłości planujemy wziąć to pod uwagę (#1404).
Trwa dyskusja na temat semantyki funkcji QuantizationZeroPoint
, w tym jej typu, wartości i tego, czy w skwantyzowanym typie tensora może istnieć tylko jeden punkt zerowy, czy może on mieć większą liczbę punktów zerowych. Na podstawie wyników tej dyskusji specyfikacja dotycząca punktów 0 może w przyszłości ulec zmianie (#1405).
Kolejna trwająca dyskusja dotyczy semantyki QuantizationStorageMin
i QuantizationStorageMax
, aby określić, czy należy nałożyć jakieś ograniczenia na te wartości i wartości skończonych tensorów (#1406).
Na koniec planujemy zgłębić temat przedstawiania nieznanych skal i punktów zerowych w sposób podobny do tego, w jaki planujemy przedstawiać nieznane rozmiary wymiarów (#1407).
Typy kwantyzowanych tensorów reprezentują tensory z kwantyzowanymi elementami. Te tensory są dokładnie takie same jak zwykłe tensory, z tą różnicą, że ich elementy mają kwantowane typy elementów zamiast zwykłych typów elementów.
W przypadku skończonych tensorów kwantyzowanie może być na poziomie tensora, co oznacza, że dla całego tensora jest jeden element scale
i zero_point
, lub na poziomie osi, co oznacza, że dla danej osi danego wymiaru quantization_dimension
jest wiele elementów scales
i zero_points
. Bardziej formalnie, w tensorze t
z kwantyzacją na osi występują dim(t, quantization_dimension)
krojenia quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
itd. Wszystkie elementy w i
tym krojeniu używają wartości scales[i]
i zero_points[i]
jako parametrów kwantowania. Typy zaokrąglonych tensorów mają te ograniczenia:
- W przypadku kwantyzacji na centrów:
- Brak dodatkowych ograniczeń.
- W przypadku kwantyzacji na oś:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Typy tokenów to tokeny, czyli nieprzezroczyste wartości tworzone i wykorzystywane przez niektóre operacje. Tokeny są używane do określania kolejności wykonywania operacji zgodnie z opisem w sekcji Wykonywanie.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Typy tupla reprezentują tuple, czyli listy niejednorodne. Kropki to starsza funkcja,
która zapewnia zgodność z HLO. W HLO tuple służą do reprezentowania zmiennych danych wejściowych i wyjściowych. W StableHLO obsługiwane są natywny wejścia i wyjścia, a jedynym zastosowaniem tupletów w StableHLO jest kompleksowe reprezentowanie HLO ABI, w którym np. T
, tuple<T>
i tuple<tuple<T>>
mogą się znacznie różnić w zależności od konkretnej implementacji. W przyszłości planujemy wprowadzić zmiany w interfejsie HLO ABI, które mogą pozwolić nam usunąć typy tuple z StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Typy elementów reprezentują elementy typu tensora. W odróżnieniu od wielu języków programowania te typy nie są typu first class w StableHLO. Oznacza to, że programy StableHLO nie mogą bezpośrednio przedstawiać wartości tych typów (w efekcie są to wartości skalarne typu T
z wartościami 0-wymiarowymi tensora typu tensor<T>
).
- Typ logiczny reprezentuje wartości logiczne
true
ifalse
. - Typy całkowite mogą być typu ze znakiem (
si
) lub bez znaku (ui
) i mieć jedną z obsługiwanych szerokości bitów (2
,4
,8
,16
,32
lub64
). Podpisane typysiN
reprezentują liczby całkowite z zakresu od-2^(N-1)
do2^(N-1)-1
włącznie, a bez znakuuiN
– wartości całkowite z zakresu od0
do2^N-1
. - Typy zmiennoprzecinkowe mogą być następujące:
f8E3M4
,f8E4M3
if8E5M2
8-bitowe liczby zmiennoprzecinkową zgodnie z konwencjami IEEE-754.- Typy
f8E4M3FN
if8E5M2
odpowiadające odpowiednio kodowaniomE4M3
iE5M2
formatu FP8 opisanego w artykule Formaty FP8 do uczenia głębokiego. - Typy
f8E4M3FNUZ
if8E5M2FNUZ
odpowiadające kodowaniomE4M3
iE5M2
formatów FP8 opisanych w artykule 8-bitowe formaty liczbowe na potrzeby głębokich sieci neuronowych. - Typ
f8E4M3B11FNUZ
odpowiadający kodowaniuE4M3
formatów FP8 opisanych w artykule Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks (Trenowanie i wykonywanie wnioskowania w sieciach głębokich neuronowych przy użyciu hybrydowego formatu 8-bitowego liczby zmiennoprzecinkowej (HFP8)). - Typ
bf16
odpowiadający formatowibfloat16
opisanemu w artykule BFloat16: sekret wysokiej wydajności w Cloud TPU. - Typy
f16
,f32
if64
odpowiadające odpowiednio formatombinary16
(„półpełnej precyzji”),binary32
(„pełnej precyzji”) ibinary64
(„podwójnej precyzji”) opisanym w standardzie IEEE 754. - Typ
tf32
odpowiada formatowi TensorFloat32 i ma ograniczoną obsługę w ramach StableHLO. - Typy
f4E2M1FN
,f6E2M3FN
,f6E3M2FN
if8E8M0FNU
MX (mikroskalowanie) opisane w specyfikacji formatów mikroskalowania OCP.
- Typy zespolone reprezentują wartości zespolone, które mają część rzeczywistą i część urojoną tego samego typu elementu. Obsługiwane złożone typy to
complex<f32>
(obie części są typuf32
) icomplex<f64>
(obie części są typuf64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Typy funkcji reprezentują zarówno funkcje nazwane, jak i anonimowe. Mają typy wejścia (lista typów po lewej stronie ->
) i typy wyjścia (lista typów po prawej stronie ->
). W wielu językach programowania typy funkcji są typu first class, ale nie w StableHLO.
StringType ::= 'string'
Typ ciągu tekstowego reprezentuje sekwencje bajtów. W przeciwieństwie do wielu języków programowania typ ciągu znaków nie jest pierwszą klasą w StableHLO i służy tylko do określania statycznych metadanych elementów programu.
Operacje
Operacje StableHLO (nazywane też operacjami) stanowią zamknięty zbiór operacji wysokiego poziomu w modelach uczenia maszynowego. Jak już wspomnieliśmy, składnia StableHLO jest mocno inspirowana MLIR, co niekoniecznie jest najbardziej ergonomiczną alternatywą, ale prawdopodobnie najlepiej pasuje do celu StableHLO, którym jest zwiększenie interoperacyjności między frameworkami i kompilatorami uczenia maszynowego.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Operacje StableHLO (nazywane też operacjami) mają nazwę, dane wejściowe/wyjściowe i podpis. Nazwa składa się z prefiksu stablehlo.
i mnemotechniki, która jednoznacznie identyfikuje jedno z obsługiwanych operacji. Pełną listę wszystkich obsługiwanych operacji znajdziesz poniżej.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Operacje wykorzystują dane wejściowe i generują dane wyjściowe. Dane wejściowe są podzielone na wartości wejściowe (obliczone podczas wykonywania), funkcje wejściowe (podawane statycznie, ponieważ w StableHLO funkcje nie są wartościami najwyższej klasy) oraz atrybuty wejściowe (również podawane statycznie). Rodzaj danych wejściowych i wyjściowych zużywanych oraz wytwarzanych przez operatora zależy od jego skrótu. Na przykład funkcja add
op pobiera 2 wartości wejściowe i zwraca 1 wartość wyjściową. Natomiast operator select_and_scatter
wymaga 3 wartości wejściowych, 2 funkcji wejściowych i 3 atrybutów wejściowych.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Funkcje wejściowe (nazywane też funkcjami anonimowymi) są bardzo podobne do funkcji nazwanych, z tym że: 1) nie mają identyfikatora (stąd nazwa „anonimowa”) i 2) nie deklarują typów danych wyjściowych (typy danych są wywnioskowane z opcji return
w ramach funkcji).
Składnia funkcji wejściowych zawiera obecnie nieużywaną część (patrz Unused
produkcja powyżej), która jest tam ze względu na zgodność z MLIR. W MLIR występuje bardziej ogólna koncepcja „regionów”, które mogą mieć wiele „bloków” operacji połączonych ze sobą za pomocą operacji przeskoku. Te bloki mają identyfikatory odpowiadające Unused
produkcji, dzięki czemu można je odróżnić od siebie.
StableHLO nie obsługuje wykonywania operacji przejść, więc odpowiadająca mu część składni MLIR nie jest używana (ale nadal jest).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Atrybuty wejściowe mają nazwę i wartość, która jest jedną z obsługiwanych stałych. Stanowią one podstawowy sposób określania metadanych statycznych elementów programu. Na przykład operator concatenate
używa atrybutu dimension
do określania wymiaru, wzdłuż którego wartości wejściowe są łączone. Podobnie operacja slice
używa wielu atrybutów, takich jak start_indices
i limit_indices
, aby określić granice używane do wycinania wartości wejściowej.
Obecnie programy StableHLO działające w środowisku naturalnym mogą czasami zawierać atrybuty, które nie są opisane w tym dokumencie. W przyszłości planujemy uwzględnić te atrybuty w opcji StableHLO lub zabronić ich wykorzystywania w programach StableHLO. Oto lista tych atrybutów:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- metadane lokalizacji (#594);
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Podpis operacji składa się z typów wszystkich wartości wejściowych (lista typów po lewej stronie ->
) oraz typów wszystkich wartości wyjściowych (lista typów po prawej stronie ->
). Ściśle mówiąc, typy wejścia są zbędne, a typy wyjścia prawie zawsze są zbędne (ponieważ w przypadku większości operacji StableHLO typy wyjścia można wywnioskować z danych wejściowych). Mimo to op
signature jest celowo częścią składni StableHLO, aby zapewnić zgodność z MLIR.
Poniżej znajduje się przykład operacji, której skrót to select_and_scatter
. Przyjmuje 3 wartości wejściowe (%operand
, %source
i %init_value
), 2 funkcje wejściowe i 3 atrybuty wejściowe (window_dimensions
, window_strides
i padding
).
Zwróć uwagę, że podpis operacji obejmuje tylko typy wartości wejściowych (ale nie typy funkcji i atrybutów wejściowych podane w tekście).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Stałe
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Stałe StableHLO mają literał i typ, które razem reprezentują wartość StableHLO. Typ jest zwykle częścią składni stałej, z wyjątkiem sytuacji, gdy jest jednoznaczny (np. stała logiczna ma jednoznacznie typ i1
, podczas gdy stała całkowita może mieć kilka możliwych typów).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Stałe logiczne reprezentują wartości logiczne true
i false
. Stałe logiczne mają typ i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Stałe liczbowe reprezentują wartości liczb całkowitych za pomocą ciągów tekstowych, które używają zapisu dziesiętnego lub szesnastkowego. Inne systemy liczbowe, np. binarny czy octal, nie są obsługiwane. Stałe liczbowe są objęte tymi ograniczeniami:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Stałe zmiennoprzecinkowe reprezentują wartości zmiennoprzecinkowe za pomocą ciągów znaków, które używają zapisu dziesiętnego lub wykładniczego. Dodatkowo przy użyciu notacji szesnastkowej można bezpośrednio określać bazowe bity w formacie zmiennoprzecinkowym odpowiedniego typu. Stałe zmiennoprzecinkowe są objęte tymi ograniczeniami:
- (C1) Jeśli używany jest zapis w formacie innym niż szesnastkowy,
is_wellformed(float_literal, float_type)
. - (C2) Jeśli używana jest notacja szesnastkowa,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Stałe zespolone reprezentują wartości zespolone za pomocą list części rzeczywistej (pojawia się jako pierwsza) i części urojonej (pojawia się jako druga). Na przykład (1.0, 0.0) : complex<f32>
oznacza 1.0 + 0.0i
, a (0.0, 1.0) : complex<f32>
oznacza 0.0 + 1.0i
. Kolejność, w jakiej te części są przechowywane w pamięci, jest zdefiniowana przez implementację. Stałe złożone mają te ograniczenia:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Stałe tensora reprezentują wartości tensora za pomocą zagnieżdżonych list określonych za pomocą notacji NumPy. Na przykład dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
reprezentuje wartość tensora z tym mapowaniem indeksów na elementy: {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
, {1, 2} => 6
. Kolejność, w jakiej te elementy są przechowywane w pamięci, jest zdefiniowana przez implementację. Stałe tensora mają następujące ograniczenia:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, gdzie:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, gdzie:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- w przeciwnym razie:
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Zwartości skonwertowanego tensora reprezentują wartości skonwertowanego tensora za pomocą tej samej notacji co stałe tensora, a ich elementy są określone jako stałe ich typu magazynu. Poddane kwantyzacji stałe tensora mają następujące ograniczenia:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Literały łańcuchowe składają się z bajtów określonych za pomocą znaków ASCII i sekwencji ucieczki. Są one niezależne od kodowania, więc interpretacja tych bajtów jest zdefiniowana przez implementację. Literały łańcuchowe mają typ string
.
Operacje
abs
Semantyka
Wykonuje elementarną operację abs na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- W przypadku liczb całkowitych ze znakiem: moduł liczby całkowitej.
- Dla liczb zmiennoprzecinkowych:
abs
z IEEE-754. - W przypadku liczb zespolonych: moduł zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(abs, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora | (C1-C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu signed integer lub zmiennoprzecinkowego albo tensor kwantyzowany na podstawie tensora | (K1–C2) |
Ograniczenia
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
zdefiniowano jako:complex_element_type(element_type(operand))
jeżeliis_complex(operand)
.baseline_element_type(operand)
w innych przypadkach.
Przykłady
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
dodaj
Semantyka
Wykonuje dodawanie elementów dwóch tensorów lhs
i rhs
i tworzy tensor result
. W zależności od typu elementu:
- W przypadku wartości logicznych: operator logiczny LUB.
- W przypadku liczb całkowitych: dodawanie liczb całkowitych.
- Dla liczb zmiennoprzecinkowych:
addition
z IEEE-754. - W przypadku liczb zespolonych: dodawanie zespolone.
- W przypadku typów skwantowanych:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor lub kwantyzowany tensor | (C1-C6) |
(I2) | rhs |
tensor lub kwantyzowany tensor | (C1-C5), (C7) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor (tensor kwantowy) | (C1-C7) |
Ograniczenia
- Jeśli operacja używa niespakowanych tensorów:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Jeśli operacja używa kwantowanych tensorów:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Jeśli
is_per_axis_quantized(lhs)
, toquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Jeśli
is_per_axis_quantized(rhs)
, toquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Przykłady
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semantyka
Zapewnia, że operacje generujące inputs
są wykonywane przed operacjami zależnymi od result
. Wykonywanie tej operacji nie powoduje żadnych zmian. Służy ona tylko do ustanowienia zależności danych z poziomu result
do inputs
.
Dane wejściowe
Etykieta | Nazwa | Typ |
---|---|---|
(I1) | inputs |
zmienna liczba token |
Wyniki
Nazwa | Typ |
---|---|
result |
token |
Przykłady
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semantyka
W każdej grupie procesów w siatce procesów StableHLO konkatenuje wartości tensorów operands
z każdego procesu wzdłuż wymiaru all_gather_dim
i tworzy tensory results
.
Operacja dzieli siatkę procesu StableHLO na process_groups
, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)
jeślichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
ifchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
ifchannel_id > 0 and use_global_device_ids = true
.
Następnie w każdym wierszu process_group
:
operands...@receiver = [operand@sender for sender in process_group]
za wszystkiereceiver
wprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
za wszystkieprocess
wprocess_group
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operands |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1), (C6) |
(I2) | all_gather_dim |
stała typu si64 |
(C1), (C6) |
(I3) | replica_groups |
2-wymiarowa stała tensora typu si64 |
(C2-C4) |
(I4) | channel_id |
stała typu si64 |
(C5) |
(I5) | use_global_device_ids |
stała typu i1 |
(K5) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C6) |
Ograniczenia
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3) Parametr
size(replica_groups)
jest zdefiniowany jako:num_replicas
, jeśli używana jest właściwośćcross_replica
.num_replicas
, jeśli używana jest właściwośćcross_replica_and_partition
.num_processes
, jeśli używana jest właściwośćflattened_ids
.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Jeśli
use_global_device_ids = true
, tochannel_id > 0
. - (C6)
type(results...) = type(operands...)
z wyjątkiem:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semantyka
W ramach każdej grupy procesów w siatce procesów StableHLO stosuje funkcję redukcji computation
do wartości tensorów operands
z każdego procesu i generuje tensory results
.
Operacja dzieli siatkę procesu StableHLO na process_groups
, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)
jeślichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
ifchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
ifchannel_id > 0 and use_global_device_ids = true
.
Następnie w każdym wierszu process_group
:
results...@process[result_index] = exec(schedule)
dla niektórych drzew binarnychschedule
gdzie:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
to drzewo binarne zdefiniowane przez implementację, którego traversal w kolejności toto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operands |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C5), (C6) |
(I2) | replica_groups |
zmienna liczba stałych tensora jednowymiarowego typu si64 |
(K1–C3) |
(I3) | channel_id |
stała typu si64 |
(C4) |
(I4) | use_global_device_ids |
stała typu i1 |
(C4) |
(I5) | computation |
funkcja | (C5) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C6-C7) |
Ograniczenia
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
zdefiniowano jako:num_replicas
, jeśli używana jest właściwośćcross_replica
.num_replicas
, jeśli używana jest właściwośćcross_replica_and_partition
.num_processes
, jeśli używana jest właściwośćflattened_ids
.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Jeśli
use_global_device_ids = true
, tochannel_id > 0
. - (C5)
computation
ma typ(tensor<E>, tensor<E>) -> (tensor<E>)
, gdzieis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semantyka
W ramach każdej grupy procesów w siatce procesów StableHLO dzieli wartości tensorów operands
wzdłuż split_dimension
na części, rozprasza te części między procesami, konkatenuje rozproszone części wzdłuż concat_dimension
i tworzy tensory results
.
Operacja dzieli siatkę procesu StableHLO na process_groups
, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)
jeżelichannel_id <= 0
.cross_partition(replica_groups)
jeżelichannel_id > 0
.
Następnie w każdym wierszu process_group
:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
dla wszystkichsender
wprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
gdziereceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operands |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1-C3), (C9) |
(I2) | split_dimension |
stała typu si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
stała typu si64 |
(C3), (C9) |
(I4) | split_count |
stała typu si64 |
(C2), (C4), (C8), (C9) |
(I5) | replica_groups |
2-wymiarowa stała tensora typu si64 |
(C5-C8) |
(I6) | channel_id |
stała typu si64 |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C9) |
Ograniczenia
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
jest zdefiniowany jako:num_replicas
, jeśli używana jest właściwośćcross_replica
.num_partitions
, jeśli używana jest właściwośćcross_partition
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
oprócz tych, któresplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
i
Semantyka
Wykonuje z punktu widzenia elementu ORAZ dwa tensory lhs
i rhs
, tworzy intensywność result
. W zależności od typu elementu:
- W przypadku wartości logicznych: operator logiczny OR.
- W przypadku liczb całkowitych: bitowe ORAZ.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu logicznego lub całkowitego | (C1) |
(I2) | rhs |
tensor typu logicznego lub całkowitego | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu logicznego lub całkowitego | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result)
.
Przykłady
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semantyka
Wykonuje elementową operację atan2 na tensorach lhs
i rhs
, tworząc tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
atan2
z IEEE-754. - W przypadku liczb zespolonych: complex atan2.
- W przypadku typów skwantowanych:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
(I2) | rhs |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Przykłady
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semantyka
Oblicza gradienty kilku wejść batch_norm_training
z backpropagatinggrad_output
i tworzy tensory grad_operand
, grad_scale
i grad_offset
. Bardziej formalnie ta operacja może zostać wyrażona jako dekompozycja istniejących operacji StableHLO z użyciem składni Pythona:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
W przypadku typów skwantowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1-C3), (C5) |
(I2) | scale |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4), (C5) |
(I3) | mean |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
(I4) | variance |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
(I5) | grad_output |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C2), (C3) |
(I6) | epsilon |
stała typu f32 |
|
(I7) | feature_index |
stała typu si64 |
(C1), (C5) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
grad_operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C2), (C3) |
grad_scale |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
grad_offset |
Jednowymiarowy tensor typu kwantyzowanego typu zmiennoprzecinkowego lub na tensor | (C2), (C4) |
Ograniczenia
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
igrad_offset
mają te same wartościbaseline_element_type
. - (C3)
operand
,grad_output
igrad_operand
mają ten sam kształt. - (C4) Znaczniki
scale
,mean
,variance
,grad_scale
igrad_offset
mają ten sam kształt. - (C5)
size(scale) = dim(operand, feature_index)
.
Przykłady
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semantyka
Normalizuje tensor operand
we wszystkich wymiarach z wyjątkiem wymiaru feature_index
i tworzy tensor result
. Bardziej formalnie tę operację można wyrazić jako dekompozycję istniejących operacji StableHLO za pomocą składni Pythona w ten sposób:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
W przypadku typów kwantowych wybiera dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub kwantyzowany tensor na poziomie procesora | (C1-C7) |
(I2) | scale |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C3) |
(I3) | offset |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C4) |
(I4) | mean |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (K5) |
(I5) | variance |
1-wymiarowy tensor typu zmiennoprzecinkowego lub z kwantyzacją na poziomie tensora | (C2), (C6) |
(I6) | epsilon |
stała typu f32 |
|
(I7) | feature_index |
stała typu si64 |
(C1), (C3-C6) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C2), (C7) |
Ograniczenia
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
iresult
mają te same ustawieniabaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semantyka
Oblicza średnią i wariancję we wszystkich wymiarach oprócz wymiaru feature_index
oraz normalizuje tensor operand
, tworząc tensory output
, batch_mean
i batch_var
. Bardziej formalnie ta operacja może być wyrażona jako dekompozycja istniejących operacji StableHLO z użyciem składni Pythona:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
W przypadku typów skwantowanych wykonuje działanie dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
(I2) | scale |
Jednowymiarowy tensor kwantyzowany (zmiennoprzecinkowy lub na tensor) | (C2), (C3) |
(I3) | offset |
1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora | (C2), (C4) |
(I4) | epsilon |
stała typu f32 |
(C1), (C3-C6) |
(I5) | feature_index |
stała typu si64 |
(C1), (C3-C6) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
output |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C7) |
batch_mean |
1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora | (C2), (C5) |
batch_var |
1-wymiarowy tensor zmiennoprzecinkowy lub kwantyzowany na poziomie tensora | (C2), (C6) |
Ograniczenia
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
ioutput
mają ten sambaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Przykłady
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semantyka
Wykonuje operację bitcast na tensorze operand
i tworzy tensor result
, w którym bity całego tensora operand
są ponownie interpretowane przy użyciu typu tensora result
.
Bardziej formalnie, jeśli E = element_type(operand)
, E' = element_type(result)
i R = rank(operand)
:
- Jeśli
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Jeśli
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Jeśli
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
Funkcja bits
zwraca reprezentację danej wartości w pamięci, a jej działanie jest definiowane przez implementację, ponieważ dokładna reprezentacja tensorów jest zdefiniowana przez implementację, a dokładna reprezentacja typów elementów również jest zdefiniowana w implementacji.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub kwantyzowany tensor | (C1-C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub kwantyzowany tensor | (K1–C2) |
Ograniczenia
- (C1) Przy założeniu, że
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
iR = rank(operand)
:- Jeśli
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Jeśli
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
dla wszystkich0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Jeśli
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
za wszystkie0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Jeśli
- (C2) Jeśli
is_complex(operand) or is_complex(result)
, tois_complex(operand) and is_complex(result)
.
Przykłady
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semantyka
Rozwija wymiary lub rangę wejściowego tensora przez powielanie danych w tensorze operand
i tworzenie tensora result
. Formalnie:
result[result_index] = operand[operand_index]
dla wszystkich d
w ramach axes(operand)
:
operand_index[d] = 0
jeżelidim(operand, d) = 1
.operand_index[d] = result_index[broadcast_dimensions[d]]
w innych przypadkach.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub kwantyzowany tensor | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Jednowymiarowa stała tensora typu si64 |
(C2-C6) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor (tensor kwantowy) | (C1), (C3), (C5-C6) |
Ograniczenia
- (C1)
element_type(result)
otrzymuje:element_type(operand)
, jeśli!is_per_axis_quantized(operand)
.element_type(operand)
, z tym żequantization_dimension(operand)
,scales(operand)
izero_points(operand)
mogą się różnić odquantization_dimension(result)
,scales(result)
izero_points(result)
odpowiednio.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) W przypadku wszystkich
d
waxes(operand)
:dim(operand, d) = 1
lubdim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Jeśli
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Jeśli
dim(operand, quantization_dimension(operand)) = 1
, toscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Przykłady
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
etui
Semantyka
Generuje dane wyjściowe w wyniku wykonania dokładnie 1 funkcji z funkcji branches
w zależności od wartości index
. Bardziej formalnie: result = selected_branch()
gdzie:
selected_branch = branches[index]
jeżeli0 <= index < size(branches)
.selected_branch = branches[-1]
w innych przypadkach.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | index |
Tensor 0-wymiarowy typu si32 |
|
(I2) | branches |
zmienna liczba funkcji | (C1-C4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C4) |
Ograniczenia
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Przykłady
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
Cbrt
Semantyka
Wykonuje elementarną operację pierwiastka sześciennego na tensorze operand
i tworzy tensor result
. W zależności od typu elementu wykonuje te działania:
- W przypadku jednostek zmiennoprzecinkowych:
rootn(x, 3)
w standardzie IEEE-754. - Liczby zespolone: pierwiastek sześcienny zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(cbrt, operand, type(result))
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semantyka
Przeprowadza element po elemencie zaokrąglenie w górę tensora operand
i generuje tensor result
.
Realizuje operację roundToIntegralTowardPositive
według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(ceil, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
Cholesky
Semantyka
Oblicza rozkład Choleskiego dla zbioru macierzy.
W bardziej formalnym ujęciu dla wszystkich i
w index_space(result)
result[i0, ..., iR-3, :, :]
jest rozkładem Cholesky'ego a[i0, ..., iR-3, :, :]
w postaci dolnej macierzy trójkątnej (jeśli lower
jest true
) lub górnej macierzy trójkątnej (jeśli lower
jest false
).
Wartości wyjściowe w trójkącie przeciwległym, tj. odpowiednio ścisły górny trójkąt lub odpowiednio ścisły trójkąt dolny, są definiowane na podstawie implementacji.
Jeśli istnieje element i
, w którym macierz wejściowy nie jest macierzową o dodatnie dodatnim, działanie jest niezdefiniowane.
W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | a |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1-C3) |
(I2) | lower |
Stałe tensora 0-wymiarowego typu i1 |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Przykłady
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
ograniczać (zakres)
Semantyka
Łączy każdy element tensora operand
między wartością minimalną a maksymalną, tworząc tensor result
. Bardziej formalnie: result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
, gdzie min_element = rank(min) = 0 ? min[] : min[result_index]
, max_element = rank(max) = 0 ? max[] : max[result_index]
. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560).
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | min |
tensor lub tensor zagregowany z tensorów | (C1), (C3) |
(I2) | operand |
tensor lub tensor zagregowany z tensorów | (C1-C4) |
(I3) | max |
tensor lub tensor zagregowany z tensorów | (C2), (C3) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C4) |
Ograniczenia
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semantyka
W każdej grupie procesów w siatce procesów StableHLO wyślij wartość tensora operand
z procesu źródłowego do procesów docelowych i utwórz tensor result
.
Ta operacja dzieli siatkę procesów StableHLO na siatkę procesów process_groups
, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)
jeżelichannel_id <= 0
.cross_partition(replica_groups)
jeżelichannel_id > 0
.
Następnie result@process
jest obliczana jako:
operand@process_groups[i, 0]
, jeśli istniejei
, który określa, że proces znajduje się w regionieprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
w innym przypadku.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub tensor zagregowany z tensorów | (C3) |
(I2) | replica_groups |
zmienna liczba stałych tensora jednowymiarowego typu si64 |
(C1), (C2) |
(I3) | channel_id |
stała typu si64 |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C3) |
Ograniczenia
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, gdzieN
jest zdefiniowany jako:num_replicas
, jeśli używana jest właściwośćcross_replica
.num_partitions
, jeśli używana jest właściwośćcross_partition
.
- (C3)
type(result) = type(operand)
.
Przykłady
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semantyka
W każdej grupie procesów w siatce procesów StableHLO wysyła wartość tensora operand
z procesu źródłowego do procesu docelowego i tworzy tensor result
.
Operacja dzieli siatkę procesu StableHLO na process_groups
, która jest zdefiniowana w ten sposób:
cross_replica(source_target_pairs)
, jeślichannel_id <= 0
.cross_partition(source_target_pairs)
jeżelichannel_id > 0
.
Później wartość result@process
jest określana przez:
operand@process_groups[i, 0]
, jeśli istniejei
takie, żeprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
kwantowy tensor lub tensor kwantowy | (K5) |
(I2) | source_target_pairs |
2-wymiarowa stała tensora typu si64 |
(C1-C4) |
(I3) | channel_id |
stała typu si64 |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, gdzieN
jest zdefiniowany jako:num_replicas
, jeśli używana jest właściwośćcross_replica
.num_partitions
, jeśli używana jest właściwośćcross_partition
.
- (C5)
type(result) = type(operand)
.
Przykłady
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
porównaj
Semantyka
Porównuje elementy tensorów lhs
i rhs
zgodnie z definicjami comparison_direction
i compare_type
, tworząc tensor result
.
Wartości comparison_direction
i compare_type
mają następującą semantykę:
W przypadku typów elementów wartości logicznych i liczb całkowitych:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
W przypadku typów elementów zmiennoprzecinkowych z compare_type = FLOAT
operator op implementuje te operacje IEEE-754:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
W przypadku typów elementów zmiennoprzecinkowych z wartością compare_type = TOTALORDER
operator używa kombinacji operacji totalOrder
i compareQuietEqual
z normy IEEE-754.
W przypadku złożonych typów elementów przeprowadzane jest porównanie leksykograficzne par (real, imag)
za pomocą podanych wartości comparison_direction
i compare_type
.
Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych, gdy comparison_direction
= GE
, GT
, LE
lub LT
(#560).
W przypadku typów poddanych kwantyzacji wykonuje działanie dequantize_compare(lhs, rhs,
comparison_direction)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C1-C3) |
(I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1-C2) |
(I3) | comparison_direction |
enum EQ , NE , GE , GT , LE i LT |
|
(I4) | compare_type |
enum FLOAT , TOTALORDER , SIGNED i UNSIGNED |
(C3) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu logicznego | (K2) |
Ograniczenia
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
jest zdefiniowany jako:SIGNED
jeżeliis_signed_integer(element_type(lhs))
.UNSIGNED
jeżeliis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
lubTOTALORDER
, jeśliis_float(element_type(lhs))
.FLOAT
jeżeliis_complex(element_type(lhs))
.
Przykłady
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
złożone
Semantyka
Przeprowadza konwersję element po elemencie na wartość zespoloną z pary wartości rzeczywistych i urojonych, lhs
i rhs
, oraz generuje tensor result
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu f32 lub f64 |
(C1-C3) |
(I2) | rhs |
tensor typu f32 lub f64 |
(C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zespolonego | (C2), (C3) |
Ograniczenia
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
ma typcomplex<E>
, gdzieE = element_type(lhs)
.
Przykłady
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
wieloskładnikowa
Semantyka
Zawiera operację złożoną z innych operacji StableHLO, przyjmując inputs
i composite_attributes
oraz zwraca results
. Semantyka operacji jest implementowana przez atrybut decomposition
. Opcję composite
można zastąpić jej rozkładem bez zmiany semantyki programu. Jeśli wstawienie dekompozycji w kod źródłowy nie zapewnia tej samej semantyki op, użyj custom_call
.
Pole version
(domyślnie 0
) służy do informowania o zmianie semantyki elementu złożonego.
Dane wejściowe
Etykieta | Nazwa | Typ |
---|---|---|
(I1) | inputs |
zmienna liczba wartości |
(I2) | name |
stała typu string |
(I3) | composite_attributes |
słownik atrybutów |
(I4) | decomposition |
stała typu string |
(I5) | version |
stała typu si32 |
Wyniki
Nazwa | Typ |
---|---|
results |
zmienna liczba wartości |
Ograniczenia
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Przykłady
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
konkatenacja
Semantyka
Łączy elementy inputs
wzdłuż wymiaru dimension
w takim samym porządku jak podane argumenty i tworzy tensor result
. Bardziej formalnie:
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, gdzie:
id = d0 + ... + dk-1 + kd
.d
jest równydimension
, ad0
, ... tod
rozmiaryinputs
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (K1–C6) |
(I2) | dimension |
stała typu si64 |
(C2), (C4), (C6) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C5–C6) |
Ograniczenia
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
z wyjątkiemdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
oprócz:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Przykłady
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
stała
Semantyka
Tworzy tensor output
na podstawie stałej value
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | value |
stała | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
output |
tensor lub kwantyzowany tensor | (C1) |
Ograniczenia
- (C1)
type(value) = type(output)
.
Przykłady
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
dokonają konwersji
Semantyka
Przeprowadza konwersję elementów z jednego typu na inny w tensorze operand
i tworzy tensor result
.
W przypadku konwersji boolean-to-any-supported-type wartość false
jest zamieniana na 0, a wartość true
na 1. W przypadku konwersji typu any-supported-type-to-boolean wartość 0 jest konwertowana na false
, a wartości inne niż zero – na true
. Poniżej znajdziesz informacje o tym, jak to działa w przypadku typów złożonych.
W przypadku konwersji zawierających liczby całkowite na liczbę całkowitą, liczbę zmiennoprzecinkową lub zmiennoprzecinkową na liczbę zmiennoprzecinkową wartość źródłową może być dokładnie reprezentowana w typie miejsca docelowego. W przeciwnym razie zachowanie jest nieokreślone (#180).
W przypadku konwersji zawierających floating-point-to-integer część ułamkowa jest skracana. Jeśli skrócona wartość nie może być przedstawiona w przypadku typu miejsca docelowego, zachowanie jest nieokreślone (#180).
Konwersja z typu zespolonego na typ zespolony działa tak samo jak konwersja z typu zmiennoprzecinkowego na typ zmiennoprzecinkowy w przypadku konwersji części rzeczywistej i części urojonej.
W przypadku konwersji z typu złożonego na dowolny inny typ i z dowolnego innego typu na złożony wartość wyobrażona źródła jest ignorowana lub docelowa wartość wyobrażona jest ustawiana na 0. Konwersja części rzeczywistej następuje zgodnie z konwersją na liczby zmiennoprzecinkowe.
Zasadniczo ta operacja może wyrażać dekwantyzację (przekształcanie tensorów regularnych w tensory regularne), kwantyzację (konwersja z tensorów regularnych na tensory kwantyzowane) i rekwantyzację (przekształcanie między kwantyzowanymi procesorami), ale obecnie mamy dla nich specjalne operacje – uniform_dequantize
dla pierwszego przypadku użycia i uniform_quantize
w drugim i trzecim przypadku użycia. W przyszłości te 2 operacje mogą zostać połączone w convert
(#1576).
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor | (C1) |
Ograniczenia
- (C1)
shape(operand) = shape(result)
.
Przykłady
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
splotu
Semantyka
Oblicza iloczyny skalarne między oknami lhs
i wycinkami rhs
oraz generuje result
. Na diagramie poniżej widać, jak elementy w elementach result
są obliczane na podstawie elementów lhs
i rhs
.
Bardziej formalnie rozważ następujące zmiany w danych wejściowych w postaci elementu lhs
, aby można było tworzyć przedziały czasu lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Ta zmiana kadrowania wykorzystuje te funkcje pomocnicze:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
gdziej[d] = i[permutation[d]]
.
Jeśli feature_group_count = 1
i batch_group_count = 1
, to dla wszystkich output_spatial_index
w index_space(dim(result, output_spatial_dimensions...))
:result[result_shape(:, output_spatial_index, :)] = dot_product
gdzie:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Ta funkcja wydaje się nieużywana, więc planujemy ją usunąć w przyszłości (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Jeśli feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Jeśli batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
W przypadku typów kwantowych wybiera dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
W przypadku typów hybrydowych kwantyzowana wartość wylicza hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
(I2) | rhs |
tensor lub kwantyzowany tensor | (C1), (C14–C16), (C25), (C27–C29), (C31–C34) |
(I3) | window_strides |
Jednowymiarowa stała tensora typu si64 |
(C2-C3), (C25) |
(I4) | padding |
2-wymiarowa stała tensora typu si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Jednowymiarowa stała tensora typu si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
1-wymiarowa stała tensora typu si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
1-wymiarowa stała tensora typu i1 |
(C9) |
(I8) | input_batch_dimension |
stała typu si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
stała typu si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
stała typu si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
stała typu si64 |
(C15–C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
stała typu si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
stała typu si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
stała typu si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
stała typu si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
zmienna liczba typów enumeracji DEFAULT , HIGH i HIGHEST |
(C24) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor (tensor kwantowy) | (C25-C28), (C30), (C32-34) |
Ograniczenia
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Zgodnie z
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Podany
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Zgodnie z
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
jest zdefiniowany jako:dim(lhs, input_batch_dimension) / batch_group_count
jeżeliresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
jeżeliresult_dim = output_feature_dimension
.num_windows
w przeciwnym razie, gdzie:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Jeśli operacja używa tensorów niekwantowych:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Jeśli operacja używa kwantowanych tensorów:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Jeśli
is_per_axis_quantized(rhs)
, toquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Jeśli
is_per_axis_quantized(result)
, toquantization_dimension(result) = output_feature_dimension
. - Jeśli
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Jeśli
is_per_tensor_quantized(rhs)
, tois_per_tensor_quantized(result)
. - Jeśli
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Przykłady
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
cosinus
Semantyka
Wykonuje elementarną operację cosinusa na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
cos
z IEEE-754. - W przypadku liczb zespolonych: cosinus zespolony.
- W przypadku typów kwantowych:
dequantize_op_quantize(cosine, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semantyka
Wykonuje element po elemencie zliczanie liczby początkowych zer w tensorze operand
i tworzy tensor result
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu liczba całkowita | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result)
.
Przykłady
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semantyka
Zawiera zdefiniowaną przez implementację operację call_target_name
, która przyjmuje argumenty inputs
i called_computations
oraz zwraca results
. Atrybuty has_side_effect
, backend_config
i api_version
mogą służyć do udostępniania dodatkowych metadanych zdefiniowanych przez implementację.
Obecnie ta operacja zawiera dość nieuporządkowaną kolekcję metadanych, która odzwierciedla organiczną ewolucję jej odpowiednika w kompilatorze XLA. W przyszłości planujemy ujednolicić te metadane (#741).
Dane wejściowe
Etykieta | Nazwa | Typ |
---|---|---|
(I1) | inputs |
zmienna liczba wartości |
(I2) | call_target_name |
stała typu string |
(I3) | has_side_effect |
stała typu i1 |
(I4) | backend_config |
stała typu string lub słownik atrybutów |
(I5) | api_version |
stała typu si32 |
(I6) | called_computations |
liczba zmiennoprzecinkowa typu string |
Wyniki
Nazwa | Typ |
---|---|
results |
liczba zmiennoprzecinkowa |
Przykłady
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
dzielenie
Semantyka
Wykonuje element po elemencie dzielenie tensorów dzielnika lhs
i dzielenia rhs
oraz zwraca tensor result
. W zależności od typu elementu:
- W przypadku liczb całkowitych: dzielenie liczb całkowitych, które zwraca iloraz algebraiczny z pominięciem części ułamkowej.
- Dla liczb zmiennoprzecinkowych:
division
z IEEE-754. - W przypadku liczb zespolonych: dzielenie zespolone.
- W przypadku typów skwantowanych:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
(I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Przykłady
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semantyka
Oblicza iloczyn skalarny między wycinkami lhs
i rhs
, uzyskując tensor result
.
Bardziej formalnie: result[result_index] = dot_product
, gdzie:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
gdziesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
isize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
W przypadku typów skwantowanych wykonuje operację dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
W przypadku hybrydowych typów kwantyzacji wykonuje działanie hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
kontroluje kompromis między szybkością a dokładnością obliczeń na backendzie akceleratora. Może to być jedna z tych wartości (obecnie semantyka tych wartości jest niewystarczająco sprecyzowana, ale planujemy rozwiązać ten problem w bładze zgłoszenia 755):
DEFAULT
: najszybsze obliczenia, ale najmniej dokładne przybliżenie do pierwotnego wyniku.HIGH
: wolniejsze obliczenia, ale dokładniejsze przybliżenie do pierwotnej wartości.HIGHEST
: najwolniejsze obliczenia, ale najbardziej dokładne przybliżenie do pierwotnej wartości.
DotAlgorithm
definiuje główne właściwości algorytmu używanego do implementacji operacji kropki, która określa też dokładność. Jeśli pola atrybutów algorytmu są ustawione, precision_config
musi mieć wartość DEFAULT
. DotAlgorithms
nie mają wartości domyślnej, ponieważ parametry domyślne są definiowane przez implementację. Dlatego wszystkie pola algorytmu kropki mogą być ustawione na None
, aby określić pusty algorytm kropki, który zamiast tego użyje wartości precision_config
.
Pola DotAlgorithm
:
lhs_precision_type
irhs_precision_type
, dokładność, do której zaokrąglane są wartości po lewej i po prawej stronie operacji. Typy dokładności są niezależne od typów magazynowania danych wejściowych i wyjściowych.accumulation_type
dokładność użyta do skumulowania.lhs_component_count
,rhs_component_count
inum_primitive_operations
stosuje się, gdy używamy algorytmu, który rozkłada lewą lub prawą stronę na kilka komponentów i wykonuje na tych wartościach wiele „prostych” operacji dot – zwykle w celu emulacji większej precyzji (np. Korzystanie z typu danych bfloat16 do obliczeń o większej precyzji: bf16_6x tf32_3x itp.). W przypadku algorytmów bez dekompozycji te wartości powinny wynosić1
.allow_imprecise_accumulation
, aby określić, czy akumulacja z mniejszą dokładnością jest dozwolona w przypadku niektórych kroków (np.CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Przykładowe atrybuty DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
To od implementacji zależy, które kombinacje są obsługiwane. Ogólnie nie ma gwarancji, że każdy algorytm jest obsługiwany przez każdy typ akceleratora przez użytkownika StableHLO. Jeśli dany algorytm nie jest obsługiwany, należy zgłosić błąd zamiast korzystać z alternatywnego rozwiązania. Weryfikacja StableHLO zapewni najlepszą weryfikację, zapobiegając działaniu algorytmów, które nie są obsługiwane na żadnym sprzęcie.
Niektóre obsługiwane wartości algorytmu znajdziesz w sekcji xla_data.proto > Algorithm
. zgłoszenie #2483 zawiera plan stworzenia scentralizowanego dokumentu na temat obsługiwanych algorytmów na zapleczu.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
tensor lub kwantyzowany tensor | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
1-wymiarowa stała tensora typu si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
1-wymiarowa stała tensora typu si64 |
(C1), (C4), (C7), (C9) |
(I5) | lhs_contracting_dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
zmienna liczba typów enumeracji DEFAULT , HIGH i HIGHEST |
(C11), (C21) |
(I8) | lhs_precision_type |
FloatType lub TensorFloat32 | (C21) |
(I9) | rhs_precision_type |
FloatType lub TensorFloat32 | (C21) |
(I10) | accumulation_type |
FloatType lub TensorFloat32 | (C21) |
(I11) | lhs_component_count |
stała typu si32 |
(C21), (C22) |
(I12) | rhs_component_count |
stała typu si32 |
(C21), (C23) |
(I13) | num_primitive_operations |
stała typu si32 |
(C21), (C24) |
(I14) | allow_imprecise_accumulation |
stała typu bool |
(C21) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub kwantyzowany tensor | (C12), (C14), (C18–C20) |
Ograniczenia
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Jeśli operacja używa niespakowanych tensorów:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Jeśli operacja używa kwantowanych tensorów:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Jeśli
is_per_axis_quantized(rhs)
, toquantization_dimension(rhs)
nie jest wrhs_contracting_dimensions
. - Jeśli
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Jeśli
is_per_tensor_quantized(rhs)
, tois_per_tensor_quantized(result)
. - Jeśli
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Jeśli
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Przykłady
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją broadcast_in_dim, ale kształt wyniku jest określany dynamicznie za pomocą parametru output_dimensions
.
Operacja akceptuje też opcjonalne atrybuty known_expanding_dimensions
i known_nonexpanding_dimensions
, które służą do wyrażania statycznej wiedzy o zachowaniu rozszerzania wymiarów.
Jeśli nie są określone, przyjmuje się, że wszystkie wymiary mogą się rozszerzać.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor (tensor kwantowy) | (C1–C2), (C5–C6), (C9) |
(I2) | output_dimensions |
Jednowymiarowy tensor typu liczby całkowitej | (C7) |
(I3) | broadcast_dimensions |
1-wymiarowy tensor stały typu całkowitego | (C2-C6) |
(I4) | known_expanding_dimensions |
Jednowymiarowy tensor stały typu liczby całkowitej | (C8-C9) |
(I5) | known_nonexpanding_dimensions |
1-wymiarowy tensor stały typu całkowitego | (C8–C9) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub kwantyzowany tensor | (C1), (C3), (C5-C7) |
Ograniczenia
- (C1)
element_type(result)
otrzymuje:element_type(operand)
, jeśli!is_per_axis_quantized(operand)
.element_type(operand)
, z tym żequantization_dimension(operand)
,scales(operand)
izero_points(operand)
mogą się różnić odquantization_dimension(result)
,scales(result)
izero_points(result)
odpowiednio.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) W przypadku wszystkich
d
waxes(operand)
:dim(operand, d) = 1
lubdim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Jeśli
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Jeśli
dim(operand, quantization_dimension(operand)) = 1
, toscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Przykłady
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją convolution, ale wypełnienie jest podawane dynamicznie za pomocą padding
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor lub tensor zagregowany z tensorów | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
(I2) | rhs |
tensor lub kwantyzowany tensor | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
2-wymiarowy tensor typu całkowitego | (C4) |
(I4) | window_strides |
Jednowymiarowa stała tensora typu si64 |
(C2-C3) |
(I5) | lhs_dilation |
Jednowymiarowa stała tensora typu si64 |
(C5-C6) |
(I6) | rhs_dilation |
1-wymiarowa stała tensora typu si64 |
(C7-C8) |
(I7) | window_reversal |
1-wymiarowa stała tensora typu i1 |
(C9) |
(I8) | input_batch_dimension |
stała typu si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
stała typu si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C12), (C13) |
(I11) | kernel_input_feature_dimension |
stała typu si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
stała typu si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C17-C18) |
(I14) | output_batch_dimension |
stała typu si64 |
(C20) |
(I15) | output_feature_dimension |
stała typu si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
1-wymiarowa stała tensora typu si64 |
(C19-C20) |
(I17) | feature_group_count |
stała typu si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
stała typu si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
zmienna liczba typów enumeracji DEFAULT , HIGH i HIGHEST |
(C24) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub kwantyzowany tensor | (C25-C27), (C29), (C31-C33) |
Ograniczenia
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Zgodnie z
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Podany
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Zgodnie z
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
jest zdefiniowany jako:dim(lhs, input_batch_dimension) / batch_group_count
jeżeliresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
jeżeliresult_dim = output_feature_dimension
.num_windows
w przeciwnym razie, gdzie:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Jeśli operacja używa tensorów niekwantowych:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Jeśli operacja używa kwantowanych tensorów:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Jeśli
is_per_axis_quantized(rhs)
, toquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Jeśli
is_per_axis_quantized(result)
, toquantization_dimension(result) = output_feature_dimension
. - Jeśli
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Jeśli
is_per_tensor_quantized(rhs)
, tois_per_tensor_quantized(result)
. - Jeśli
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Przykłady
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z gather
op, z tym że slice_sizes
jest dynamicznie określany jako wartość.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensor typu liczba całkowita | (C2), (C3), (C13) |
(I3) | slice_sizes |
1-wymiarowy tensor typu całkowitego | (C8), (C11–C13) |
(I4) | offset_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C4–C5), (C13) |
(I5) | collapsed_slice_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
1-wymiarowa stała tensora typu si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
stała typu si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
stała typu i1 |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C5), (C13–C14) |
Ograniczenia
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
gdzie:batch_dim_sizes = shape(start_indices)
, z tym że nie uwzględnia wymiarustart_indices
odpowiadającego wymiarowiindex_vector_dim
.offset_dim_sizes = shape(slice_sizes)
, z tym że nie uwzględnia wymiarówslice_sizes
odpowiadających wymiarowicollapsed_slice_dims
.- Funkcja
combine
umieszczabatch_dim_sizes
na osi odpowiadającejbatch_dims
, a funkcjaoffset_dim_sizes
– na osi odpowiadającejoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Przykłady
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją iota, ale kształt wyniku jest dynamicznie określany za pomocą parametru output_shape
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | output_shape |
1-wymiarowy tensor typu całkowitego | (C1), (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C2) |
Ograniczenia
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Przykłady
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semantyka
Ta operacja jest funkcjonalnie identyczna z pad, ale z wartościami edge_padding_low
, edge_padding_high
i interior_padding
określonymi dynamicznie.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor 0-wymiarowy lub kwantowy tensor na tensor | (C1) |
(I3) | edge_padding_low |
Jednowymiarowy tensor typu liczby całkowitej | (C1), (C4) |
(I4) | edge_padding_high |
1-wymiarowy tensor typu całkowitego | (C1), (C4) |
(I5) | interior_padding |
Jednowymiarowy tensor typu liczby całkowitej | (C2-C4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C3-C6) |
Ograniczenia
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Przykłady
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semantyka
Ta operacja jest pod względem funkcjonalnym identyczna z operacją reshape, ale kształt wyniku jest dynamicznie określany za pomocą argumentu output_shape
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub kwantyzowany tensor | (K1–C3) |
(I2) | output_shape |
1-wymiarowy tensor typu całkowitego | (C4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub kwantyzowany tensor | (C1–C4) |
Ograniczenia
- (C1)
element_type(result)
otrzymuje:element_type(operand)
, jeśli!is_per_axis_quantized(operand)
.element_type(operand)
– oprócz tychquantization_dimension(operand)
iquantization_dimension(result)
mogą się różnić.
- (C2)
size(operand) = size(result)
. - (C3) Jeśli
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
.
Przykłady
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semantyka
Wyodrębnia wycinek z operand
, używając dynamicznie obliczanych indeksów początkowych i tworzy tensor result
. start_indices
zawiera indeksy początkowe przekroju dla każdego wymiaru, który może zostać dostosowany, a slice_sizes
zawiera rozmiary przekroju dla każdego wymiaru. Bardziej oficjalnie:
result[result_index] = operand[operand_index]
, gdzie:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C4) |
(I2) | start_indices |
zmienna liczba 0-wymiarowych tensorów typu całkowitego | (C2), (C3) |
(I3) | slice_sizes |
Jednowymiarowa stała tensora typu si64 |
(C2), (C4), (C5) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C5) |
Ograniczenia
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Przykłady
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semantyka
Tworzy tensor result
, który jest równy tensorowi operand
, ale wycinek zaczynający się od start_indices
jest aktualizowany o wartości w update
.
Bardziej formalnie result[result_index]
jest zdefiniowana jako:
update[update_index]
jeśli0 <= update_index < shape(update)
gdzie:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- W przeciwnym razie:
operand[result_index]
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1-C4), (C6) |
(I2) | update |
tensor lub tensor zagregowany z tensorów | (C2), (C3), (C6) |
(I3) | start_indices |
zmienna liczba 0-wymiarowych tensorów typu całkowitego | (C4), (C5) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Przykłady
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
wykładniczo
Semantyka
Wykonuje elementarną operację wykładniczą na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
exp
z IEEE-754. - Liczby zespolone: wykładnik zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(exponential, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semantyka
Wykonuje elementarną operację wykładniczą minus 1 na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
expm1
z IEEE-754. - W przypadku liczb zespolonych: zespolona wykładnicza wartość minus 1.
- W przypadku typów skwantowanych:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
FFT
Semantyka
Wykonuje bezpośrednie i odwrotne transformacje Fouriera w przypadku rzeczywistych i zespolonych wejść/wyjść.
fft_type
może mieć jedną z tych wartości:
FFT
: FFT kompleksowe w przód.IFFT
: odwrotna transformacja FFT z kompleksowej na złożoną.RFFT
: przesuwanie widoku w kierunku rzeczywistym do złożonego.IRFFT
: odwrotna transformacja Fouriera z realnego na zespolony (czyli przyjmuje zespolony, zwraca rzeczywisty).
Bardziej formalnie, jeśli dana funkcja fft
przyjmuje jako dane wejściowe 1-wymiarowe tensory złożonych typów, to zwraca 1-wymiarowe tensory tego samego typu co dane wyjściowe i wylicza dyskretną transformację Fouriera:
W przypadku fft_type = FFT
wartość result
jest zdefiniowana jako końcowy wynik serii obliczeń L, gdzie L = size(fft_length)
. Na przykład w przypadku L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Ponadto funkcja ifft
, która ma ten sam podpis typu i oblicza odwrotność fft
:
W przypadku fft_type = IFFT
wartość result
jest zdefiniowana jako odwrotna wartość obliczeń dla fft_type = FFT
. Na przykład w przypadku L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Funkcja rfft
przyjmuje 1-wymiarowe tensory typu zmiennoprzecinkowego i tworzy 1-wymiarowe tensory typu zespolonego o tej samej semantyce zmiennoprzecinkowej. Działa ona w ten sposób:
rfft(real_operand) = truncated_result
gdziecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(gdy dyskretna transformacja Fouriera jest obliczana dla rzeczywistych operandów, pierwsze
N/2 + 1
elementy wyniku jednoznacznie określają pozostałą część wyniku,
dlatego wynik rfft
jest obcinany, aby uniknąć obliczania zbędących elementów).
W przypadku funkcji fft_type = RFFT
wartość result
jest definiowana jako końcowy wynik serii obliczeń L, gdzie L = size(fft_length)
. Na przykład w przypadku L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Na koniec, jeśli dana funkcja irfft
ma tę samą deklarację typu i oblicza odwrotność funkcji rfft
:
W przypadku fft_type = IRFFT
wartość result
jest zdefiniowana jako odwrotna wartość obliczeń dla fft_type = RFFT
. Na przykład w przypadku L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enum FFT , IFFT , RFFT i IRFFT |
(C2), (C5) |
(I3) | fft_length |
1-wymiarowa stała tensora typu si64 |
(C1), (C3), (C4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego | (C2), (C4), (C5) |
Ograniczenia
- (C1)
size(fft_length) <= rank(operand)
. - (C2) Związek między typami elementów
operand
iresult
jest różny:- Jeśli
fft_type = FFT
,element_type(operand)
ielement_type(result)
mają ten sam typ złożony. - Jeśli
fft_type = IFFT
,element_type(operand)
ielement_type(result)
mają ten sam typ złożony. - Jeśli
fft_type = RFFT
,element_type(operand)
jest typem zmiennoprzecinkowym, aelement_type(result)
jest złożonym typem o tej samej semantyce zmiennoprzecinkowej. - Jeśli
fft_type = IRFFT
,element_type(operand)
jest typem złożonym, aelement_type(result)
jest typem zmiennoprzecinkowym o tej samej semantyce zmiennoprzecinkowej.
- Jeśli
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Jeśli wśród elementów
operand
iresult
znajduje się tensorreal
typu zmiennoprzecinkowego, toshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
z wyjątkiem:- Jeśli
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Jeśli
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Jeśli
Przykłady
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piętro
Semantyka
Przeprowadza element po elemencie zaokrąglenie w dół tensora operand
i generuje tensor result
.
Realizuje operację roundToIntegralTowardNegative
według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(floor, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
zbierać
Semantyka
Gromadzi wycinki z tensora operand
z przesunięć określonych w start_indices
i tworzy tensor result
.
Ten diagram pokazuje na konkretnym przykładzie, jak elementy w result
są mapowane na elementy w operand
. Diagram wybiera kilka przykładowych result
indeksów i szczegółowo wyjaśnia, którym indeksom operand
odpowiadają indeksy.
Bardziej formalnie: result[result_index] = operand[operand_index]
, gdzie:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
jest zdefiniowany jako:start_indices[bi0, ..., :, ..., biN]
, gdziebi
to poszczególne elementy wbatch_index
, a element:
jest wstawiany w indeksieindex_vector_dim
, jeśliindex_vector_dim
<rank(start_indices)
.[start_indices[batch_index]]
w innych przypadkach.
- W przypadku
d_operand
waxes(operand)
:full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
ifd_operand = start_index_map[d_start]
.- W przeciwnym razie:
full_start_index[d_operand] = 0
.
- W przypadku
d_operand
waxes(operand)
:full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
jeślid_operand = operand_batching_dims[i_batching]
id_start = start_indices_batching_dims[i_batching]
.full_batching_index[d_operand] = 0
w innych przypadkach.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, gdzieoi
to poszczególne elementy w tablicyoffset_index
, a element0
jest wstawiany na pozycjach z tabliccollapsed_slice_dims
ioperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
.
Jeśli indices_are_sorted
ma wartość true
, implementacja może zakładać, że dane start_indices
są posortowane według argumentu start_index_map
. W przeciwnym razie działanie jest niezdefiniowane. Bardziej formalnie, w przypadku wszystkich i1 < i2
z indices(result)
:full_start_index(i1) <= full_start_index(i2)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
tensor typu liczba całkowita | (C2–C3), (C14), (C17), (C22) |
(I3) | offset_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C13-C17) |
(I7) | start_index_map |
1-wymiarowa stała tensora typu si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
stała typu si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
1-wymiarowa stała tensora typu si64 |
(C9), (C12), (C20–C22) |
(I10) | indices_are_sorted |
stała typu i1 |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C5), (C22-C23) |
Ograniczenia
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
gdzie:batch_dim_sizes = shape(start_indices)
, z tym że nie uwzględnia wymiarustart_indices
odpowiadającego wymiarowiindex_vector_dim
.offset_dim_sizes = slice_sizes
, z tą różnicą, że rozmiary wymiarów w poluslice_sizes
odpowiadające wartościomcollapsed_slice_dims
ioperand_batching_dims
nie są uwzględniane.- Funkcja
combine
umieszczabatch_dim_sizes
na osiach odpowiadających wartościbatch_dims
ioffset_dim_sizes
na osiach odpowiadających wartościoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Przykłady
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semantyka
Zwraca rozmiar danego dimension
w ramach operand
. Bardziej formalnie:
result = dim(operand, dimension)
. Semantyka dotyczy tylko komponentu kształtu danego typu. Typ elementu może być dowolny.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub kwantyzowany tensor | (C1) |
(I2) | dimension |
stała typu si64 |
(C1) |
Wyniki
Nazwa | Typ |
---|---|
result |
Tensor 0-wymiarowy typu si32 |
Ograniczenia
- (C1)
0 <= dimension < rank(operand)
.
Przykłady
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semantyka
Wyodrębnia element w pozycji index
krotki operand
i tworzy result
. Więcej formalnie: result = operand[index]
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tablice | (C1), (C2) |
(I2) | index |
stała typu si32 |
(C1), (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
dowolny obsługiwany typ | (C2) |
Ograniczenia
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Przykłady
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
jeśli
Semantyka
Wynikiem jest wynik wykonania dokładnie jednej funkcji z true_branch
lub false_branch
w zależności od wartości elementu pred
. W bardziej formalnym ujęciu: result =
pred ? true_branch() : false_branch()
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | pred |
0-wymiarowy tensor typu i1 |
|
(I2) | true_branch |
funkcja | (C1-C3) |
(I3) | false_branch |
funkcja | (C1), (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C3) |
Ograniczenia
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Przykłady
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Semantyka
Wyodrębnia część urojona z elementów tensora operand
i tworzy tensor result
. W formalnej postaci dla każdego elementu x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego | (C1), (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego | (C1), (C2) |
Ograniczenia
- (C1)
shape(result) = shape(operand)
. - (C2) Parametr
element_type(result)
jest zdefiniowany jako:complex_element_type(element_type(operand))
jeżeliis_complex(operand)
.- W przeciwnym razie:
element_type(operand)
.
Przykłady
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
infeed
Semantyka
Odczytuje dane z infeedu i generuje results
.
Semantyka elementu infeed_config
jest zdefiniowana przez implementację.
results
składa się z wartości ładunku, które występują na początku, oraz tokena, który występuje na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).
Dane wejściowe
Etykieta | Nazwa | Typ |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
stała typu string |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C1-C3) |
Ograniczenia
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
lubis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Przykłady
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semantyka
Wypełnia tensor output
wartościami w rosnącej kolejności, zaczynając od 0 w wymiarze iota_dimension
. Bardziej formalnie,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
output |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
0 <= iota_dimension < rank(output)
.
Przykłady
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semantyka
Sprawdza, czy wartość w x
jest skończona (tzn.nie jest skończona (tzn. nie jest liczbą +Inf, -Inf ani NaN) i generuje tensor y
. Realizuje operację isFinite
z specyfikacji IEEE-754. W przypadku typów kwantowych wynik to zawsze true
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | x |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
y |
tensor typu logicznego | (C1) |
Ograniczenia
- (C1)
shape(x) = shape(y)
.
Przykłady
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semantyka
Wykonuje operację logarytmiczną z uwzględnieniem elementów na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
log
z IEEE-754. - W przypadku liczb zespolonych: logarytm zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(log, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semantyka
Wykonuje logarytm z elementami oraz 1 operację na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
logp1
z IEEE-754. - W przypadku liczb zespolonych: logarytm zespolony plus 1.
- W przypadku typów skwantowanych:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logistyczna
Semantyka
Wykonuje operacje logistyczne związane z elementami na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
division(1, addition(1, exp(-x)))
z IEEE-754. - W przypadku liczb zespolonych: złożona logistyczna.
- W przypadku typów skwantowanych:
dequantize_op_quantize(logistic, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mapa
Semantyka
Stosuje funkcję mapy computation
do inputs
na dimensions
i tworzy tensor result
.
Więcej formalnie: result[result_index] = computation(inputs...[result_index])
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1–C4) |
(I2) | dimensions |
1-wymiarowa stała tensora typu si64 |
(K3) |
(I3) | computation |
funkcja | (K4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C4) |
Ograniczenia
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
ma typ(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, gdzieEi = element_type(inputs[i])
iE' = element_type(result)
.
Przykłady
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
maksimum
Semantyka
Wykonuje operację elementarnego maksimum na tensorach lhs
i rhs
oraz zwraca tensor result
. W zależności od typu elementu wykonuje te działania:
- W przypadku wartości logicznych: logiczny LUB.
- W przypadku liczb całkowitych: maksymalna liczba całkowita.
- Dla liczb zmiennoprzecinkowych:
maximum
z IEEE-754. - W przypadku liczb zespolonych: maksymalna wartość leksykograficzna pary
(real, imaginary)
. Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560). - W przypadku typów skwantowanych:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
kwantowy tensor lub tensor kwantowy | (C1) |
(I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Przykłady
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
minimum
Semantyka
Wykonuje elementarną operację min na tensorach lhs
i rhs
oraz zwraca tensor result
. W zależności od typu elementu wykonuje te działania:
- W przypadku wartości logicznych: operator logiczny OR.
- W przypadku liczb całkowitych: minimalna liczba całkowita.
- Dla liczb zmiennoprzecinkowych:
minimum
z IEEE-754. - W przypadku liczb zespolonych: leksykograficzne minimum dla pary
(real, imaginary)
. Narzucanie kolejności liczbom zespolonym wiąże się z nieoczywistą semantyką, dlatego w przyszłości planujemy usunąć obsługę liczb zespolonych w przypadku tej operacji (#560). - W przypadku typów kwantowych:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
kwantowy tensor lub tensor kwantowy | (C1) |
(I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Przykłady
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
pomnóż
Semantyka
Wykonuje element po elemencie iloczyn dwóch tensorów lhs
i rhs
oraz generuje tensor result
. W zależności od typu elementu:
- W przypadku wartości logicznych: logiczne I.
- W przypadku liczb całkowitych: mnożenie liczb całkowitych.
- W przypadku jednostek zmiennoprzecinkowych:
multiplication
w standardzie IEEE-754. - Liczby zespolone: mnożenie zespolone.
- W przypadku typów skwantowanych:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
kwantowy tensor lub tensor kwantowy | (C1) |
(I2) | rhs |
tensor lub tensor zagregowany z tensorów | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negacja
Semantyka
Wykonuje negację tensora operand
z uwzględnieniem elementów i generuje tensor result
. W zależności od typu elementu:
- W przypadku liczb całkowitych ze znakiem: negacja liczby całkowitej.
- W przypadku liczb bez znaku: bitowy zapis jako liczba ze znakiem, negacja liczby, bitowy zapis z powrotem jako liczba bez znaku.
- Dla liczb zmiennoprzecinkowych:
negate
z IEEE-754. - Liczby zespolone: negacja zespolona.
- W przypadku typów kwantowych:
dequantize_op_quantize(negate, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
nie
Semantyka
Przeprowadza elementarną operację NOT na tensorze operand
i tworzy tensor result
.
W zależności od typu elementu:
- W przypadku wartości logicznych: NOT logiczna.
- W przypadku liczb całkowitych: bitowa operacja NOT.
Argumenty
Nazwa | Typ | Ograniczenia |
---|---|---|
operand |
tensor typu logicznego lub całkowitego | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu logicznego lub całkowitego | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result)
.
Przykłady
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semantyka
Zapewnienie, że operacje, które generują operand
, są wykonywane przed operacjami, które zależą od result
, oraz zapobieganie przemieszczaniu operacji przez barierę przez przekształcenia kompilatora. W przeciwnym razie operacja jest
tożsamością, czyli result = operand
.
Argumenty
Nazwa | Typ | Ograniczenia |
---|---|---|
operand |
zmienna liczba tensorów, tensorów skwantowanych na podstawie tensora lub tokenów | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
zmienna liczba tensorów, tensorów skwantowanych na podstawie tensora lub tokenów | (C1) |
Ograniczenia
- (C1)
type(operand...) = type(result...)
.
Przykłady
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
lub
Semantyka
Wykonuje elementarną operację OR na dwóch tensorach lhs
i rhs
, tworząc tensor result
. W zależności od typu elementu:
- W przypadku wartości logicznych: operator logiczny LUB.
- W przypadku liczb całkowitych: operator bitowy LUB.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu liczba całkowita lub logiczna | (C1) |
(I2) | rhs |
tensor typu liczba całkowita lub logiczna | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu liczba całkowita lub logiczna | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result)
.
Przykłady
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
outfeed
Semantyka
Zapisuje inputs
w outfiedzie i generuje token result
.
Semantyka elementu outfeed_config
jest zdefiniowana przez implementację.
Dane wejściowe
Etykieta | Nazwa | Typ |
---|---|---|
(I1) | inputs |
zmienna liczba tensorów lub zagęszczonych tensorów |
(I2) | token |
token |
(I3) | outfeed_config |
stała typu string |
Wyniki
Nazwa | Typ |
---|---|
result |
token |
Przykłady
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
pad
Semantyka
Rozszerza operand
przez wypełnienie przestrzeni wokół tensora oraz między elementami tensora za pomocą podanego padding_value
.
Parametry edge_padding_low
i edge_padding_high
określają ilość wypełnień dodawanych na niskich wartościach (obok indeksu 0) i na wysokich wartościach (obok najwyższego indeksu) w każdym wymiarze. Ilość wypełnienia może być ujemna, a wartość bezwzględna ujemnego wypełnienia wskazuje liczbę elementów do usunięcia z wybranego wymiaru.
interior_padding
określa dopełnienie między dwoma dowolnymi elementami w każdym wymiarze, które może nie być ujemne. Wypełnienie wewnętrzne występuje przed wypełnieniem krawędzi, dzięki czemu wypełnienie krawędzi o ujemnej wartości spowoduje usunięcie elementów z operanda z wypełnieniem wewnętrznym.
Bardziej formalnie result[result_index]
jest zdefiniowana jako:
operand[operand_index]
jeśliresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.- W przeciwnym razie:
padding_value
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor 0-wymiarowy lub kwantowy tensor na tensor | (C1) |
(I3) | edge_padding_low |
1-wymiarowa stała tensora typu si64 |
(C1), (C4) |
(I4) | edge_padding_high |
1-wymiarowa stała tensora typu si64 |
(C1), (C4) |
(I5) | interior_padding |
1-wymiarowa stała tensora typu si64 |
(C2-C4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C3-C6) |
Ograniczenia
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Przykłady
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semantyka
Tworzy partition_id
bieżącego procesu.
Wyniki
Nazwa | Typ |
---|---|
result |
0-wymiarowy tensor typu ui32 |
Przykłady
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Semantyka
Wykonuje za pomocą elementu liczby bitów ustawione w tensorze operand
i tworzy tensor result
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu liczba całkowita | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(operand) = type(result)
.
Przykłady
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
moc
Semantyka
Wykonuje elementową potęgowanie tensora lhs
przez tensor rhs
i tworzy tensor result
. W zależności od typu elementu:
- W przypadku liczb całkowitych: wykładnik całkowity.
- Dla liczb zmiennoprzecinkowych:
pow
z IEEE-754. - W przypadku liczb zespolonych: wykładnik zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
(I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora według tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semantyka
Wyodrębnia rzeczywistą część z operand
element po elemencie i tworzy tensor result
. W formalnej postaci dla każdego elementu x
:
real(x) = is_complex(x) ? real_part(x) : x
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego | (C1), (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego | (C1), (C2) |
Ograniczenia
- (C1)
shape(result) = shape(operand)
. - (C2) Parametr
element_type(result)
jest zdefiniowany jako:complex_element_type(element_type(operand))
jeżeliis_complex(operand)
.- W przeciwnym razie:
element_type(operand)
.
Przykłady
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semantyka
Pobiera dane z kanału z parametrem channel_id
i tworzy results
.
Jeśli is_host_transfer
to true
, operacja przenosi dane z hosta. W przeciwnym razie dane są przenoszone z innego urządzenia. Co to oznacza, zależy od implementacji. Ta flaga powiela informacje podane w flagach channel_type
, dlatego w przyszłości planujemy zachować tylko jedną z nich (#666).
results
składa się z wartości ładunku, które występują na początku, oraz tokena, który występuje na końcu. W przyszłości planujemy podzielić ładunek i token na 2 osobne dane wyjściowe, aby zwiększyć przejrzystość (#670).
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
stała typu si64 |
|
(I3) | channel_type |
enum DEVICE_TO_DEVICE i HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
stała typu i1 |
(C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C2–C4) |
Ograniczenia
- (C1)
channel_type
jest zdefiniowany jako:HOST_TO_DEVICE
jeżeliis_host_transfer = true
,DEVICE_TO_DEVICE
w innych przypadkach.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
lubis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Przykłady
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
zmniejszyć
Semantyka
Stosuje funkcję redukcji body
do inputs
i init_values
wzdłuż dimensions
i zwraca tensory results
.
Kolejność redukcji jest zdefiniowana w implementacji, co oznacza, że body
i init_values
muszą utworzyć monoid, aby zagwarantować, że operacja da takie same wyniki przy wszystkich danych wejściowych we wszystkich implementacjach. Jednak w przypadku wielu popularnych rabatów to założenie nie jest spełnione. Przykładowo dodawanie liczb zmiennoprzecinkowych w przypadku wartości body
i zera w przypadku wartości init_values
nie tworzy monoidu, ponieważ dodawanie liczb zmiennoprzecinkowych nie jest skojarzone.
Bardziej formalnie: results...[j0, ..., jR-1] = reduce(input_slices_converted)
, gdzie:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, gdzie:
są wstawiane w miejscudimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
dla niektórych drzew binarnychschedule
gdzie:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
to pełne drzewo binarne zdefiniowane przez implementację, którego przeglądanie w kolejności zgodnej z ich poziomem w drzewie obejmuje:- wartości
input_slices_converted...[index]
dla wszystkichindex
w tablicyindex_space(input_slices_converted)
w rosnącym porządku leksykograficznymindex
. - Wplecione z określoną przez implementację ilością znaków
init_values_converted
w określonych przez implementację pozycjach.
- wartości
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1-C4), (C6), (C7) |
(I2) | init_values |
zmienna liczba 0-wymiarowych tensorów lub tensorów skwantowanych na tensor | (C2), (C3) |
(I3) | dimensions |
1-wymiarowa stała tensora typu si64 |
(C4), (C5), (C7) |
(I4) | body |
funkcja | (C6) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C3), (C7), (C8) |
Ograniczenia
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
ma typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
gdzieis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, z tym że nie uwzględnia wymiaru rozmiary w elementachinputs...
odpowiadających elementomdimensions
. - (C8)
element_type(results[i]) = Ei
dla wszystkichi
w[0,N)
.
Przykłady
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semantyka
Przeprowadza konwersję elementów operand
na inny typ zmiennoprzecinkowy, który używa exponent_bits
i mantissa_bits
, a następnie z powrotem na pierwotny typ zmiennoprzecinkowy. Tworzy też tensor output
.
Bardziej formalnie:
- Bity mantissy pierwotnej wartości są aktualizowane w celu zaokrąglenia pierwotnej wartości do najbliższej wartości, którą można przedstawić za pomocą funkcji
mantissa_bits
przy użyciu semantykiroundToIntegralTiesToEven
. - Następnie, jeśli
mantissa_bits
jest mniejsza od pierwotnej wartości, bity modliszki są obcinane domantissa_bits
. - Następnie, jeśli bity wykładnika wyniku pośredniego nie mieszczą się w zakresie określonym przez
exponent_bits
, wynik pośredni jest przepełniony do nieskończoności z użyciem znaku pierwotnego lub jest podpięty do zera z użyciem znaku pierwotnego. - W przypadku typów skwantowanych wykonuje operację
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
(I2) | exponent_bits |
stała typu si32 |
(C2) |
(I3) | mantissa_bits |
stała typu si32 |
(K3) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
output |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Przykłady
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semantyka
W ramach każdej grupy procesów w siatce procesów StableHLO wykonuje redukcję (za pomocą funkcji computations
) wartości tensora operand
z każdego procesu, dzieli wynik redukcji wzdłuż osi scatter_dimension
na części, a potem rozprasza te części między procesami, aby wygenerować result
.
Operacja dzieli siatkę procesu StableHLO na process_groups
, która jest zdefiniowana w ten sposób:
cross_replica(replica_groups)
jeślichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
ifchannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
ifchannel_id > 0 and use_global_device_ids = true
.
Następnie w każdym wierszu process_group
:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
dla wszystkichsender
wprocess_group
, gdziereceiver_index = process_group.index(receiver)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
stała typu si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
2-wymiarowa stała tensora typu si64 |
(C3-C5) |
(I4) | channel_id |
stała typu si64 |
(C6) |
(I5) | use_global_device_ids |
stała typu i1 |
(C6) |
(I6) | computation |
funkcja | (C7) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C8-C9) |
Ograniczenia
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
jest zdefiniowana jako:num_replicas
, jeśli używana jest właściwośćcross_replica
.num_replicas
, jeśli używana jest właściwośćcross_replica_and_partition
.num_processes
, jeśli używana jest właściwośćflattened_ids
.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Jeśli
use_global_device_ids = true
, tochannel_id > 0
. - (C7)
computation
ma typ(tensor<E>, tensor<E>) -> (tensor<E>)
, gdzieis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
z wyjątkiem:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Przykłady
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semantyka
Stosuje funkcję redukcji body
do okien inputs
i init_values
oraz zwraca wartość results
.
Ten diagram pokazuje, jak elementy w elementach results...
są obliczane na podstawie elementów inputs...
na konkretnym przykładzie.
Bardziej oficjalnie:
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(patrz reduce), gdzie:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | inputs |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C1–C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(I2) | init_values |
liczba zmiennoprzecinkowa tensorów 0-wymiarowych lub kwantyzowanych tensorów na tensor | (C1), (C13) |
(I3) | window_dimensions |
Jednowymiarowa stała tensora typu si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
1-wymiarowa stała tensora typu si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
1-wymiarowa stała tensora typu si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
1-wymiarowa stała tensora typu si64 |
(C10), (C11), (C15) |
(I7) | padding |
2-wymiarowa stała tensora typu si64 |
(C12), (C15) |
(I8) | body |
funkcja | (C13) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C1), (C14–C16) |
Ograniczenia
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
ma typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
gdzieis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
gdzie:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
dla wszystkichi
w[0,N)
.
Przykłady
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
reszta
Semantyka
Wykonuje element po elemencie resztę dzielnika lhs
i dzielonki rhs
tensorów oraz zwraca tensor result
.
Bardziej oficjalnie znak wyniku jest wyliczany z dywidendy, a wartość bezwzględna wyniku jest zawsze mniejsza od wartości bezwzględnej dzielnika.
Reszta jest obliczana według wzoru: lhs - d * rhs
, gdzie d
jest określona przez:
- W przypadku liczb całkowitych:
stablehlo.divide(lhs, rhs)
. - W przypadku liczb zmiennoprzecinkowych:
division(lhs, rhs)
z IEEE-754 z atrybutem zaokrągleniaroundTowardZero
. - W przypadku liczb zespolonych: do ustalenia (#997).
- W przypadku typów skwantowanych:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
W przypadku elementów zmiennoprzecinkowych ta operacja jest niezgodna z operacją remainder
ze specyfikacji IEEE-754, w której d
jest wartością całkowitą zbliżoną do dokładnej wartości parametru lhs/rhs
z opisem równomiernym.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
(I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora na tensor | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semantyka
Tworzy replica_id
bieżącego procesu.
Wyniki
Nazwa | Typ |
---|---|
result |
0-wymiarowy tensor typu ui32 |
Przykłady
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
zmienić kształt
Semantyka
Przekształca tensor operand
w tensor result
. W zasadzie jest to zachowanie tego samego reprezentacji kanonicznej, ale potencjalnie zmiany kształtu, np. z tensor<2x3xf32>
na tensor<3x2xf32>
lub tensor<6xf32>
.
Dokładniej rzecz ujmując: result[result_index] = operand[operand_index]
, gdzie result_index
i operand_index
mają to samo miejsce w kolejności leksykograficznej index_space(result)
i index_space(operand)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub kwantyzowany tensor | (C1-C3) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor (tensor kwantowy) | (C1-C3) |
Ograniczenia
- (C1)
element_type(result)
otrzymuje:element_type(operand)
, jeśli!is_per_axis_quantized(operand)
.element_type(operand)
– oprócz tychquantization_dimension(operand)
iquantization_dimension(result)
mogą się różnić.
- (C2)
size(operand) = size(result)
. - (C3) Jeśli
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Przykłady
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
odwróć
Semantyka
Odwraca kolejność elementów w operand
wzdłuż podanego wymiaru dimensions
i tworzy tensor result
. Bardziej formalnie:
result[result_index] = operand[operand_index]
gdzie:
operand_index[d] = dim(result, d) - result_index[d] - 1
jeżelid
wdimensions
.operand_index[d] = result_index[d]
w innych przypadkach.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1), (C3) |
(I2) | dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C3) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C3) |
Ograniczenia
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Przykłady
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semantyka
Generuje liczby losowe za pomocą algorytmu rng_distribution
i tworzy tensor result
o kształcie shape
.
Jeśli rng_distribution = UNIFORM
, liczby losowe są generowane zgodnie z rozkładem stałym w przedziałzie [a, b)
. Jeśli a >= b
, zachowanie jest nieokreślone.
Jeśli rng_distribution = NORMAL
, liczby losowe są generowane zgodnie z rozkładem normalnym, ze średnią = a
i odchyleniem standardowym = b
.
Jeśli b < 0
, zachowanie jest nieokreślone.
Dokładny sposób generowania liczb losowych jest zdefiniowany w implementacji. Mogą na przykład być deterministyczne, ale nie muszą.
W rozmowach z wielu interesariuszami okazało się, że ta opcja jest w istocie wycofana, dlatego w przyszłości planujemy jej usunięcie (#597).
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | a |
0-wymiarowy tensor typu liczby całkowitej, wartości logicznej lub zmiennoprzecinkowej | (C1), (C2) |
(I2) | b |
0-wymiarowy tensor typu liczby całkowitej, wartości logicznej lub zmiennoprzecinkowej | (C1), (C2) |
(I3) | shape |
1-wymiarowa stała tensora typu si64 |
(C3) |
(I4) | rng_distribution |
wyliczenie UNIFORM i NORMAL |
(C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu całkowita, logiczna lub zmiennoprzecinkowa | (C1-C3) |
Ograniczenia
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Jeśli
rng_distribution = NORMAL
, tois_float(a)
. - (C3)
shape(result) = shape
.
Przykłady
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semantyka
Zwraca pole output
wypełnione jednolitymi losowymi bitami i zaktualizowany stan wyjściowy output_state
za pomocą algorytmu generatora liczb pseudorandom rng_algorithm
z przypisanym stanem początkowym initial_state
. Wyjście jest zdefiniowane jako funkcja deterministyczna initial_state
, ale nie jest deterministyczna w różnych implementacjach.
rng_algorithm
jest jedną z tych:
DEFAULT
: algorytm zdefiniowany przez implementację.THREE_FRY
: wariant algorytmu Threefry zdefiniowany przez implementację.*PHILOX
: wariant algorytmu Philox zdefiniowany przez implementację.*
* Patrz: Salmon et al. SC 2011. Losowe liczby równoległe: to bardzo proste.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | rng_algorithm |
wyliczenie DEFAULT , THREE_FRY i PHILOX |
(C2) |
(I2) | initial_state |
Jednowymiarowy tensor typu ui64 |
(C1), (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
output_state |
1-wymiarowy tensor typu ui64 |
(C1) |
output |
tensor typu liczba całkowita lub zmiennoprzecinkowa |
Ograniczenia
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
zdefiniowano jako:- zdefiniowaną przez implementację, jeśli
rng_algorithm = DEFAULT
. 2
jeżelirng_algorithm = THREE_FRY
.2
lub3
, jeślirng_algorithm = PHILOX
.
- zdefiniowaną przez implementację, jeśli
Przykłady
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semantyka
Wykonuje zaokrąglanie z poziomu elementu do najbliższej liczby całkowitej, zrywając remisy od zera na tensorze operand
i tworzy tensor result
. Realizuje operację roundToIntegralTiesToAway
zgodnie ze specyfikacją IEEE-754. W przypadku typów skompresowanych wykonuje działanie dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semantyka
Zaokrągla elementy tensora operand
do najbliższej liczby całkowitej, rozstrzygając remisy na korzyść parzystej liczby całkowitej, i tworzy tensor result
. Realizuje operację roundToIntegralTiesToEven
według specyfikacji IEEE-754. W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub tensor kwantyzowany na podstawie tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semantyka
Przeprowadza elementarną operację odwrotnego pierwiastka kwadratowego na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
rSqrt
z IEEE-754. - Liczby zespolone: odwrotny pierwiastek kwadratowy z liczby zespolonej.
- W przypadku typów kwantowych:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
rozproszenie
Semantyka
Wyświetla tensory results
, które są równe tensorom inputs
, ale kilka wycinków określonych przez scatter_indices
zostało zaktualizowanych wartościami updates
przy użyciu metody update_computation
.
Ten diagram pokazuje na konkretnym przykładzie, jak elementy w updates...
są mapowane na elementy w results...
. Diagram przedstawia kilka przykładowych indeksów updates...
i szczegółowo wyjaśnia, z którymi indeksami results...
są one powiązane.
Więcej formalnie dla wszystkich update_index
w index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
to:scatter_indices[si0, ..., :, ..., siN]
, gdziesi
to poszczególne elementy wupdate_scatter_index
, a element:
jest wstawiany w indeksieindex_vector_dim
, jeśliindex_vector_dim
<rank(scatter_indices)
.[scatter_indices[update_scatter_index]]
w innych przypadkach.
- Za
d_input
waxes(inputs[0])
:full_start_index[d_input] = start_index[d_start]
jeślid_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
w innych przypadkach.
- W przypadku
d_input
waxes(inputs[0])
:full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
jeślid_input = input_batching_dims[i_batching]
id_start = scatter_indices_batching_dims[i_batching]
.- W przeciwnym razie:
full_batching_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, gdziewi
to poszczególne elementy w tablicyupdate_window_index
, a element0
jest wstawiany na pozycjach z tablicinserted_window_dims
iinput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
W związku z tym results = exec(schedule, inputs)
, gdzie:
schedule
to określona przez implementację permutacja funkcjiindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
, gdzie:- Jeśli
result_index
mieści się w zakresieshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
to kopiaresults
z wartościąresults...[result_index]
ustawioną naupdated_values...
.- W innym przypadku
updated_results = results
.
- Jeśli
exec([], results) = results
.
Jeśli indices_are_sorted
to true
, implementacja może założyć, że scatter_indices
są posortowane zgodnie z scatter_dims_to_operand_dims
. W przeciwnym razie zachowanie jest nieokreślone. Bardziej formalnie dotyczy wszystkich i1 < i2
od
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Jeśli unique_indices
to true
, implementacja może założyć, że wszystkie indeksy result_index
, które są rozproszone, są unikalne. Jeśli unique_indices
to
true
, ale indeksy, do których jest rozproszony, nie są unikalne, działanie jest niezdefiniowane.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tensor typu liczba całkowita | (C4), (C15), (C19), (C22) |
(I3) | updates |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C3-C6), (C8) |
(I4) | update_window_dims |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
1-wymiarowa stała tensora typu si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
1-wymiarowa stała tensora typu si64 |
(C19-C21) |
(I9) | index_vector_dim |
stała typu si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
stała typu i1 |
|
(I11) | unique_indices |
stała typu i1 |
|
(I12) | update_computation |
funkcja | (C23) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C24-C25) |
Ograniczenia
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
, gdzie:update_scatter_dim_sizes = shape(scatter_indices)
, z tym że nie uwzględnia wymiaruscatter_indices
odpowiadającego wymiarowiindex_vector_dim
.update_window_dim_sizes <= shape(inputs[0])
, z tym że nie uwzględnia wymiarówinputs[0]
odpowiadających wymiarominserted_window_dims
iinput_batching_dims
.- Funkcja
combine
umieszcza poleupdate_scatter_dim_sizes
na osiach odpowiadających podziałowiupdate_scatter_dims
iupdate_window_dim_sizes
na osiach odpowiadających wartościupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
ma typ(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, gdzieis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
dla wszystkichi
w[0,N)
.
Przykłady
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
wybierz
Semantyka
Tworzy tensor result
, w którym każdy element jest wybierany z tensora on_true
lub on_false
na podstawie wartości odpowiadającego elementu tensora pred
.
W formie bardziej oficjalnej: result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, gdzie pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. W przypadku typów kwantowych wybiera dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | pred |
tensor typu i1 |
(C1) |
(I2) | on_true |
tensor lub tensor zagregowany z tensorów | (C1-C2) |
(I3) | on_false |
tensor lub tensor zagregowany z tensorów | (K2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C2) |
Ograniczenia
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Przykłady
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semantyka
Rozprasza wartości z tensora source
za pomocą funkcji scatter
na podstawie wyniku reduce_window
z tensora input
za pomocą funkcji select
i tworzy tensor result
.
Na diagramie poniżej widać, jak elementy w elementach result
są obliczane na podstawie elementów operand
i source
.
Bardziej formalnie:
selected_values = reduce_window_without_init(...)
z tymi danymi wejściowymi:inputs = [operand].
window_dimensions
,window_strides
ipadding
, które są używane w takiej postaci, w jakiej zostały przesłane.base_dilations = windows_dilations = 1
.body
jest zdefiniowana jako:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
gdzie
E = element_type(operand)
ireduce_window_without_init
działają dokładnie tak jakreduce_window
, z tą różnicą, żeschedule
wartości podstawowejreduce
(patrz redukcja) nie zawiera wartości init. Obecnie nie określono, co się stanie, jeśli odpowiednie okno nie zawiera wartości (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
, gdzie:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
jeśliselected_values[source_index]
ma elementoperand
zoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub tensor zagregowany z tensorów | (C1-C4), (C6), (C8-C11) |
(I2) | source |
kwantowy tensor lub tensor kwantowy | (C1), (C2) |
(I3) | init_value |
0-wymiarowy tensor lub tensor kwantyzowany na podstawie tensora | (C3) |
(I4) | window_dimensions |
1-wymiarowa stała tensora typu si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
1-wymiarowa stała tensora typu si64 |
(C2), (C6), (C7) |
(I6) | padding |
2-wymiarowa stała tensora typu si64 |
(C2), (C8) |
(I7) | select |
funkcja | (C9) |
(I8) | scatter |
funkcja | (C10) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
kwantowy tensor lub tensor kwantowy | (C11-C12) |
Ograniczenia
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
, gdzie:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
ma typ(tensor<E>, tensor<E>) -> tensor<i1>
, gdzieE = element_type(operand)
. - (C10)
scatter
ma typ(tensor<E>, tensor<E>) -> tensor<E>
, gdzieis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Przykłady
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
wyślij
Semantyka
Wysyła kod inputs
do kanału channel_id
i tworzy token result
.
Jeśli is_host_transfer
to true
, operacja przenosi dane do hosta. W przeciwnym razie dane zostaną przeniesione na inne urządzenie. Co to oznacza, zależy od implementacji. Ta flaga powiela informacje podane w flagach channel_type
, dlatego w przyszłości planujemy zachować tylko jedną z nich (#666).
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | inputs |
zmienna liczba tensorów lub zagęszczonych tensorów | |
(I2) | token |
token |
|
(I3) | channel_id |
stała typu si64 |
|
(I4) | channel_type |
enum DEVICE_TO_DEVICE i DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
stała typu i1 |
(C1) |
Wyniki
Nazwa | Typ |
---|---|
result |
token |
Ograniczenia
- (C1)
channel_type
jest zdefiniowany jako:DEVICE_TO_HOST
jeżeliis_host_transfer = true
,DEVICE_TO_DEVICE
w innych przypadkach.
Przykłady
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semantyka
Przeprowadza elementową operację przesunięcia w lewo na tensorze lhs
o liczbę bitów rhs
i tworzy tensor result
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu liczba całkowita | (C1) |
(I2) | rhs |
tensor typu liczby całkowitej | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result)
.
Przykłady
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semantyka
Przesuwa elementowo w prawo o określoną liczbę bitów (rhs
) elementarną operację arytmetyczną na tensorze lhs
i tworzy tensor result
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu liczba całkowita | (C1) |
(I2) | rhs |
tensor typu liczby całkowitej | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result)
.
Przykłady
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semantyka
Wykonuje logiczne operację przesunięcia w prawo na tensorze lhs
o rhs
bitów i tworzy tensor result
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu liczba całkowita | (C1) |
(I2) | rhs |
tensor typu liczby całkowitej | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu liczba całkowita | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result)
.
Przykłady
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
podpisywanie
Semantyka
Zwraca znak operand
element po elemencie i tworzy tensor result
.
Bardziej formalnie, w przypadku każdego elementu x
semantyka może być wyrażona za pomocą składni Pythona w ten sposób:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(sign, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu signed integer, zmiennoprzecinkowego, zespolonego lub kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
sinus
Semantyka
Wykonuje elementarną operację sinusa na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
sin
z IEEE-754. - W przypadku liczb zespolonych: sinus zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(sine, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
wycinek
Semantyka
Wyodrębnia wycinek z operand
, używając statycznie obliczonych indeksów początkowych i tworzy tensor result
. start_indices
zawiera indeksy początkowe przekroju dla każdego wymiaru, limit_indices
zawiera indeksy końcowe (wykluczające) przekroju dla każdego wymiaru, a strides
zawiera kroki dla każdego wymiaru.
Bardziej formalnie: result[result_index] = operand[operand_index]
, gdzie operand_index = start_indices + result_index * strides
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
kwantowy tensor lub tensor kwantowy | (C1-C3), (C5) |
(I2) | start_indices |
1-wymiarowa stała tensora typu si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
1-wymiarowa stała tensora typu si64 |
(C2), (C3), (C5) |
(I4) | strides |
1-wymiarowa stała tensora typu si64 |
(C2), (C4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub tensor zagregowany z tensorów | (C1), (C5) |
Ograniczenia
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Przykłady
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
sortuj
Semantyka
Sortuje jednowymiarowe wycinki obiektu inputs
wzdłuż wymiaru dimension
według wartości comparator
i tworzy results
.
W przeciwieństwie do podobnych danych wejściowych w innych operacjach funkcja dimension
umożliwia stosowanie wartości ujemnych. Wybrane wartości mają następującą interpretację: W przyszłości możemy zablokować tę funkcję ze względu na spójność (problem #1377).
Jeśli is_stable
ma wartość prawda, sortowanie jest stabilne, co oznacza, że zachowana jest względna kolejność elementów uznanych przez porównywarkę za równe. W przypadku pojedynczego wejścia dwa elementy e1
i e2
są uznawane za równe przez porównywacz, jeśli i tylko jeśli comparator(e1, e2) = comparator(e2, e1) = false
. Zobacz formalny opis poniżej, aby dowiedzieć się, jak to uogólniać do wielu danych wejściowych.
Więcej formalnie dla wszystkich result_index
w index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, gdzieriN
to pojedyncze elementy w komórceresult_index
, a:
jest wstawiony w miejscuadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- gdzie
sort
sortuje jednowymiarowy wycinek w kolejności niemalejącej, oczekując, że funkcjacomparator_together
zwrócitrue
, jeśli argument po lewej stronie jest mniejszy niż argument drugiej strony po prawej stronie. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | inputs |
zmienna liczba tensorów lub tensorów zagregowanych z poszczególnych tensorów | (C1-C5) |
(I2) | dimension |
stała typu si64 |
(C4) |
(I3) | is_stable |
stała typu i1 |
|
(I4) | comparator |
funkcja | (C5) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
liczba zmiennoprzecinkowa tensorów lub kwantyzowane tensory na tensor | (C2), (C3) |
Ograniczenia
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, gdzieR = rank(inputs[0])
. - (C5)
comparator
ma typ(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, gdzieEi = element_type(inputs[i])
.
Przykłady
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semantyka
Wykonuje operację pierwiastka kwadratowego uwzględniającą elementy na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
squareRoot
z IEEE-754. - W przypadku liczb zespolonych: pierwiastek kwadratowy z liczby zespolonej.
- W przypadku typów skwantowanych:
dequantize_op_quantize(sqrt, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
odejmij
Semantyka
Wykonuje element po elemencie odejmowanie dwóch tensorów lhs
i rhs
oraz generuje tensor result
. W zależności od typu elementu:
- W przypadku liczb całkowitych: odejmowanie liczb całkowitych.
- Dla liczb zmiennoprzecinkowych:
subtraction
z IEEE-754. - W przypadku liczb zespolonych: odejmowanie liczb zespolonych.
- W przypadku typów skwantowanych:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
(I2) | rhs |
tensor typu całkowitej, zmiennoprzecinkowego, zespolonego lub kwantyzowany tensora według tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu całkowitego, zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Przykłady
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semantyka
Wykonuje elementarną operację pochodnej cząstkowej na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
tan
z IEEE-754. - W przypadku liczb zespolonych: tangens zespolony.
- W przypadku typów kwantowych:
dequantize_op_quantize(tan, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semantyka
Wykonuje elementarną operację tangensu hiperbolicznego na tensorze operand
i tworzy tensor result
. W zależności od typu elementu:
- Dla liczb zmiennoprzecinkowych:
tanh
z IEEE-754. - W przypadku liczb zespolonych: tangens hiperboliczny zespolony.
- W przypadku typów skwantowanych:
dequantize_op_quantize(tanh, operand, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_type(operand) = baseline_type(result)
.
Przykłady
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transponować
Semantyka
Zamienia wymiary tensora operand
za pomocą permutation
i tworzy tensor result
. Bardziej formalnie: result[result_index] = operand[operand_index]
gdzie result_index[d] = operand_index[permutation[d]]
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor lub kwantyzowany tensor | (C1-C4) |
(I2) | permutation |
1-wymiarowa stała tensora typu si64 |
(C2-C4) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor lub kwantyzowany tensor | (C1), (C3-C4) |
Ograniczenia
- (C1)
element_type(result)
otrzymuje:element_type(operand)
, jeśli!is_per_axis_quantized(operand)
.element_type(operand)
z tym, żequantization_dimension(operand)
iquantization_dimension(result)
mogą się różnić.
- (C2)
permutation
jest permutacjąrange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Jeśli
is_per_axis_quantized(result)
, toquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Przykłady
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semantyka
Rozwiązuje partie układów równań liniowych z macierzowymi współczynnikami o dolnej lub górnej trójkątności.
Bardziej formalnie, przy założeniu a
i b
, result[i0, ..., iR-3, :, :]
jest rozwiązaniem op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
, gdy left_side
jest równe true
lub x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
, gdy left_side
jest równe false
, przy czym zmienna x
jest rozwiązaniem op(a)
, a jej wartość jest określana przez transpose_a
, który może być równy:
NO_TRANSPOSE
: wykonaj operację, używająca
w postaci domyślnej.TRANSPOSE
: wykonaj operację na przekształceniu macierzowyma
.ADJOINT
: wykonaj operację na sprzężonym przekształceniu macierzowyma
.
Dane wejściowe są odczytywane tylko z dolnego trójkąta elementu a
, jeśli lower
to true
, lub z górnego trójkąta elementu a
, jeśli nie. Dane wyjściowe są zwracane w tym samym trójkącie, a wartości w drugim trójkącie są zdefiniowane przez implementację.
Jeśli unit_diagonal
ma wartość true, implementacja może założyć, że elementy diagonalne funkcji a
są równe 1. W przeciwnym razie działanie jest nieokreślone.
W przypadku typów skwantowanych wykonuje działanie dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | a |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1-C3) |
(I2) | b |
tensor typu zmiennoprzecinkowego lub zespolonego albo tensor kwantyzowany na podstawie tensora | (C1-C4) |
(I3) | left_side |
stała typu i1 |
(K3) |
(I4) | lower |
stała typu i1 |
|
(I5) | unit_diagonal |
stała typu i1 |
|
(I6) | transpose_a |
enum NO_TRANSPOSE , TRANSPOSE i ADJOINT |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego lub zespolonego lub kwantyzowany tensora według tensora | (C1) |
Ograniczenia
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) Związek między
shape(a)
ashape(b)
jest zdefiniowany w ten sposób:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Przykłady
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tablice
Semantyka
Tworzy kwaternion result
z wartości val
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | val |
zmienna liczba wartości | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
krotka | (C1) |
Ograniczenia
- (C1)
result
ma typtuple<E0, ..., EN-1>
, gdzieEi = type(val[i])
.
Przykłady
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semantyka
Przekształca element po elemencie kwantyzowany tensor operand
na tensor zmiennoprzecinkowy result
zgodnie z parametrami kwantyzacji zdefiniowanymi przez typ operand
.
Więcej formalnie: result = dequantize(operand)
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor kwantowy | (C1), (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu zmiennoprzecinkowego | (C1), (C2) |
Ograniczenia
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Przykłady
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semantyka
Wykonuje konwersję tensora zmiennoprzecinkowego lub kwantyzowanego tensora operand
na kwantyzowany tensor result
zgodnie z parametrami kwantyzacji zdefiniowanych przez typ result
.
Bardziej formalnie,
- Jeśli
is_float(operand)
:result = quantize(operand, type(result))
.
- Jeśli
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
tensor typu zmiennoprzecinkowego lub kwantyzowanego | (C1), (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor kwantyzowany | (C1), (C2) |
Ograniczenia
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Przykłady
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
podczas gdy
Semantyka
Wyświetla dane wyjściowe z wykonania funkcji body
co najmniej 0 razy, gdy funkcja cond
zwraca wartość true
. Bardziej oficjalnie semantyka można wyrazić za pomocą składni Pythona:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Zachowanie nieskończonego pętli jest nieznane (#383).
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | operand |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C1-C3) |
(I2) | cond |
funkcja | (C1) |
(I3) | body |
funkcja | (C2) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
results |
zmienna liczba tensorów, zaokrąglonych tensorów lub tokenów | (C3) |
Ograniczenia
- (C1)
cond
ma typ(T0, ..., TN-1) -> tensor<i1>
, gdzieTi = type(operand[i])
. - (C2)
body
ma typ(T0, ..., TN-1) -> (T0, ..., TN-1)
, gdzieTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Przykłady
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semantyka
Wykonuje elementarną operację XOR na dwóch tensorach lhs
i rhs
oraz zwraca tensor result
. W zależności od typu elementu wykonuje te działania:
- W przypadku wartości logicznych: XOR logiczny.
- W przypadku liczb całkowitych: XOR bitowy.
Dane wejściowe
Etykieta | Nazwa | Typ | Ograniczenia |
---|---|---|---|
(I1) | lhs |
tensor typu logicznego lub całkowitego | (C1) |
(I2) | rhs |
tensor typu logicznego lub całkowitego | (C1) |
Wyniki
Nazwa | Typ | Ograniczenia |
---|---|---|
result |
tensor typu logicznego lub całkowitego | (C1) |
Ograniczenia
- (C1)
type(lhs) = type(rhs) = type(result)
.
Przykłady
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interoperacyjność dialektów
Obecnie programy StableHLO w naturze czasami zawierają operacje, które nie są zdefiniowane przez StableHLO.
Moduł, funkcja, wywołanie i zwracanie
StableHLO używa operacji MLIR z upstream do operacji ModuleOp, FuncOp, CallOp i ReturnOp. Zrobiliśmy to, aby zapewnić lepszą współpracę z dotychczasowymi mechanizmami MLIR, ponieważ wiele przydatnych przejść jest napisanych z uwzględnieniem operacji FuncOp i ModuleOp, a wiele ścieżek kompilacji oczekuje obecności tych operacji. Do tych operacji stosowane są gwarancje pełnej zgodności. Jeśli cokolwiek w tych działaniach zmieni się w niekompatybilny sposób (np. zostanie usunięty), zostaną dodane odpowiedniki StableHLO, aby zachować zgodność.
CHLO
Opset CHLO zawiera operacje wyższego poziomu, które rozkładają się na StableHLO. Obecnie nie ma żadnych gwarancji zgodności w przypadku CHLO. Aby zapewnić zgodność, przed serializacją należy użyć przejścia chlo-legalize-to-stablehlo.
Operacje kształtu
W społeczności często stosuje się w programach dynamicznych StableHLO pewne operacje z podstawowych dialektów MLIR do wykonywania obliczeń kształtu.
Najczęściej są to operacje shape
, takie jak shape_of
lub num_elements
, operacje tensor
, takie jak dim
lub from_elements
, oraz wbudowany typ index
.
Dynamism RFC > O2 wskazuje, że te typy wykraczają poza zakres, ale ze względu na interoperacyjność uwzględniono w nim niektóre typy index
. W przypadku tych typów i wersji nie ma gwarancji zgodności. Za pomocą przejścia shape-legalize-to-stablehlo można przekształcić te operacje w w pełni obsługiwane operacje StableHLO.
Wycofane operacje
Istnieje kilka operacji StableHLO, które zostały odziedziczone z MHLO, są wycofane i zostaną usunięte z StableHLO. Szczegółowe informacje na ten temat znajdziesz w dokumencie StableHLO v1.0 Cleanup #2283 (Czyszczenie StableHLO w wersji 1.0). W przypadku tych wycofanych rozwiązań występuje problem z lokalizatorem to #2340.
Operacje te można podzielić na kilka kategorii:
- Kategoria operacji StableHLO „Nie w HLO” – początkowo były one częścią zestawu operacji StableHLO, ale później uznano, że nie pasują do niego:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - Nieużywane operacje – te operacje mogły być przydatne w danym momencie, ale były niedostatecznie rozwinięte lub ścieżki korzystające z tych operacji zostały przebudowane tak, aby ich już nie wymagały. Dotyczy to funkcji
map
,tuple
(#598),get_tuple_element
,rng
,complex
(#560) oraz funkcjiwindow_reversal
(#1181).
Niektóre z tych operacji można łatwo usunąć, ponieważ można je wyrazić za pomocą istniejących operacji (broadcast
, create_token
, cross-replica-sum
, dot
, unary_einsum
) i zostaną usunięte po upływie obecnego okna zgodności (6 miesięcy). Inne są nadal analizowane do usunięcia (porównania: einsum
, get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
, window_reversal
). Oczekujemy na opinię społeczności. Te operacje zostaną usunięte lub dodane do specyfikacji z pełną obsługą. Dopóki nie poznamy tych funkcji, gwarantujemy zgodność tylko przez 6 miesięcy.
Wykonanie
Wykonywanie sekwencyjne
Program StableHLO jest wykonywany przez podanie wartości wejściowych do funkcji main
i obliczenie wartości wyjściowych. Wartości wyjściowe funkcji są obliczane przez wykonanie grafu operacji z korzenia w odpowiednim elemencie return
.
Kolejność wykonywania jest określana przez implementację, o ile jest zgodna z przepływem danych, czyli jeśli operacje są wykonywane przed ich użyciem. W StableHLO wszystkie operacje o efektach ubocznych zużywają 1 token i generują 1 token (wiele tokenów można zmultipleksować w jeden token za pomocą funkcji after_all
), więc kolejność wykonywania efektów ubocznych jest również zgodna z dataflow. Na przykład w programie poniżej są 2 możliwe zamówienia: %0
→ %1
→ %2
→ return
i %1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Bardziej oficjalnie proces StableHLO składa się z tych elementów:
1) program StableHLO, 2) stanów operacji (jeszcze niewykonany, już wykonany) oraz 3) wartości pośrednich, nad którymi dany proces pracuje.
Proces zaczyna się od wartości wejściowych funkcji main
, przechodzi przez wykres operacji aktualizowania stanów operacji i wartości pośrednich, a kończy się wartościami wyjściowymi. Dalsze formalizowanie jest jeszcze do ustalenia (#484).
równoległe wykonanie,
Programy StableHLO mogą być wykonywane równolegle, podzielone na siatkę procesów 2D num_replicas
według num_partitions
, z których obydwa mają typ ui32
.
W siatce procesów StableHLO jednocześnie wykonywanych jest num_replicas * num_partitions
procesów StableHLO. Każdy proces ma unikalny element process_id = (replica_id, partition_id)
, gdzie replica_id
w replica_ids = range(num_replicas)
i partition_id
w partition_ids = range(num_partitions)
mają typ ui32
.
Rozmiar siatki procesów jest znany statycznie w przypadku każdego programu (w przyszłości planujemy, aby był on częścią programów StableHLO #650), a pozycja w siatce procesów jest znana statycznie w przypadku każdego procesu. Każdy proces ma dostęp do swojej pozycji w siatce procesów za pomocą operacji replica_id
i partition_id
.
W siatce procesów programy mogą być takie same (w stylu „Jedno program, wiele danych”), różne (w stylu „Wiele programów, wiele danych”) lub mieścić się gdzieś pośrodku. W przyszłości planujemy wprowadzić obsługę innych idiomów definiowania równoległych programów StableHLO, w tym GSPMD (#619).
W ramach siatki procesów procesy są w większości niezależne od siebie – mają oddzielne stany operacji, oddzielne wartości wejścia/pośrednie/wyjścia i większość operacji jest wykonywana oddzielnie w ramach procesów, z wyjątkiem niewielkiej liczby operacji zbiorczych opisanych poniżej.
Ponieważ większość operacji wykonuje się tylko z użyciem wartości z tego samego procesu, zwykle odwołania do tych wartości za pomocą ich nazw są jednoznaczne.
Jednak opis semantyki działań zbiorowych jest niewystarczający. Powoduje to powstanie zapisu name@process_id
, który odwołuje się do wartości name
w konkretnym procesie. (z tego punktu widzenia niekwalifikowana wartość name
może być traktowana jako skrót od wartości name@(replica_id(), partition_id())
).
Kolejność wykonywania procesów jest określana przez implementację, z wyjątkiem synchronizacji wprowadzonej przez komunikację punkt-punkt i operacje zbiorcze, jak opisano poniżej.
Komunikacja punkt-punkt
Procesy StableHLO mogą komunikować się ze sobą za pomocą kanałów StableHLO. Kanał jest reprezentowany przez dodatni identyfikator typu si64
. Za pomocą różnych operacji można wysyłać wartości do kanałów i odbierać je z kanałów.
Dalsze formalizowanie, np. skąd pochodzą te identyfikatory kanałów, jak procesy programów je rozpoznają i jakie synchronizacja jest przez nie wprowadzana, jest jeszcze do ustalenia (#484).
Komunikacja strumieniowa
Każdy proces StableHLO ma dostęp do 2 interfejsów strumieniowego przesyłania danych:
- Infeed, z których można odczytać treści.
- Outfeed, na którym można zapisać dane.
W przeciwieństwie do kanałów, które służą do komunikacji między procesami, a więc mają procesy po obu końcach, dane wejściowe i dane wyjściowe mają zdefiniowany drugi koniec w ramach implementacji.
Dalszą formalizację, np. o tym, jak strumieniowa komunikacja wpływa na kolejność wykonywania i jaki rodzaj synchronizacji jest przez nie wprowadzana, to do ustalenia (#484).
Operacje zbiorcze
W StableHLO jest 6 zbiorowych operacji: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
i reduce_scatter
. Wszystkie te operacje dzielą procesy w sieci procesów StableHLO na grupy procesów StableHLO i wykonują wspólne obliczenia w każdej z nich niezależnie od innych grup procesów.
W ramach każdej grupy procesów operacje zbiorcze mogą wprowadzać barierę synchronizacji. Dalsza formalizacja, np. określenie, kiedy dokładnie ma miejsce synchronizacja, jak dokładnie procesy docierają do tej bariery i co się dzieje, jeśli tego nie zrobią, jest jeszcze nierozstrzygnięta (#484).
Jeśli grupa procesów obejmuje komunikację między partycjami, czyli w grupie procesów są procesy o różnych identyfikatorach partycji, wykonanie operacji zbiorczej wymaga kanału, a operacja zbiorcza musi zawierać dodatnią wartość channel_id
typu si64
. Komunikacja między replikami
nie wymaga kanałów.
Obliczenia wykonywane przez zbiorcze operacje są specyficzne dla poszczególnych operacji i są opisane w poszczególnych sekcjach dotyczących operacji. Jednak strategie, według których siatka procesów jest dzielona na grupy procesów, są wspólne dla tych operacji i opisane w tej sekcji. W formalnym ujęciu StableHLO obsługuje 4 strategię:
cross_replica
W ramach każdej grupy procesów odbywa się tylko komunikacja między replikami. Ta strategia przyjmuje replica_groups
, czyli listę list identyfikatorów replik, i wylicza iloczyn kartezjański replica_groups
i partition_ids
. replica_groups
musi zawierać unikalne elementy i objmować wszystkie replica_ids
. Bardziej formalnie, używając składni Pythona:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku wartości replica_groups = [[0, 1], [2, 3]]
i num_partitions = 2
funkcja cross_replica
zwróci wartość [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
W każdej grupie procesów odbywa się tylko komunikacja między partycjami. Ta strategia wykorzystuje partition_groups
– listę list identyfikatorów partycji – i oblicza kartezjański iloczyn wartości partition_groups
przez replica_ids
.
partition_groups
musi zawierać unikalne elementy i objmować wszystkie partition_ids
.
Bardziej oficjalnie, używając składni Pythona:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku wartości partition_groups = [[0, 1]]
i num_replicas = 4
funkcja cross_partition
zwróci wartość [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
W każdej grupie procesów mogą występować zarówno komunikaty między replikami, jak i między partycjami. Ta strategia wykorzystuje replica_groups
– listę list identyfikatorów replik – i oblicza iloczy kartezjańskie każdego elementu replica_group
według parametru partition_ids
. replica_groups
musi zawierać unikalne elementy i obejmować wszystkie
replica_ids
. Bardziej oficjalnie, używając składni Pythona:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku wartości replica_groups = [[0, 1], [2, 3]]
i num_partitions = 2
funkcja cross_replica_and_partition
zwróci wartość [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Ta strategia przyjmuje flattened_id_groups
– listę list „spłaszczonego” identyfikatora procesu w formie replica_id * num_partitions + partition_id
– i przekształca je w identyfikatory procesów. flattened_id_groups
musi zawierać unikalne elementy i pokrywać wszystkie elementy process_ids
. Bardziej formalnie, używając składni Pythona:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Na przykład w przypadku flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
i num_partitions = 2
, flattened_ids
zwróci wartość [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Dokładność
Obecnie StableHLO nie gwarantuje dokładności liczbowej, ale może się to zmienić w przyszłości (#1156).
Semantyka wykonania kwantowanej operacji
Interpretacja kwantowanych operacji StableHLO może się różnić w zależności od wymagań i możliwości sprzętowych. Na przykład niektóre urządzenia mogą interpretować operacje kwantyzacji za pomocą strategii „dequantyzacja, wykonanie operacji zmiennoprzecinkowej i na koniec ponowna kwantyzacja”. Inne mogą wykonywać całe obliczenia z użyciem arytmetyki całkowitej. Dlatego interpretacja kwantyzowanych operacji StableHLO zależy wyłącznie od konkretnego wdrożenia. Interpretacja kwantyzacji hybrydowej (#1575) powinna opierać się na jej semantyce zgodnie ze specyfikacją (na stronie 1792).
Błędy
Programy StableHLO są weryfikowane za pomocą obszernego zbioru ograniczeń dotyczących poszczególnych operacji, co wyklucza wiele klas błędów przed czasem wykonywania. Nadal jednak mogą wystąpić błędy, np. przepełnienie liczb całkowitych, dostęp poza zakresem itp. O ile nie określono inaczej, wszystkie te błędy powodują zachowanie określone przez implementację, ale może się to w przyszłości zmienić (#1157).
Wyjątki dotyczące liczb zmiennoprzecinkowych
Wyjątkiem od tej reguły są wyjątki o typie zmiennoprzecinkowym w programach StableHLO, które mają dobrze zdefiniowane działanie. Operacje, które powodują wyjątki zdefiniowane przez standard IEEE-754 (nieprawidłowa operacja, dzielenie przez 0, przepełnienie, niedopełnienie lub nieprecyzyjne wyjątki) dają wyniki domyślne (zdefiniowane przez standard) i kontynuują wykonywanie bez podnoszenia odpowiedniej flagi stanu; podobnie jak obsługa wyjątku raiseNoFlag
ze standardu. Wyjątki dotyczące operacji niestandardowych (np. złożonej arytmetyki i niektórych funkcji transcendentalnych) są zdefiniowane przez implementację.
niezgodności kształtów,
StableHLO obsługuje tensory o dynamicznej wielkości. Jednak kształty muszą być zgodne w czasie wykonywania, w przeciwnym razie zachowanie jest nieokreślone. StableHLO nie udostępnia bezpośrednio operacji, która może potwierdzić, że tensor ma określony kształt w czasie działania. Generowanie prawidłowego kodu jest obowiązkiem producenta.
Przykładem prawidłowego programu jest program poniżej. Jednak w czasie działania dokładne kształty obiektów %arg0
i %arg1
muszą być takie same. W przeciwnym razie działanie programu będzie niezdefiniowane:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notacja
Do opisania składni w tym dokumencie użyto zmodyfikowanego rodzaju składni EBNF (ISO/IEC 14977:1996, Wikipedia). 1) regułę zdefiniowano za pomocą metody ::=
, a nie =
:
2) konkatenacja jest wyrażana za pomocą juxtaposition, a nie ,
.
Do opisu semantyki (np. w sekcjach „Typy”, „Stałe” i „Operacje”) używamy formuł opartych na składni Pythona rozszerzonej o obsługę zwięzłego wyrażania operacji na tablicach, jak opisano poniżej. Takie rozwiązanie sprawdza się w przypadku małych fragmentów kodu, ale w rzadkich przypadkach, gdy potrzebne są większe fragmenty kodu, używamy składni Pythona, która jest zawsze wprowadzana bezpośrednio.
Wzory
Na przykładzie ze specyfikacji dot_general
omówimy, jak działają formuły. Jedno z ograniczeń tej operacji wygląda tak: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Nazwy używane w tej formule pochodzą z 2 źródeł: 1) funkcji globalnych, np. dim
, 2) definicji elementów członkowskich odpowiadającego elementu programu, np. lhs
, lhs_batching_dimensions
, rhs
i rhs_batching_dimensions
, zdefiniowanych w sekcji „Wejścia” w funkcji dot_general
.
Jak już wspomnieliśmy, składnia tej formuły jest oparta na Pythonie z kilkoma rozszerzeniami ułatwiającymi zwiężenie kodu. Aby lepiej zrozumieć tę formułę, przekształcimy ją w tradycyjną składnię Pythona.
A) W tych formułach używamy wyrażenia =
do reprezentowania równości, więc pierwszym krokiem do uzyskania składni Pythona jest zastąpienie =
przez ==
w ten sposób: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Formuły te obsługują też wielokropki (...
), które zamieniają wyrażenia skalarne w wyrażenia tensorowe. Krótko mówiąc, f(xs...)
oznacza mniej więcej „dla każdego wektora x
w tensorze xs
oblicz wektor f(x)
, a potem zwracaj wszystkie te wyniki wektorów jako wynik tensora”. W standardowej składni Pythona nasza przykładowa formuła zamienia się na:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Dzięki elipsom często można uniknąć pracy na poziomie poszczególnych skalarów. W niektórych trudnych przypadkach można jednak użyć nieformalnej składni na niższym poziomie, np. w formule start_indices[bi0, ..., :, ..., biN]
ze specyfikacji gather
. W trosce o zwiększenie zwięzłości nie podajemy dokładnego formalizmu, w którym można przetłumaczyć taką składnię na zwykły Python. Mamy nadzieję, że w każdym przypadku będzie ona intuicyjna.
Jeśli zauważysz, że niektóre formuły są nieczytelne, daj nam znać, a spróbujemy je ulepszyć.
Zauważysz też, że formuły używają wielokropków do rozwijania wszelkich rodzajów list, w tym tensorów, list tensorów (które np. mogą powstać na podstawie różnej liczby tensorów) itp. W tym innym obszarze nie ma dokładnego formalności (np. listy nie są nawet częścią intuicyjnego systemu StaableHLO).
C) Ostatnim wartym uwagi sposobem zapisu, którego używamy, jest domyślne przesyłanie. Chociaż przesunięcie StableHLO nie obsługuje jawnego nadawania, formuły już tak, co również służy zwiększeniu zwięzłości. Krótko mówiąc, jeśli w kontekście, w którym oczekuje się tensora, używany jest element skalarny, element skalarny jest rozprowadzany do oczekiwanego kształtu.
Aby kontynuować przykład z dot_general
, zastosuj tu kolejne ograniczenie: 0 <= lhs_batching_dimensions < rank(lhs)
. Zgodnie z definicją w specyfikacji dot_general
element lhs_batching_dimensions
jest tensorem, ale zarówno 0
, jak i rank(lhs)
są skalarami. Gdy zastosujemy przekazywanie niejawne, formuła zmieni się na [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Po zastosowaniu do konkretnej operacji dot_general
ta formuła oceni tensor wartości logicznych. Gdy formuły są używane jako ograniczenia, ograniczenie jest spełnione, jeśli formuła zwraca wartość true
lub tensor zawierający tylko elementy true
.
Nazwy
W formułach zakres leksykalny obejmuje: 1) funkcje globalne, 2) definicje członków,
3) definicje lokalne. Poniżej znajdziesz listę funkcji globalnych. Lista definicji elementów zależy od elementu programu, do którego zastosowano notację:
- W przypadku operacji definicje elementów obejmują nazwy wprowadzone w sekcjach „Wejścia” i „Wyjścia”.
- W przypadku innych elementów definicje członków obejmują części strukturalne elementu programu, nazwane na podstawie odpowiednich nieterminali EBNF. W większości przypadków nazwy tych części strukturalnych są uzyskiwane przez konwersję nazw nieterminali na format snake case (np.
IntegerLiteral
=>integer_literal
), ale czasami nazwy są w tym procesie skracane (np.QuantizationStorageType
=>storage_type
). W takim przypadku nazwy są wprowadzane w sposób jawny, podobnie jak sekcje „Wejścia” i „Wyjścia” w specyfikacjach operacji. - Definicje członków zawsze zawierają element
self
, który odnosi się do odpowiedniego elementu programu.
Wartości
Podczas obliczania formuł są one używane do obsługi tych typów wartości:
1) Value
(rzeczywiste wartości, np. dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
; zawsze znają swój typ),
2) Placeholder
(przyszłe wartości, np. lhs
, rhs
lub result
; ich rzeczywiste wartości nie są jeszcze znane, znane są tylko ich typy),
3) Type
(typy zdefiniowane w sekcji „Typy”),
4) Function
(funkcje globalne zdefiniowane w sekcji „Funkcje”).
W zależności od kontekstu nazwy mogą się odnosić do różnych wartości. W szczególności sekcja „Semantyka” w przypadku operacji (i odpowiednich sekcji w przypadku innych elementów programu) definiuje logikę czasu wykonywania, dzięki czemu wszystkie dane wejściowe są dostępne jako Value
.
Z kolei sekcja „Ograniczenia” odnosząca się do operacji (i ich odpowiedników) definiuje logikę „czas kompilowania”, tj.coś, co jest zwykle wykonywane przed uruchomieniem. Dlatego jako Value
dostępne są tylko stałe dane wejściowe, a inne dane wejściowe – tylko jako Placeholder
.
Nazwy | W sekcji „Semantyka” | W sekcji „Ograniczenia” |
---|---|---|
Funkcje globalne | Function |
Function |
stałe wejścia, | Value |
Value |
Nieciągłe dane wejściowe | Value |
Placeholder |
Wyniki | Value |
Placeholder |
Definicje lokalne | Zależy od definicji | Zależy od definicji |
Przeanalizujmy przykładową operację transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
W przypadku tej operacji permutation
jest stałą, więc jest dostępna jako Value
zarówno w semantyce, jak i w ograniczeniach. Z kolei operand
i result
są dostępne jako Value
w semantyce, ale tylko jako Placeholder
w ograniczeniach.
Funkcje
Konstrukcja typów
Nie ma funkcji, których można używać do tworzenia typów. Zamiast tego używamy bezpośrednio składni typu, ponieważ jest ona zazwyczaj bardziej zwięzła. Na przykład (tensor<E>, tensor<E>) -> (tensor<E>)
zamiast function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funkcje w typach
element_type
jest zdefiniowany na typach tensorów i typach zaokrąglonych tensorów oraz zwraca odpowiednio częśćTensorElementType
lubQuantizedTensorElementType
odpowiadającegoTensorType
lubQuantizedTensorType
.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
to skrót do:is_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
to skrót odis_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
sprawdza, czy typx
można podnieść do typuy
. Jeślix
iy
toQuantizedTensorElementType
, promocja jest stosowana tylko dostorage_type
. Ta konkretna wersja promocji jest obecnie używana w kontekście obliczeń redukcji (więcej informacji znajdziesz w RFC).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
to skrót dois_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Dostępne dla wszystkich typów. Na przykładis_float(x)
zwracatrue
, jeślix
jest wartością typuFloatType
. Jeślix
jest wartością lub miejscem zastępczym, ta funkcja jest skrótem funkcjiis_type_name(type(x))
.max_value(x: Type) -> Value
zwraca maksymalną wartośćTensorElementType
. Jeślix
nie jestTensorElementType
, zwracaNone
.min_value(x: Type) -> Value
zwraca minimalną możliwą wartośćTensorElementType
. Jeślix
nie jestTensorElementType
, zwracaNone
.member_name(x: Value | Placeholder | Type) -> Any
. Dostępne dla wszystkich definicji członkówmember_name
wszystkich typów. Na przykładtensor_element_type(x)
zwraca częśćTensorElementType
odpowiadającego elementuTensorType
. Jeślix
jest wartością lub zmienną, ta funkcja jest skrótem domember_name(type(x))
. Jeślix
nie jest typem, który ma odpowiedni element, wartość lub zastępnik tego typu, zwracaNone
.is_empty_algorithm(*args: Type)
sprawdza, czy wszystkie pola algorytmu dot są ustawione naNone
. Jest to konieczne, ponieważ algorytmy dot mają określone przez implementację domyślne zachowania, więc podanie domyślnej wartości byłoby nieprawidłowe.
Budowa wartości
operation_name(*xs: Value | Type) -> Value
. Dostępne w przypadku wszystkich operacji. Na przykład funkcjaadd(lhs, rhs)
przyjmuje 2 wartości tensoralhs
irhs
oraz zwraca wynik oceny operacjiadd
z tymi danymi wejściowymi. W przypadku niektórych operacji, np.broadcast_in_dim
, typy ich danych wyjściowych są „nośne”, czyli potrzebne do oceny operacji. W tym przypadku funkcja przyjmuje te typy jako argumenty.
Funkcje wartości
Dostępne są wszystkie operatory i funkcje Pythona. Na przykład w Pythonie dostępne są zarówno subskrypcje, jak i wycinki, które można indeksować w tensorach, kwantowych tensorach i tuplach.
to_destination_type(x: Value, destination_type: Type) -> Value
jest zdefiniowany na tensorach i zwraca przekonwertowaną wartośćx
na podstawie wartościtype(x)
idestination_type
w następujący sposób:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Trwają wstępne dyskusje na temat połączenia operacji convert
, uniform_quantize
i uniform_dequantize
(#1576).
Po scaleniu nie potrzebujesz powyższej funkcji i możemy zamiast niej użyć nazwy operacji dla funkcji convert
.
Funkcja
is_nan(x: Value) -> Value
jest zdefiniowana na tensorach i zwraca wartośćtrue
, jeśli wszystkie elementyx
są równeNaN
, a w przeciwnym razie zwraca wartośćfalse
. Jeślix
nie jest tensorem, zwracaNone
.Funkcja
is_sorted(x: Value) -> Value
jest zdefiniowana na tensorach i zwraca wartośćtrue
, jeśli elementyx
są posortowane w kolejności rosnącej według rosnącej kolejności leksykograficznej ich indeksów lubfalse
w innym przypadku. Jeślix
nie jest tensorem, zwracaNone
.Funkcja
is_unique(x: Value) -> Value
jest zdefiniowana na tensorach i zwraca wartośćtrue
, jeślix
nie zawiera podwójnych elementów, a w przeciwnym razie zwraca wartośćfalse
. Jeślix
nie jest tensorem, zwracaNone
.member_name(x: Value) -> Any
jest zdefiniowany dla wszystkich definicji elementówmember_name
wszystkich wartości. Na przykładreal_part(x)
zwraca elementRealPart
należący do odpowiadającego mu elementuComplexConstant
. Jeślix
nie jest wartością, która ma odpowiedni element, zwracaNone
.Funkcja
same(x: Value) -> Value
jest zdefiniowana na tensorach i zwraca wartośćtrue
, jeśli wszystkie elementy tensorax
są sobie równe, a w przeciwnym razie zwraca wartośćfalse
. Jeśli tensor nie ma elementów, liczy się to jako „wszystkie są sobie równe”, tzn. funkcja zwracatrue
. Jeślix
nie jest tensorem, zwracaNone
.split(x: Value, num_results: Value, axis: Value) -> Value
jest zdefiniowany na tensorach i zwracanum_results
przekrojex
wzdłuż osiaxis
. Jeślix
nie jest tensorem anidim(x, axis) % num_results != 0
, zwracaNone
.is_defined_in_parent_scope(x: Value) -> Value
jest zdefiniowana na ciągach znaków i zwracatrue
, jeślix
to nazwa funkcji zdefiniowanej w tym samym zakresie co funkcja nadrzędna dla odpowiedniej op.Argument
is_namespaced_op_name(x: Value) -> Value
jest zdefiniowany na ciągach znaków i zwraca wartośćtrue
, jeślix
jest prawidłową nazwą op, czyli spełnia to wyrażenie regularne:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Obliczenia kształtu
axes(x: Value | Placeholder | Type) -> Value
to skrót do opcjirange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
to skrót doshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
to skrót dolist(map(lambda axis: dim(x, axis), axes))
.Funkcja
index_space(x: Value | Placeholder | Type) -> Value
jest zdefiniowana na tensorach i zwraca indeksysize(x)
dla odpowiednichTensorType
posortowanych w rosnącym porządku alfabetycznym, np.[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Jeślix
nie jest typem tensora, skonwertowanego typu tensora, wartości lub zastępnika jednego z tych typów, zwracaNone
.rank(x: Value | Placeholder | Type) -> Value
to skrót dosize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
jest zdefiniowana w sekcji „Funkcje według typów” za pomocąmember_name
.size(x: Value | Placeholder | Type) -> Value
to skrót doreduce(lambda x, y: x * y, shape(x))
.
Obliczenia kwantyzacji
def baseline_element_type(x: Value | Placeholder | Type) -> Type
to skrót odelement_type(baseline_type(x))
.Funkcja
baseline_type
jest zdefiniowana dla typów tensorów i typów zaokrąglonych tensorów. Przekształca je w „wartość bazową”, czyli typ o tym samym kształcie, ale z parametrami zaokrąglenia typu elementu skonfigurowanymi na wartości domyślne. To przydatna sztuczka, która pozwala jednolicie porównać zarówno tensor, jak i kwantyzowane typy tensorów, co jest potrzebne dość często. W przypadku typów kwantyzowanych umożliwia to porównywanie typów ignorujących parametry kwantyzacji, czylishape
,storage_type
,expressed_type
,storage_min
,storage_max
iquantization_dimension
(w przypadku typu kwantyzowanego na oś), ale wartościscales
izero points
mogą się różnić.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
jest zdefiniowany na podstawie zliczanych typów tensorów i przekształca je w typy tensorów zmiennoprzecinkowych. Polega to na konwertowaniu elementów dyskretnych, które reprezentują wartości całkowite typu magazynowania, na odpowiadające im wartości zmiennoprzecinkowe typu wyrażonego za pomocą punktu zerowego i skali powiązanej z elementem dyskretnym.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
jest zdefiniowany na typach tensorów zmiennoprzecinkowych i przekształca je w typy skonwertowanych tensorów. Polega to na konwersji wartości zmiennoprzecinkowych wyrażonego typu na odpowiadające wartości całkowite typu magazynu za pomocą punktu zerowego i skali powiązanej z kwantowanym typem elementu.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
służy do określania obliczeń elementarnych na kwantyzowaniach tensorów. Dequantuje, czyli zamienia elementy poddane kwantyzacji na ich wyrażone typy, a potem wykonuje operację, a następnie ponownie kwantyzuje, czyli zamienia wyniki z powrotem na ich typy magazynowania. Obecnie ta funkcja działa tylko w przypadku kwantyzacji natężenia. Trwa kwantyzacja według osi (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
służy do określania kwantyzacji tylko wagi dla operacji hybrydowej, która przyjmuje lewą stronę w typie zmiennoprzecinkowym, a prawą – w typie skwantowanym. Dekwantuje zagregowane dane wejściowe w ich wyrażonych typach i wykonuje obliczenia w typie float. Typ elementu wektora lewego argumentu typu float i wyrażony typ zagregowanego prawego argumentu powinny być identyczne.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Obliczenia siatki
cross_partition(replica_groups: Value) -> Value
. Zobacz sekcję „cross_replica” powyżej.cross_replica(replica_groups: Value) -> Value
. Zobacz sekcję „cross_replica” powyżej.cross_replica_and_partition(replica_groups: Value) -> Value
. Zobacz sekcję „cross_replica_and_partition” powyżej.flattened_ids(replica_groups: Value) -> Value
. Zobacz sekcję „flattened_ids” powyżej.
Dynamizm
Wartości StableHLO mogą mieć dynamiczne rozmiary wymiarów, np. tensor<?xi64>
.
Wartości StableHLO nie mogą jednak mieć dynamicznej liczby wymiarów (nieuporządkowany
dynamizm, np. tensor<*xi64>
). Obliczenia i wyniki mogą używać dynamicznych rozmiarów wymiarów, nawet jeśli istnieją ograniczenia rozmiarów. Jeśli to możliwe, ograniczenia są weryfikowane statycznie. W przeciwnym razie są odraczane do działania w czasie działania i niezgodności powodują niezdefiniowane zachowanie. Przykłady znajdziesz poniżej.
Niezgodność kształtów w przypadku operacji jednoelementowych
Rozważ ten program:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Taki program jest nietypowy, ponieważ zwykle wiadomo, jak wygląda wynik, ale nie wiadomo, jak wyglądają dane wejściowe. To jednak prawidłowy program StableHLO. W tym programie nie można statycznie zweryfikować operacji abs
, ponieważ nie znamy dokładnego kształtu operandu. Kształty z pewnością są jednak zgodne i można to sprawdzić statycznie: okazało się, że w czasie działania obiekt ?
to 2
i nie będzie żadnych problemów. Jednak ?
może też okazać się inną liczbą całkowitą, w którym przypadku działanie jest nieokreślone.
Pamiętaj, że jeśli rozmiar wymiaru jest dynamiczny, nie może wystąpić niezdefiniowane działanie. Rzeczywiście nie ma „oczekiwanego” rozmiaru, więc nie może być niezgodności.
Niezgodność kształtów w przypadku operacji binarnych element po elemencie
Rozważ ten program:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
W przypadku operacji binarnych elementowych kształty danych wejściowych i wyniku muszą być zgodne w czasie wykonywania. Podczas kompilowania wymiary statyczne muszą być równe. W przeciwnym razie wystarczy, że będą tylko zgodne. Jeśli cokolwiek w danych wejściowych jest wymiarem dynamicznym, może to spowodować nieokreślone działanie w czasie wykonywania, ponieważ rozmiar dynamiczny może nie pasować do odpowiadającego mu rozmiaru w innym operandzie (czyli stałego lub dynamicznego). Jeśli wszystkie dane wejściowe są statyczne, nie ma znaczenia, czy wynik jest dynamiczny: wymiary znane statycznie będą sprawdzane statycznie, a wymiary dynamiczne nie narzucają żadnych ograniczeń.
Niezgodności kształtów w operacjach, które przyjmują kształt wyjściowy jako operand
Rozważ ten program:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Wartości w operandzie kształtu w czasie wykonywania programu muszą odpowiadać kształtowi wyniku. W przeciwnym razie zachowanie jest nieokreślone. Oznacza to, że w czasie wykonywania %arg0
musi mieć wartość dense<[3, 4]> : tensor<2xi32>
. Jeśli operand kształtu jest stały, można to zweryfikować statycznie. Jeśli kształt wyniku jest w pełni dynamiczny, nie może wystąpić rozbieżność.