StableHLO — это набор операций для операций высокого уровня (HLO) в моделях машинного обучения (ML). StableHLO работает как уровень переносимости между различными платформами машинного обучения и компиляторами машинного обучения: платформы машинного обучения, создающие программы StableHLO, совместимы с компиляторами машинного обучения, которые используют программы StableHLO.
Наша цель — упростить и ускорить разработку машинного обучения за счет большей совместимости между различными платформами машинного обучения (такими как TensorFlow, JAX и PyTorch) и компиляторами машинного обучения (такими как XLA и IREE). С этой целью в этом документе представлена спецификация языка программирования StableHLO.
Данная спецификация содержит три основных раздела. Во-первых, раздел «Программы» описывает структуру программ StableHLO, которые состоят из функций StableHLO, которые сами состоят из операций StableHLO. В этой структуре раздел Ops определяет семантику отдельных операций. Раздел «Выполнение» предоставляет семантику для всех этих операций, выполняемых вместе в программе. Наконец, в разделе «Обозначения» обсуждаются обозначения, используемые в спецификации.
Чтобы просмотреть спецификацию предыдущего выпуска StableHLO, откройте репозиторий интересующего выпуска с тегом . Например, StableHLO v0.19.0 Spec . Чтобы просмотреть изменения, произошедшие при каждом второстепенном обновлении версии StableHLO, обратитесь к журналу версий в VhloDialect.td .
Программы
Program ::= {Func}
Программы StableHLO состоят из произвольного количества функций StableHLO. Ниже приведен пример программы с функцией @main
, которая имеет 3 входа ( %image
, %weights
и %bias
) и 1 выход. Тело функции имеет 6 операций.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Функции
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Функции StableHLO (которые также называются именованными функциями ) имеют идентификатор, входы/выходы и тело. В будущем мы планируем ввести дополнительные метаданные для функций для достижения лучшей совместимости с HLO ( #425 , #626 , #740 , #744 ).
Идентификаторы
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Идентификаторы StableHLO похожи на идентификаторы во многих языках программирования, но имеют две особенности: 1) все идентификаторы имеют символы, которые различают разные типы идентификаторов, 2) идентификаторы значений могут быть полностью числовыми, чтобы упростить создание программ StableHLO.
Типы
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Типы StableHLO подразделяются на типы значений (которые также называются типами первого класса ), которые представляют значения StableHLO, и типы, не являющиеся значениями , которые описывают другие элементы программы. Типы StableHLO похожи на типы во многих языках программирования, при этом основной особенностью является предметно-ориентированный характер StableHLO, что приводит к некоторым необычным результатам (например, скалярные типы не являются типами значений).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Тензорные типы представляют собой тензоры, т.е. многомерные массивы. У них есть форма и тип элемента , где форма представляет собой неотрицательные или неизвестные размеры размеров в порядке возрастания соответствующих размеров (которые также называются осями ), пронумерованных от 0
до R-1
. Число измерений R
называется рангом . Например, tensor<2x3xf32>
— это тип тензора с формой 2x3
и типом элемента f32
. Он имеет два измерения (или, другими словами, две оси) — 0-е измерение и 1-е измерение — размеры которых равны 2 и 3. Его ранг равен 2.
Формы могут быть частично или полностью неизвестными (динамическими), например, tensor<?x2xf64>
частично неизвестен, а tensor<?x?xf64>
полностью неизвестен. Размеры динамических размеров обозначаются знаком ?
. Фигуры не могут быть лишены ранжирования.
В будущем мы планируем изучить расширение типов тензоров за пределы размеров размеров и типов элементов, например, включив макеты ( #629 ) и разреженность ( #1078 ).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Имя | Тип | Ограничения |
---|---|---|
storage_type | целочисленный тип | (С1-С3), (С8) |
storage_min | целочисленная константа | (С1), (С3), (С7) |
storage_max | целочисленная константа | (С2), (С3), (С7) |
expressed_type | тип с плавающей запятой | (С4) |
quantization_dimension | необязательная целочисленная константа | (С10-С12) |
scales | вариативное число констант с плавающей запятой | (С4-С6), (С9), (С10), (С13) |
zero_points | вариативное число целочисленных констант | (С7-С9) |
Типы квантованных элементов представляют собой целочисленные значения типа хранения в диапазоне от storage_min
до storage_max
(включительно), которые соответствуют значениям с плавающей запятой выраженного типа . Для данного целочисленного значения i
соответствующее значение f
с плавающей запятой может быть вычислено как f = (i - zero_point) * scale
, где scale
и zero_point
называются параметрами квантования . storage_min
и storage_max
не являются обязательными в грамматике, но имеют значения по умолчанию min_value(storage_type)
и max_value(storage_type)
соответственно. Типы квантованных элементов имеют следующие ограничения:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Если
is_empty(quantization_dimension)
, тоsize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
На данный момент QuantizationScale
представляет собой константу с плавающей запятой, но существует большой интерес к целочисленным шкалам, представленным множителями и сдвигами. Мы планируем изучить это в ближайшем будущем ( #1404 ).
Продолжается обсуждение семантики QuantizationZeroPoint
, включая тип, значения и может ли быть только одна или потенциально несколько нулевых точек в типе квантованного тензора. По результатам этого обсуждения спецификация нулевых точек может измениться в будущем ( #1405 ).
Другое продолжающееся обсуждение касается семантики QuantizationStorageMin
и QuantizationStorageMax
, чтобы определить, следует ли налагать какие-либо ограничения на эти значения и на значения квантованных тензоров ( #1406 ).
Наконец, мы планируем изучить представление неизвестных масштабов и нулевых точек аналогично тому, как мы планируем изучить представление неизвестных размеров измерений ( #1407 ).
Типы квантованных тензоров представляют собой тензоры с квантованными элементами. Эти тензоры точно такие же, как и обычные тензоры, за исключением того, что их элементы имеют типы квантованных элементов вместо обычных типов элементов.
В квантованных тензорах квантование может быть потензорным , то есть иметь один scale
и zero_point
для всего тензора, или может быть поосевым , то есть иметь несколько scales
и zero_points
, одну пару на срез определенного измерения quantization_dimension
. Более формально, в тензоре t
с поосевым квантованием существуют dim(t, quantization_dimension)
срезы quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
и т. д. Все элементы в i
-м срезе используют scales[i]
и zero_points[i]
в качестве параметров квантования. Типы квантованных тензоров имеют следующие ограничения:
- Для потензорного квантования:
- Никаких дополнительных ограничений.
- Для поосевого квантования:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Типы токенов представляют собой токены, т.е. непрозрачные значения, создаваемые и потребляемые некоторыми операциями. Токены используются для установления порядка выполнения операций, как описано в разделе «Выполнение» .
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Типы кортежей представляют собой кортежи, т. е. гетерогенные списки. Кортежи — это устаревшая функция, которая существует только для совместимости с HLO. В HLO кортежи используются для представления переменных входных и выходных данных. В StableHLO изначально поддерживаются переменные входные и выходные данные, и единственное использование кортежей в StableHLO — это всестороннее представление HLO ABI, где, например, T
, tuple<T>
и tuple<tuple<T>>
могут существенно отличаться в зависимости от конкретной реализации. . В будущем мы планируем внести изменения в HLO ABI, которые могут позволить нам удалить типы кортежей из StableHLO ( #598 ).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Типы элементов представляют собой элементы тензорных типов. В отличие от многих языков программирования, эти типы не являются первоклассными в StableHLO. Это означает, что программы StableHLO не могут напрямую представлять значения этих типов (в результате идиоматично представлять скалярные значения типа T
с помощью 0-мерных тензорных значений типа tensor<T>
).
- Тип Boolean представляет логические значения
true
иfalse
. - Целочисленные типы могут быть знаковыми (
si
) или беззнаковыми (ui
) и иметь одну из поддерживаемых разрядностей (2
,4
,8
,16
,32
или64
). Знаковые типыsiN
представляют целые значения от-2^(N-1)
до2^(N-1)-1
включительно, а беззнаковые типыuiN
представляют целочисленные значения от0
до2^N-1
включительно. - Типы с плавающей запятой могут быть одним из следующих:
-
f8E3M4
,f8E4M3
иf8E5M2
8-битные числа с плавающей запятой в соответствии с соглашениями IEEE-754. - Типы
f8E4M3FN
иf8E5M2
соответствующие соответственно кодировкамE4M3
иE5M2
формата FP8, описанным в разделе «Форматы FP8 для глубокого обучения» . - Типы
f8E4M3FNUZ
иf8E5M2FNUZ
соответствующие кодировкамE4M3
иE5M2
форматов FP8, описанных в 8-битных числовых форматах для глубоких нейронных сетей . - Тип
f8E4M3B11FNUZ
, соответствующий кодировкеE4M3
форматов FP8, описанных в разделе «Обучение и вывод гибридных 8-битных чисел с плавающей запятой (HFP8) для глубоких нейронных сетей» . - Тип
bf16
, соответствующий форматуbfloat16
, описанному в BFloat16: Секрет высокой производительности на Cloud TPU . - Типы
f16
,f32
иf64
соответствующие соответственноbinary16
(«половинная точность»),binary32
(«одинарная точность»)binary64
(«двойная точность»), описанным в стандарте IEEE 754 . - Тип
tf32
соответствует формату TensorFloat32 и имеет ограниченную поддержку в StableHLO. - Типы
f4E2M1FN
,f6E2M3FN
,f6E3M2FN
иf8E8M0FNU
MX (микромасштабирование) описаны в Спецификации форматов микромасштабирования OCP .
-
- Сложные типы представляют собой комплексные значения, которые имеют действительную и мнимую части одного и того же типа элемента . Поддерживаемые сложные типы:
complex<f32>
(обе части имеют типf32
) иcomplex<f64>
(обе части имеют типf64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Типы функций представляют как именованные, так и анонимные функции. У них есть типы ввода (список типов в левой части ->
) и типы вывода (список типов в правой части ->
). Во многих языках программирования типы функций являются первоклассными, но не в StableHLO.
StringType ::= 'string'
Тип String представляет собой последовательность байтов. В отличие от многих языков программирования, строковый тип не является первым классом в StableHLO и используется только для указания статических метаданных для элементов программы.
Операции
Операции StableHLO (которые также называются ops ) представляют собой закрытый набор операций высокого уровня в моделях машинного обучения. Как обсуждалось выше, синтаксис StableHLO во многом основан на MLIR, который не обязательно является наиболее эргономичной альтернативой, но, возможно, лучше всего подходит для цели StableHLO по созданию большей совместимости между платформами ML и компиляторами ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Операции StableHLO (которые также называются ops ) имеют имя, входы/выходы и подпись. Название состоит из stablehlo.
префикс и мнемоника , которая однозначно идентифицирует одну из поддерживаемых операций. Ниже приведен полный список всех поддерживаемых операций.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Операционные операторы потребляют входные данные и производят выходные данные . Входные данные подразделяются на входные значения (вычисляемые во время выполнения), входные функции (предоставляемые статически, поскольку в StableHLO функции не являются значениями первого класса) и входные атрибуты (также предоставляемые статически). Вид входных и выходных данных, потребляемых и производимых операцией, зависит от ее мнемоники. Например, операция add
потребляет 2 входных значения и производит 1 выходное значение. Для сравнения, операция select_and_scatter
использует 3 входных значения, 2 входные функции и 3 входных атрибута.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Функции ввода (которые также называются анонимными функциями ) очень похожи на именованные функции, за исключением того, что: 1) они не имеют идентификатора (отсюда и название «анонимные»), 2) они не объявляют типы вывода (типы вывода выводится из операции return
внутри функции).
Синтаксис функций ввода включает в себя неиспользуемую в настоящее время часть (см. раздел « Unused
» выше), которая предназначена для совместимости с MLIR. В MLIR существует более общая концепция «регионов», которая может состоять из нескольких «блоков» операций, соединенных вместе посредством операций перехода. Эти блоки имеют идентификаторы, соответствующие Unused
продукции, чтобы их можно было отличить друг от друга. В StableHLO нет прыжковых операций, поэтому соответствующая часть синтаксиса MLIR не используется (но все еще существует).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Входные атрибуты имеют имя и значение, которое является одной из поддерживаемых констант. Они являются основным способом указания статических метаданных для элементов программы. Например, операция concatenate
использует dimension
атрибута, чтобы указать измерение, по которому объединяются его входные значения. Аналогично, операция slice
использует несколько атрибутов, таких как start_indices
и limit_indices
для указания границ, которые используются для среза входного значения.
На данный момент существующие программы StableHLO иногда содержат атрибуты, не описанные в этом документе. В будущем мы планируем либо включить эти атрибуты в опсет StableHLO, либо запретить их появление в программах StableHLO. А пока вот список этих атрибутов:
-
layout
( #629 ). -
mhlo.frontend_attributes
( #628 ). -
mhlo.sharding
( #619 ). -
output_operand_aliases
( #740 ). - Метаданные местоположения ( #594 ).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Сигнатура операции состоит из типов всех входных значений (список типов в левой части ->
) и типов всех выходных значений (список типов в правой части ->
). Строго говоря, входные типы избыточны, и выходные типы также почти всегда избыточны (поскольку для большинства операций StableHLO типы выходных данных могут быть выведены из входных данных). Тем не менее, подпись op намеренно является частью синтаксиса StableHLO для совместимости с MLIR.
Ниже приведен пример операции, мнемоника которой — select_and_scatter
. Он использует 3 входных значения ( %operand
, %source
и %init_value
), 2 входные функции и 3 входных атрибута ( window_dimensions
, window_strides
и padding
). Обратите внимание, что подпись операции включает только типы ее входных значений (но не типы входных функций и атрибутов, которые предоставляются в строке).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Константы
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Константы StableHLO имеют литерал и тип, которые вместе представляют значение StableHLO. Обычно тип является частью синтаксиса константы, за исключением случаев, когда он однозначен (например, логическая константа однозначно имеет тип i1
, тогда как целочисленная константа может иметь несколько возможных типов).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Булевы константы представляют логические значения true
и false
. Булевы константы имеют тип i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Целочисленные константы представляют целочисленные значения посредством строк, в которых используется десятичная или шестнадцатеричная запись. Другие системы счисления, например двоичная или восьмеричная, не поддерживаются. Целочисленные константы имеют следующие ограничения:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Константы с плавающей запятой представляют значения с плавающей запятой в виде строк, в которых используется десятичная или экспоненциальная запись. Кроме того, шестнадцатеричная запись может использоваться для непосредственного указания базовых битов в формате с плавающей запятой соответствующего типа. Константы с плавающей запятой имеют следующие ограничения:
- (C1) Если используется нешестнадцатеричная система записи,
is_wellformed(float_literal, float_type)
. - (C2) Если используется шестнадцатеричная запись,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Комплексные константы представляют комплексные значения с использованием списков вещественной части (идет первой) и мнимой части (идет второй). Например, (1.0, 0.0) : complex<f32>
представляет 1.0 + 0.0i
, а (0.0, 1.0) : complex<f32>
представляет 0.0 + 1.0i
. Порядок, в котором эти части затем сохраняются в памяти, определяется реализацией. Комплексные константы имеют следующие ограничения:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Тензорные константы представляют значения тензора с использованием вложенных списков, заданных с помощью нотации NumPy. Например, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
представляет значение тензора со следующим сопоставлением индексов с элементами: {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
, {1, 2} => 6
. Порядок, в котором эти элементы затем сохраняются в памяти, определяется реализацией. Тензорные константы имеют следующие ограничения:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, где:-
has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
. -
has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
-
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, где:-
has_shape(element_literal: Syntax, []) = true
. -
has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
. - в противном случае
false
.
-
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Квантованные тензорные константы представляют квантованные тензорные значения с использованием тех же обозначений, что и тензорные константы, с элементами, заданными как константы их типа хранения. Квантованные тензорные константы имеют следующие ограничения:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Строковые литералы состоят из байтов, заданных с помощью символов ASCII и escape-последовательностей. Они не зависят от кодировки, поэтому интерпретация этих байтов определяется реализацией. Строковые литералы имеют тип string
.
Операции
пресс
Семантика
Выполняет поэлементную операцию abs над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел со знаком: целочисленный модуль.
- Для поплавков:
abs
из IEEE-754. - Для комплексных чисел: комплексный модуль.
- Для квантованных типов:
dequantize_op_quantize(abs, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целого числа со знаком, с плавающей запятой или комплексного типа или по-тензорный квантованный тензор | (С1-С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целого числа со знаком или типа с плавающей запятой или потензорный квантованный тензор | (С1-С2) |
Ограничения
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
определяется как:-
complex_element_type(element_type(operand))
ifis_complex(operand)
. -
baseline_element_type(operand)
в противном случае.
-
Примеры
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
добавлять
Семантика
Выполняет поэлементное сложение двух тензоров lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое ИЛИ.
- Для целых чисел: сложение целых чисел.
- Для поплавков:
addition
из IEEE-754. - Для комплексных чисел: комплексное сложение.
- Для квантованных типов:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор или квантованный тензор | (С1-С6) |
(И2) | rhs | тензор или квантованный тензор | (С1-С5), (С7) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С7) |
Ограничения
- Если в операции используются неквантованные тензоры:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Если в операции используются квантованные тензоры:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Если
is_per_axis_quantized(lhs)
, тоquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Примеры
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
после всего
Семантика
Гарантирует, что операции, производящие inputs
, выполняются до выполнения любых операций, зависящих от result
. Выполнение этой операции ничего не делает, она существует только для того, чтобы установить зависимости данных от result
до inputs
.
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | вариативное число token |
Выходы
Имя | Тип |
---|---|
result | token |
Примеры
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
все_собрать
Семантика
Внутри каждой группы процессов в сетке процессов StableHLO объединяет значения тензоров operands
каждого процесса по all_gather_dim
и создает тензоры results
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
еслиchannel_id <= 0 and use_global_device_ids = false
. -
cross_replica_and_partition(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = false
. -
flattened_ids(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = true
.
Затем внутри каждой process_group
:
-
operands...@receiver = [operand@sender for sender in process_group]
для всехreceiver
вprocess_group
. -
results...@process = concatenate(operands...@process, all_gather_dim)
для всехprocess
вprocess_group
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operands | вариативное число тензоров или потензорные квантованные тензоры | (С1), (С6) |
(И2) | all_gather_dim | константа типа si64 | (С1), (С6) |
(И3) | replica_groups | 2-мерная тензорная константа типа si64 | (С2-С4) |
(И4) | channel_id | константа типа si64 | (С5) |
(И5) | use_global_device_ids | константа типа i1 | (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С6) |
Ограничения
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_replicas
если используетсяcross_replica_and_partition
. -
num_processes
если используетсяflattened_ids
.
-
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Если
use_global_device_ids = true
, тоchannel_id > 0
. - (C6)
type(results...) = type(operands...)
за исключением:-
dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
-
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Семантика
В каждой группе процессов в сетке процессов StableHLO применяет computation
функции сокращения к значениям тензоров operands
каждого процесса и создает тензоры results
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
еслиchannel_id <= 0 and use_global_device_ids = false
. -
cross_replica_and_partition(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = false
. -
flattened_ids(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = true
.
Затем внутри каждой process_group
:
-
results...@process[result_index] = exec(schedule)
для некоторогоschedule
двоичного дерева, где:-
exec(node)
=computation(exec(node.left), exec(node.right))
. -
exec(leaf)
=leaf.value
.
-
-
schedule
— это двоичное дерево, определяемое реализацией, порядок обхода которого равенto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operands | вариативное число тензоров или потензорные квантованные тензоры | (С5), (С6) |
(И2) | replica_groups | вариатическое число одномерных тензорных констант типа si64 | (С1-С3) |
(И3) | channel_id | константа типа si64 | (С4) |
(И4) | use_global_device_ids | константа типа i1 | (С4) |
(И5) | computation | функция | (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С6-С7) |
Ограничения
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_replicas
если используетсяcross_replica_and_partition
. -
num_processes
если используетсяflattened_ids
.
-
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Если
use_global_device_ids = true
, тоchannel_id > 0
. - (C5)
computation
имеет тип(tensor<E>, tensor<E>) -> (tensor<E>)
гдеis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Семантика
Внутри каждой группы процессов в сетке процессов StableHLO значения тензоров operands
разделяются по split_dimension
на части, распределяются разделенные части между процессами, объединяются разбросанные части по concat_dimension
и выдаются тензоры results
. Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
, еслиchannel_id <= 0
. -
cross_partition(replica_groups)
еслиchannel_id > 0
.
Затем внутри каждой process_group
:
-
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
для всехsender
вprocess_group
. -
scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
, гдеreceiver_index = process_group.index(receiver)
. -
results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operands | вариативное число тензоров или потензорные квантованные тензоры | (С1-С3), (С9) |
(И2) | split_dimension | константа типа si64 | (С1), (С2), (С9) |
(И3) | concat_dimension | константа типа si64 | (С3), (С9) |
(И4) | split_count | константа типа si64 | (С2), (С4), (С8), (С9) |
(И5) | replica_groups | 2-мерная тензорная константа типа si64 | (С5-С8) |
(И6) | channel_id | константа типа si64 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С9) |
Ограничения
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_partitions
если используетсяcross_partition
.
-
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
кроме случая, когдаsplit_dimension != concat_dimension
:-
dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
. -
dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
-
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
и
Семантика
Выполняет поэлементное «И» двух тензоров lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое И.
- Для целых чисел: побитовое И.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор логического или целочисленного типа | (С1) |
(И2) | rhs | тензор логического или целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического или целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
Атан2
Семантика
Выполняет поэлементную операцию atan2 над тензорами lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
atan2
из IEEE-754. - Для комплексных чисел: комплекс atan2.
- Для квантованных типов:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
пакет_норм_град
Семантика
Вычисляет градиенты нескольких входных данных batch_norm_training
с обратным распространением ошибки от grad_output
и создает тензоры grad_operand
, grad_scale
и grad_offset
. Более формально эту операцию можно выразить как декомпозицию существующих операций StableHLO с использованием синтаксиса Python следующим образом:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Для квантованных типов выполняет dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1-С3), (С5) |
(И2) | scale | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4), (С5) |
(И3) | mean | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
(И4) | variance | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
(И5) | grad_output | тензор типа с плавающей запятой или потензорный квантованный тензор | (С2), (С3) |
(И6) | epsilon | константа типа f32 | |
(I7) | feature_index | константа типа si64 | (С1), (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
grad_operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С2), (С3) |
grad_scale | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
grad_offset | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
Ограничения
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
иgrad_offset
имеют одинаковыйbaseline_element_type
. - (C3)
operand
,grad_output
иgrad_operand
имеют одинаковую форму. - (C4)
scale
,mean
,variance
,grad_scale
иgrad_offset
имеют одинаковую форму. - (C5)
size(scale) = dim(operand, feature_index)
.
Примеры
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
пакетная норма_inference
Семантика
Нормализует тензор operand
по всем измерениям, кроме измерения feature_index
, и создает result
тензор. Более формально эту операцию можно выразить как декомпозицию существующих операций StableHLO с использованием синтаксиса Python следующим образом:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Для квантованных типов выполняет dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1-С7) |
(И2) | scale | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С3) |
(И3) | offset | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
(И4) | mean | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С5) |
(И5) | variance | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С6) |
(И6) | epsilon | константа типа f32 | |
(I7) | feature_index | константа типа si64 | (С1), (С3-С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С2), (С7) |
Ограничения
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
иresult
имеют один и тот жеbaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
Batch_norm_training
Семантика
Вычисляет среднее значение и дисперсию по всем измерениям, за исключением измерения feature_index
, и нормализует тензор operand
, создавая output
, тензоры batch_mean
и batch_var
. Более формально эту операцию можно выразить как декомпозицию существующих операций StableHLO с использованием синтаксиса Python следующим образом:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Для квантованных типов выполняет dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
(И2) | scale | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С3) |
(И3) | offset | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С4) |
(И4) | epsilon | константа типа f32 | (С1), (С3-С6) |
(И5) | feature_index | константа типа si64 | (С1), (С3-С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор типа с плавающей запятой или потензорный квантованный тензор | (С7) |
batch_mean | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С5) |
batch_var | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С6) |
Ограничения
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
иoutput
имеют один и тот жеbaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Примеры
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Семантика
Выполняет операцию побитового преобразования тензора operand
и создает тензор result
, в котором биты всего тензора operand
переинтерпретируются с использованием типа тензора result
.
Более формально, учитывая E = element_type(operand)
, E' = element_type(result)
и R = rank(operand)
:
- Если
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Если
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Если
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
возвращает представление данного значения в памяти, и его поведение определяется реализацией, поскольку точное представление тензоров определяется реализацией, а точное представление типов элементов также определяется реализацией.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С2) |
Ограничения
- (C1) Учитывая
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
иR = rank(operand)
:- Если
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Если
num_bits(E') < num_bits(E)
: -
rank(result) = R + 1
. -
dim(result, i) = dim(operand, i)
для всех0 <= i < R
. -
dim(result, R) * num_bits(E') = num_bits(E)
. - Если
num_bits(E') > num_bits(E)
: -
rank(result) = R - 1
. -
dim(result, i) = dim(operand, i)
для всех0 <= i < R
. -
dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Если
- (C2) Если
is_complex(operand) or is_complex(result)
, тоis_complex(operand) and is_complex(result)
.
Примеры
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
Broadcast_in_dim
Семантика
Расширяет размеры и/или ранг входного тензора путем дублирования данных в тензоре operand
и создает result
тензор. Более формально, result[result_index] = operand[operand_index]
где для всех d
в axes(operand)
:
-
operand_index[d] = 0
еслиdim(operand, d) = 1
. -
operand_index[d] = result_index[broadcast_dimensions[d]]
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С2), (С5-С6) |
(И2) | broadcast_dimensions | 1-мерная тензорная константа типа si64 | (С2-С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1), (С3), (С5-С6) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
,scales(operand)
иzero_points(operand)
могут отличаться отquantization_dimension(result)
,scales(result)
иzero_points(result)
соответственно, в противном случае.
-
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Для всех
d
вaxes(operand)
:-
dim(operand, d) = 1
или -
dim(operand, d) = dim(result, broadcast_dimensions[d])
.
-
- (C6) Если
is_per_axis_quantized(result)
:-
quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
. - Если
dim(operand, quantization_dimension(operand)) = 1
, тоscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
-
Примеры
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
случай
Семантика
Производит результат выполнения ровно одной функции из branches
в зависимости от значения index
. Более формально, result = selected_branch()
где:
-
selected_branch = branches[index]
if0 <= index < size(branches)
. -
selected_branch = branches[-1]
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | index | 0-мерный тензор типа si32 | |
(И2) | branches | вариативное число функций | (С1-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С4) |
Ограничения
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Примеры
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CBRT
Семантика
Выполняет поэлементную операцию кубического корня над тензором operand
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
rootn(x, 3)
из IEEE-754. - Для комплексных чисел: комплексный кубический корень.
- Для квантованных типов:
dequantize_op_quantize(cbrt, operand, type(result))
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
клетка
Семантика
Выполняет поэлементную ячейку тензора operand
и создает тензор result
. Реализует операцию roundToIntegralTowardPositive
из спецификации IEEE-754. Для квантованных типов выполняется dequantize_op_quantize(ceil, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
холецкий
Семантика
Вычисляет разложение Холецкого пакета матриц.
Более формально, для всех i
в index_space(result)
, result[i0, ..., iR-3, :, :]
является разложением Холецкого a[i0, ..., iR-3, :, :]
, в виде нижне-треугольной (если lower
имеет true
) или верхнетреугольной (если lower
является false
) матрицы. Выходные значения в противоположном треугольнике, т.е. строгом верхнем треугольнике или строгом нижнем треугольнике соответственно, определяются реализацией.
Если существует i
, где входная матрица не является эрмитовой положительно определенной матрицей, то поведение не определено.
Для квантованных типов выполняется dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | a | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1-С3) |
(И2) | lower | 0-мерная тензорная константа типа i1 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Примеры
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
зажим
Семантика
Зажимает каждый элемент тензора operand
между минимальным и максимальным значением и создает result
тензор. Более формально, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element)
, где min_element = rank(min) = 0 ? min[] : min[result_index]
, max_element = rank(max) = 0 ? max[] : max[result_index]
. Для квантованных типов выполняется dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел для этой операции ( #560 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | min | тензорный или потензорный квантованный тензор | (С1), (С3) |
(И2) | operand | тензорный или потензорный квантованный тензор | (С1-С4) |
(И3) | max | тензорный или потензорный квантованный тензор | (С2), (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С4) |
Ограничения
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
коллективное_вещание
Семантика
В каждой группе процессов в сетке процессов StableHLO отправьте значение тензора operand
из исходного процесса в целевые процессы и создайте тензор result
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
, еслиchannel_id <= 0
. -
cross_partition(replica_groups)
еслиchannel_id > 0
.
После этого result@process
определяется следующим образом:
-
operand@process_groups[i, 0]
если существуетi
такой, что процесс находится вprocess_groupsprocess_groups[i]
. -
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
иначе.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С3) |
(И2) | replica_groups | вариатическое число одномерных тензорных констант типа si64 | (С1), (С2) |
(И3) | channel_id | константа типа si64 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С3) |
Ограничения
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, гдеN
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_partitions
если используетсяcross_partition
.
-
- (C3)
type(result) = type(operand)
.
Примеры
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
коллективный_пермуте
Семантика
Внутри каждой группы процессов в сетке процессов StableHLO отправляет значение тензора operand
из исходного процесса в целевой процесс и создает тензор result
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(source_target_pairs)
, еслиchannel_id <= 0
. -
cross_partition(source_target_pairs)
, еслиchannel_id > 0
.
После этого result@process
определяется следующим образом:
-
operand@process_groups[i, 0]
, если существуетi
такой,process_groups[i, 1] = process
. -
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
иначе.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С5) |
(И2) | source_target_pairs | 2-мерная тензорная константа типа si64 | (С1-С4) |
(И3) | channel_id | константа типа si64 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, гдеN
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_partitions
если используетсяcross_partition
.
-
- (C5)
type(result) = type(operand)
.
Примеры
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
сравнивать
Семантика
Выполняет поэлементное сравнение тензоров lhs
и rhs
в соответствии с comparison_direction
и compare_type
и создает result
тензор.
Значения comparison_direction
и compare_type
имеют следующую семантику:
Для логических и целочисленных типов элементов:
-
EQ
:lhs = rhs
. -
NE
:lhs != rhs
. -
GE
:lhs >= rhs
. -
GT
:lhs > rhs
. -
LE
:lhs <= rhs
. -
LT
:lhs < rhs
.
Для типов элементов с плавающей запятой с compare_type = FLOAT
оператор реализует следующие операции IEEE-754:
-
EQ
:compareQuietEqual
. -
NE
:compareQuietNotEqual
. -
GE
:compareQuietGreaterEqual
. -
GT
:compareQuietGreater
. -
LE
:compareQuietLessEqual
. -
LT
:compareQuietLess
.
Для типов элементов с плавающей запятой с compare_type = TOTALORDER
оператор использует комбинацию операций totalOrder
и compareQuietEqual
из IEEE-754.
Для сложных типов элементов лексикографическое сравнение пар (real, imag)
выполняется с использованием предоставленных comparison_direction
и compare_type
. Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел, когда comparison_direction
— GE
, GT
, LE
или LT
( #560 ).
Для квантованных типов. выполняет dequantize_compare(lhs, rhs, comparison_direction)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1-С3) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1-С2) |
(И3) | comparison_direction | перечисление EQ , NE , GE , GT , LE и LT | |
(И4) | compare_type | перечисление FLOAT , TOTALORDER , SIGNED и UNSIGNED | (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического типа | (С2) |
Ограничения
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
определяется как:-
SIGNED
, еслиis_signed_integer(element_type(lhs))
. -
UNSIGNED
, еслиis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
. -
FLOAT
илиTOTALORDER
, еслиis_float(element_type(lhs))
. -
FLOAT
, еслиis_complex(element_type(lhs))
.
-
Примеры
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
сложный
Семантика
Выполняет поэлементное преобразование в комплексное значение из пары действительных и мнимых значений, lhs
и rhs
, и создает result
тензор.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор типа f32 или f64 | (С1-С3) |
(И2) | rhs | тензор типа f32 или f64 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор комплексного типа | (С2), (С3) |
Ограничения
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
имеет типcomplex<E>
, гдеE = element_type(lhs)
.
Примеры
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
композитный
Семантика
Инкапсулирует операцию, состоящую из других операций StableHLO, принимающую inputs
и composite_attributes
и выдающую results
. Семантика операции реализуется атрибутом decomposition
. composite
операцию можно заменить ее декомпозицией без изменения семантики программы. В тех случаях, когда встраивание декомпозиции не обеспечивает одинаковую семантику операций, предпочтительнее использовать custom_call
.
Поле version
(по умолчанию 0
) используется для обозначения изменения семантики композиции.
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | вариативное число значений |
(И2) | name | константа типа string |
(И3) | composite_attributes | словарь атрибутов |
(И4) | decomposition | константа типа string |
(И5) | version | константа типа si32 |
Выходы
Имя | Тип |
---|---|
results | вариативное число значений |
Ограничения
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Примеры
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
объединять
Семантика
Объединяет inputs
по dimension
в том же порядке, что и заданные аргументы, и создает result
тензор. Более формально, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, где:
-
id = d0 + ... + dk-1 + kd
. -
d
равноdimension
, аd0
,... — этоd
-е размерностьinputs
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (С1-С6) |
(И2) | dimension | константа типа si64 | (С2), (С4), (С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С5-С6) |
Ограничения
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
за исключениемdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
за исключением:-
dim(result, dimension) = dim(inputs[0], dimension) + ...
.
-
Примеры
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
постоянный
Семантика
Создает output
тензор из постоянного value
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | value | постоянный | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор или квантованный тензор | (С1) |
Ограничения
- (C1)
type(value) = type(output)
.
Примеры
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
конвертировать
Семантика
Выполняет поэлементное преобразование одного типа элемента в другой в тензоре operand
и создает тензор result
.
Для преобразования логического значения в любой поддерживаемый тип значение false
преобразуется в ноль, а значение true
преобразуется в единицу. Для любого поддерживаемого преобразования типа в логическое значение нулевое значение преобразуется в false
, а ненулевые значения преобразуются в true
. Ниже показано, как это работает для сложных типов.
Для преобразований, включающих целое число в целое число , целое число в число с плавающей запятой или число с плавающей запятой в число с плавающей запятой , если исходное значение может быть точно представлено в целевом типе, значение результата будет таким же точным представлением. В противном случае поведение будет определено позднее ( #180 ).
Для преобразований с плавающей запятой в целое число дробная часть усекается. Если усеченное значение не может быть представлено в целевом типе, поведение будет определено позднее ( #180 ).
Преобразование комплексного значения в комплексное происходит так же, как и преобразование чисел с плавающей запятой при преобразовании действительных и мнимых частей.
Для преобразований «комплексный тип в любой другой» и «любой другой тип в комплекс» исходное мнимое значение игнорируется или конечное мнимое значение обнуляется соответственно. Преобразование вещественной части следует за преобразованиями с плавающей запятой.
В принципе, эта операция могла бы выражать деквантование (преобразование квантованных тензоров в регулярные тензоры), квантование (преобразование регулярных тензоров в квантованные тензоры) и реквантование (преобразование между квантованными тензорами), но на данный момент для этого у нас есть специальные операции — uniform_dequantize
для первый вариант использования и uniform_quantize
для второго и третьего вариантов использования. В будущем эти две операции могут быть объединены в convert
( #1576 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор | (С1) |
Ограничения
- (C1)
shape(operand) = shape(result)
.
Примеры
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
свертка
Семантика
Вычисляет скалярное произведение между окнами lhs
и срезами rhs
и выдает result
. На следующей диаграмме показано, как элементы result
вычисляются на основе lhs
и rhs
на конкретном примере.
Более формально, рассмотрим следующее переформулирование входных данных в терминах lhs
чтобы иметь возможность выражать окна lhs
:
-
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
. -
lhs_window_strides = lhs_shape(1, window_strides, 1)
. -
lhs_padding = lhs_shape([0, 0], padding, [0, 0])
. -
lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
. -
lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
В этом рефрейминге используются следующие вспомогательные функции:
-
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
. -
result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
. -
permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
гдеj[d] = i[permutation[d]]
.
Если feature_group_count = 1
и batch_group_count = 1
, то для всех output_spatial_index
в index_space(dim(result, output_spatial_dimensions...))
, result[result_shape(:, output_spatial_index, :)] = dot_product
где:
-
padding_value = constant(0, element_type(lhs))
. -
padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
. -
lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
. -
lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
. -
reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Эта функция, похоже, не используется, поэтому в будущем мы планируем ее удалить ( #1181 ). -
dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Если feature_group_count > 1
:
-
lhses = split(lhs, feature_group_count, input_feature_dimension)
. -
rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
. -
results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
. -
result = concatenate(results, output_feature_dimension)
.
Если batch_group_count > 1
:
-
lhses = split(lhs, batch_group_count, input_batch_dimension)
. -
rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
. -
results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
. -
result = concatenate(results, output_feature_dimension)
.
Для квантованных типов выполняет dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result))
.
Для гибридных квантованных типов выполняет hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1), (С10-С11), (С14) (С25), (С27-С28), (С31-С32), (С34) |
(И2) | rhs | тензор или квантованный тензор | (С1), (С14-С16), (С25), (С27-С29), (С31-С34) |
(И3) | window_strides | 1-мерная тензорная константа типа si64 | (С2-С3), (С25) |
(И4) | padding | 2-мерная тензорная константа типа si64 | (С4), (С25) |
(И5) | lhs_dilation | 1-мерная тензорная константа типа si64 | (С5-С6), (С25) |
(И6) | rhs_dilation | 1-мерная тензорная константа типа si64 | (С7-С8), (С25) |
(I7) | window_reversal | 1-мерная тензорная константа типа i1 | (С9) |
(И8) | input_batch_dimension | константа типа si64 | (С10), (С13), (С25) |
(I9) | input_feature_dimension | константа типа si64 | (С11), (С13-С14) |
(I10) | input_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С12), (С13), (С25) |
(I11) | kernel_input_feature_dimension | константа типа si64 | (С14), (С18) |
(I12) | kernel_output_feature_dimension | константа типа si64 | (С15-С16), (С18), (С25), (С29) |
(I13) | kernel_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С17-С18), (С25) |
(I14) | output_batch_dimension | константа типа si64 | (С20), (С25) |
(I15) | output_feature_dimension | константа типа si64 | (С20), (С25), (С30) |
(I16) | output_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С19-С20), (С25) |
(I17) | feature_group_count | константа типа si64 | (С11), (С14), (С16), (С21), (С23) |
(I18) | batch_group_count | константа типа si64 | (С10), (С15), (С22), (С23), (С25) |
(I19) | precision_config | вариативное количество перечислений DEFAULT , HIGH и HIGHEST | (С24) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С25-С28), (С30), (С32-34) |
Ограничения
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Учитывая
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:-
is_unique(input_dimensions)
. -
0 <= input_dimensions < N
.
-
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Учитывая
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:-
is_unique(kernel_dimensions)
. -
0 <= kernel_dimensions < N
.
-
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Учитывая
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:-
is_unique(output_dimensions)
. -
0 <= output_dimensions < N
.
-
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
определяется как:-
dim(lhs, input_batch_dimension) / batch_group_count
еслиresult_dim = output_batch_dimension
. -
dim(rhs, kernel_output_feature_dimension)
еслиresult_dim = output_feature_dimension
. -
num_windows
в противном случае, где: -
output_spatial_dimensions[spatial_dim] = result_dim
. -
lhs_dim = input_spatial_dimensions[spatial_dim]
. -
rhs_dim = kernel_spatial_dimensions[spatial_dim]
. -
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
. -
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
. -
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
. -
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
. -
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
-
- (C26)
rank(result) = N
. - Если в операции используются неквантованные тензоры:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Если в операции используются квантованные тензоры:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Если
is_per_axis_quantized(result)
, тоquantization_dimension(result) = output_feature_dimension
. - Если
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Если
is_per_tensor_quantized(rhs)
, тоis_per_tensor_quantized(result)
. - Если
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Примеры
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
косинус
Семантика
Выполняет поэлементную операцию косинуса над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
cos
из IEEE-754. - Для комплексных чисел: комплексный косинус.
- Для квантованных типов:
dequantize_op_quantize(cosine, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Семантика
Выполняет поэлементный подсчет количества ведущих нулевых битов в тензоре operand
и создает result
тензор.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
.
Примеры
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Семантика
Инкапсулирует определяемую реализацией операцию call_target_name
, которая принимает inputs
и called_computations
и выдает results
. has_side_effect
, backend_config
и api_version
могут использоваться для предоставления дополнительных метаданных, определяемых реализацией.
На данный момент эта операция содержит довольно неорганизованный набор метаданных, отражающий органическое развитие аналогичной операции в компиляторе XLA. В будущем мы планируем унифицировать эти метаданные ( #741 ).
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | вариативное число значений |
(И2) | call_target_name | константа типа string |
(И3) | has_side_effect | константа типа i1 |
(И4) | backend_config | константа типа string или словаря атрибутов |
(И5) | api_version | константа типа si32 |
(И6) | called_computations | вариативное число констант типа string |
Выходы
Имя | Тип |
---|---|
results | вариативное число значений |
Примеры
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
разделять
Семантика
Выполняет поэлементное деление тензоров делимого lhs
и rhs
делителя и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел: целочисленное деление, которое дает алгебраическое частное с отбрасыванием любой дробной части.
- Для поплавков:
division
от IEEE-754. - Для комплексных чисел: комплексное деление.
- Для квантованных типов:
-
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Семантика
Вычисляет скалярное произведение между срезами lhs
и rhs
и создает тензор result
.
Более формально, result[result_index] = dot_product
, где:
-
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
. -
rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
. -
result_batching_index + result_lhs_index + result_rhs_index = result_index
гдеsize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
иsize(result_rhs_index) = size(rhs_result_dimensions)
. -
transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
. -
transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
. -
reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
. -
transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
. -
transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
. -
reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
. -
dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Для квантованных типов выполняет dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Для гибридных квантованных типов выполняется hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
управляет компромиссом между скоростью и точностью вычислений на серверных компонентах ускорителя. Это может быть одно из следующих (на данный момент семантика этих значений перечисления недостаточно определена, но мы планируем решить эту проблему в #755 ):
-
DEFAULT
: Самый быстрый расчет, но наименее точное приближение к исходному числу. -
HIGH
: более медленный расчет, но более точное приближение к исходному числу. -
HIGHEST
: Самый медленный расчет, но наиболее точное приближение к исходному числу.
DotAlgorithm
определяет основные свойства алгоритма, используемого для реализации операции с точкой, что также определяет точность. Если поля атрибутов алгоритма установлены, то для precision_config
должно быть DEFAULT
. DotAlgorithms
не имеют значения по умолчанию, поскольку параметры по умолчанию определяются реализацией. Таким образом, для всех полей алгоритма с точкой можно установить значение None
, чтобы указать алгоритм с пустой точкой, который вместо этого будет использовать значение precision_config
.
Поля DotAlgorithm
включают:
-
lhs_precision_type
иrhs_precision_type
— точность, до которой округляются левая и правая части операции. Типы точности не зависят от типов хранения входных и выходных данных. -
accumulation_type
точность, используемая для накопления. -
lhs_component_count
,rhs_component_count
иnum_primitive_operations
применяются, когда мы выполняем алгоритм, который разлагает LHS и/или RHS на несколько компонентов и выполняет несколько «примитивных» точечных операций над этими значениями — обычно для эмуляции более высокой точности (например, использование типа данных искусственного интеллекта bfloat16). Для вычислений более высокой точности : bf16_6x tf32_3x и т. д.). Для алгоритмов без декомпозиции эти значения должны быть установлены равными1
. -
allow_imprecise_accumulation
, чтобы указать, разрешено ли накопление с более низкой точностью для некоторых шагов (например,CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Пример атрибутов DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Реализации должны решить, какие комбинации поддерживаются. В общем, не гарантируется, что каждый алгоритм поддерживается на каждом типе ускорителя потребителем StableHLO. Если данный алгоритм не поддерживается, следует выдать ошибку, а не возвращаться к альтернативе. Проверка StableHLO обеспечит максимальную проверку, предотвращая использование алгоритмов, о поддержке которых неизвестно ни на каком оборудовании.
См. xla_data.proto > Algorithm
для получения информации о некоторых поддерживаемых значениях алгоритма. Заявка № 2483 описывает план создания централизованного документа по поддерживаемым алгоритмам с помощью серверной части.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С5-С6), (С9-С10), (С12-С14), (С17-С18), (С20) |
(И2) | rhs | тензор или квантованный тензор | (С7-С10), (С12-С20) |
(И3) | lhs_batching_dimensions | 1-мерная тензорная константа типа si64 | (С1), (С3), (С5), (С9), (С12) |
(И4) | rhs_batching_dimensions | 1-мерная тензорная константа типа si64 | (С1), (С4), (С7), (С9) |
(И5) | lhs_contracting_dimensions | 1-мерная тензорная константа типа si64 | (С2), (С3), (С6), (С10) |
(И6) | rhs_contracting_dimensions | 1-мерная тензорная константа типа si64 | (С2), (С4), (С8), (С10), (С16) |
(I7) | precision_config | вариативное количество перечислений DEFAULT , HIGH и HIGHEST | (С11), (С21) |
(И8) | lhs_precision_type | FloatType или TensorFloat32 | (С21) |
(I9) | rhs_precision_type | FloatType или TensorFloat32 | (С21) |
(I10) | accumulation_type | FloatType или TensorFloat32 | (С21) |
(I11) | lhs_component_count | константа типа si32 | (С21), (С22) |
(I12) | rhs_component_count | константа типа si32 | (С21), (С23) |
(I13) | num_primitive_operations | константа типа si32 | (С21), (С24) |
(I14) | allow_imprecise_accumulation | константа типа bool | (С21) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С12), (С14), (С18-С20) |
Ограничения
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Если в операции используются неквантованные тензоры:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Если в операции используются квантованные тензоры:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs)
не находится вrhs_contracting_dimensions
. - Если
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Если
is_per_tensor_quantized(rhs)
, тоis_per_tensor_quantized(result)
. - Если
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Если
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Примеры
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
Dynamic_broadcast_in_dim
Семантика
Эта операция функционально идентична операции Broadcast_in_dim , но форма результата задается динамически через output_dimensions
.
Операция также принимает необязательные known_expanding_dimensions
, known_nonexpanding_dimensions
для выражения статических знаний о поведении измерений при расширении. Если не указано иное, предполагается, что все размеры могут расширяться.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С2), (С5-С6), (С9) |
(И2) | output_dimensions | 1-мерный тензор целочисленного типа | (С7) |
(И3) | broadcast_dimensions | Одномерный постоянный тензор целочисленного типа | (С2-С6) |
(И4) | known_expanding_dimensions | Одномерный постоянный тензор целочисленного типа | (С8-С9) |
(И5) | known_nonexpanding_dimensions | Одномерный постоянный тензор целочисленного типа | (С8-С9) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1), (С3), (С5-С7) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
,scales(operand)
иzero_points(operand)
могут отличаться отquantization_dimension(result)
,scales(result)
иzero_points(result)
соответственно, в противном случае.
-
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Для всех
d
вaxes(operand)
:-
dim(operand, d) = 1
или -
dim(operand, d) = dim(result, broadcast_dimensions[d])
.
-
- (C6) Если
is_per_axis_quantized(result)
:-
quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
. - Если
dim(operand, quantization_dimension(operand)) = 1
, тоscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
-
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Примеры
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
динамический_конв
Семантика
Эта операция функционально идентична операции свертки , но заполнение задается динамически с помощью padding
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1), (С10-С11), (С14) (С25), (С26-С27), (С30-С31), (С33) |
(И2) | rhs | тензор или квантованный тензор | (С1), (С14-С16), (С26-С28), (С30-С33) |
(И3) | padding | 2-мерный тензор целочисленного типа | (С4) |
(И4) | window_strides | 1-мерная тензорная константа типа si64 | (С2-С3) |
(И5) | lhs_dilation | 1-мерная тензорная константа типа si64 | (С5-С6) |
(И6) | rhs_dilation | 1-мерная тензорная константа типа si64 | (С7-С8) |
(I7) | window_reversal | 1-мерная тензорная константа типа i1 | (С9) |
(И8) | input_batch_dimension | константа типа si64 | (С10), (С13) |
(I9) | input_feature_dimension | константа типа si64 | (С11), (С13-С14) |
(I10) | input_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С12), (С13) |
(I11) | kernel_input_feature_dimension | константа типа si64 | (С14), (С18) |
(I12) | kernel_output_feature_dimension | константа типа si64 | (С15-С16), (С18), (С28) |
(I13) | kernel_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С17-С18) |
(I14) | output_batch_dimension | константа типа si64 | (С20) |
(I15) | output_feature_dimension | константа типа si64 | (С20), (С29) |
(I16) | output_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С19-С20) |
(I17) | feature_group_count | константа типа si64 | (С11), (С14), (С16), (С21), (С23) |
(I18) | batch_group_count | константа типа si64 | (С10), (С15), (С22), (С23) |
(I19) | precision_config | вариативное количество перечислений DEFAULT , HIGH и HIGHEST | (С24) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С25-С27), (С29), (С31-С33) |
Ограничения
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Учитывая
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:-
is_unique(input_dimensions)
. -
0 <= input_dimensions < N
.
-
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Учитывая
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:-
is_unique(kernel_dimensions)
. -
0 <= kernel_dimensions < N
.
-
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Учитывая
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:-
is_unique(output_dimensions)
. -
0 <= output_dimensions < N
.
-
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
определяется как:-
dim(lhs, input_batch_dimension) / batch_group_count
еслиresult_dim = output_batch_dimension
. -
dim(rhs, kernel_output_feature_dimension)
еслиresult_dim = output_feature_dimension
. -
num_windows
в противном случае, где: -
output_spatial_dimensions[spatial_dim] = result_dim
. -
lhs_dim = input_spatial_dimensions[spatial_dim]
. -
rhs_dim = kernel_spatial_dimensions[spatial_dim]
. -
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
. -
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
. -
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
. -
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
. -
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
-
- (C26)
rank(result) = N
. - Если в операции используются неквантованные тензоры:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Если в операции используются квантованные тензоры:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Если
is_per_axis_quantized(result)
, тоquantization_dimension(result) = output_feature_dimension
. - Если
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Если
is_per_tensor_quantized(rhs)
, тоis_per_tensor_quantized(result)
. - Если
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Примеры
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
динамический_сбор
Семантика
Эта операция функционально идентична операции сбора , при этом slice_sizes
задаются динамически как значение.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С7), (С10-С12), (С14) |
(И2) | start_indices | тензор целочисленного типа | (С2), (С3), (С13) |
(И3) | slice_sizes | 1-мерный тензор целочисленного типа | (С8), (С11-С13) |
(И4) | offset_dims | 1-мерная тензорная константа типа si64 | (С1), (С4-С5), (С13) |
(И5) | collapsed_slice_dims | 1-мерная тензорная константа типа si64 | (С1), (С6-С8), (С13) |
(И6) | start_index_map | 1-мерная тензорная константа типа si64 | (С3), (С9), (С10) |
(I7) | index_vector_dim | константа типа si64 | (С2), (С3), (С13) |
(И8) | indices_are_sorted | константа типа i1 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С5), (С13-С14) |
Ограничения
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
где:-
batch_dim_sizes = shape(start_indices)
за исключением того, что размер измеренияstart_indices
соответствующийindex_vector_dim
не включается. -
offset_dim_sizes = shape(slice_sizes)
за исключением того, что размеры размеров вslice_sizes
соответствующиеcollapsed_slice_dims
, не включены. -
combine
помещаетbatch_dim_sizes
в оси, соответствующиеbatch_dims
, иoffset_dim_sizes
в оси, соответствующиеoffset_dims
.
-
- (C14)
element_type(operand) = element_type(result)
.
Примеры
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
динамическая_йота
Семантика
Эта операция функционально идентична операции iota op, но форма результата задается динамически через output_shape
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | output_shape | 1-мерный тензор целочисленного типа | (С1), (С2) |
(И2) | iota_dimension | si64 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С2) |
Ограничения
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Примеры
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
динамический_пад
Семантика
Эта операция функционально идентична операции Pad , но с edge_padding_low
, edge_padding_high
и interior_padding
, заданными динамически как значения.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С2), (С4) |
(И2) | padding_value | 0-мерный тензор или потензорный квантованный тензор | (С1) |
(И3) | edge_padding_low | 1-мерный тензор целочисленного типа | (С1), (С4) |
(И4) | edge_padding_high | 1-мерный тензор целочисленного типа | (С1), (С4) |
(И5) | interior_padding | 1-мерный тензор целочисленного типа | (С2-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С3-С6) |
Ограничения
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Примеры
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
Dynamic_reshape
Семантика
Эта операция функционально идентична операции изменения формы , но форма результата задается динамически через output_shape
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С3) |
(И2) | output_shape | 1-мерный тензор целочисленного типа | (С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С4) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
иquantization_dimension(result)
могут отличаться, в противном случае.
-
- (C2)
size(operand) = size(result)
. - (C3) Если
is_per_axis_quantized(operand)
:-
reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
. -
dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
. -
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
-
- (C4)
size(output_shape) = rank(result)
.
Примеры
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
динамический_срез
Семантика
Извлекает срез из operand
используя динамически вычисляемые начальные индексы, и создает result
тензор. start_indices
содержит начальные индексы среза для каждого измерения, подлежащего потенциальной корректировке, а slice_sizes
содержит размеры среза для каждого измерения. Более формально, result[result_index] = operand[operand_index]
, где:
-
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
. -
operand_index = adjusted_start_indices + result_index
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С2), (С4) |
(И2) | start_indices | вариативное число 0-мерных тензоров целочисленного типа | (С2), (С3) |
(И3) | slice_sizes | 1-мерная тензорная константа типа si64 | (С2), (С4), (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1), (С5) |
Ограничения
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Примеры
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
Dynamic_update_slice
Семантика
Создает тензор result
, который равен тензору operand
, за исключением того, что срез, начинающийся с start_indices
обновляется значениями в update
. Более формально, result[result_index]
определяется как:
-
update[update_index]
if0 <= update_index < shape(update)
где:-
adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
. -
update_index = result_index - adjusted_start_indices
.
-
-
operand[result_index]
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1-С4), (С6) |
(И2) | update | тензорный или потензорный квантованный тензор | (С2), (С3), (С6) |
(И3) | start_indices | вариативное число 0-мерных тензоров целочисленного типа | (С4), (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Примеры
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
экспоненциальный
Семантика
Выполняет поэлементную экспоненциальную операцию над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
exp
IEEE-754. - Для комплексных чисел: комплексная экспонента.
- Для квантованных типов:
dequantize_op_quantize(exponential, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
экспоненциальный_минус_один
Семантика
Выполняет поэлементную экспоненциальную операцию минус одна операция над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
expm1
из IEEE-754. - Для комплексных чисел: комплексная экспонента минус один.
- Для квантованных типов:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
фф
Семантика
Выполняет прямое и обратное преобразование Фурье для реальных и сложных входных/выходных данных.
fft_type
— одно из следующих:
-
FFT
: прямое БПФ от комплекса к комплексу. -
IFFT
: Обратное комплексное БПФ. -
RFFT
: прямое БПФ от действительного к комплексному. -
IRFFT
: обратное БПФ от действительного к комплексному (т.е. принимает комплексное значение, возвращает вещественное).
Более формально, учитывая функцию fft
, которая принимает на вход одномерные тензоры комплексных типов, создает одномерные тензоры тех же типов, что и на выходе, и вычисляет дискретное преобразование Фурье:
Для fft_type = FFT
result
определяется как окончательный результат серии L вычислений, где L = size(fft_length)
. Например, для L = 3
:
-
result1[i0, ..., :] = fft(operand[i0, ..., :])
. -
result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Кроме того, учитывая функцию ifft
, которая имеет ту же сигнатуру типа и вычисляет обратную функцию fft
:
Для fft_type = IFFT
result
определяется как обратный вычислениям для fft_type = FFT
. Например, для L = 3
:
-
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
. -
result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Кроме того, учитывая функцию rfft
, которая принимает одномерные тензоры типов с плавающей запятой, создает одномерные тензоры комплексных типов с той же семантикой с плавающей запятой и работает следующим образом:
-
rfft(real_operand) = truncated_result
где -
complex_operand... = (real_operand..., 0.0)
. -
complex_result = fft(complex_operand)
. -
truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Когда дискретное преобразование Фурье вычисляется для вещественных операндов, первые N/2 + 1
элементов результата однозначно определяют остальную часть результата, поэтому результат rfft
усекается, чтобы избежать вычисления избыточных элементов).
Для fft_type = RFFT
result
определяется как окончательный результат серии L вычислений, где L = size(fft_length)
. Например, для L = 3
:
-
result1[i0, ..., :] = rfft(operand[i0, ..., :])
. -
result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Наконец, дана функция irfft
, которая имеет ту же сигнатуру типа и вычисляет обратную функцию rfft
:
Для fft_type = IRFFT
result
определяется как обратный вычислениям для fft_type = RFFT
. Например, для L = 3
:
-
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
. -
result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа | (С1), (С2), (С4), (С5) |
(И2) | fft_type | перечисление FFT , IFFT , RFFT и IRFFT | (С2), (С5) |
(И3) | fft_length | 1-мерная тензорная константа типа si64 | (С1), (С3), (С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа | (С2), (С4), (С5) |
Ограничения
- (C1)
size(fft_length) <= rank(operand)
. - (C2) Отношения между типами
operand
иresult
элементов различаются:- Если
fft_type = FFT
,element_type(operand)
иelement_type(result)
имеют один и тот же сложный тип. - Если
fft_type = IFFT
,element_type(operand)
иelement_type(result)
имеют один и тот же сложный тип. - Если
fft_type = RFFT
,element_type(operand)
— это тип с плавающей запятой, аelement_type(result)
— это сложный тип с той же семантикой с плавающей запятой. - Если
fft_type = IRFFT
,element_type(operand)
— это сложный тип, аelement_type(result)
— это тип с плавающей запятой той же семантики с плавающей запятой.
- Если
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Если среди
operand
иresult
есть тензорноеreal
с плавающей запятой, тоshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
за исключением:- Если
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Если
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Если
Примеры
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
пол
Семантика
Выполняет поэлементное определение тензора operand
и создает тензор result
. Реализует операцию roundToIntegralTowardNegative
из спецификации IEEE-754. Для квантованных типов выполняется dequantize_op_quantize(floor, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
собирать
Семантика
Собирает срезы тензора operand
по смещениям, указанным в start_indices
и создает result
тензор.
На следующей диаграмме показано, как элементы result
сопоставляются с элементами operand
на конкретном примере. На диаграмме выбраны несколько примеров индексов result
и подробно объяснено, каким индексам operand
они соответствуют.
Более формально, result[result_index] = operand[operand_index]
, где:
-
batch_dims = [d for d in axes(result) and d not in offset_dims]
. -
batch_index = result_index[batch_dims...]
. -
start_index
определяется как:-
start_indices[bi0, ..., :, ..., biN]
гдеbi
— отдельные элементы вbatch_index
, а:
вставляется в индексindex_vector_dim
, еслиindex_vector_dim
<rank(start_indices)
. -
[start_indices[batch_index]]
в противном случае.
-
- Для
d_operand
вaxes(operand)
,-
full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
еслиd_operand = start_index_map[d_start]
. -
full_start_index[d_operand] = 0
в противном случае.
-
- Для
d_operand
вaxes(operand)
,-
full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
, еслиd_operand = operand_batching_dims[i_batching]
иd_start = start_indices_batching_dims[i_batching]
. -
full_batching_index[d_operand] = 0
в противном случае.
-
-
offset_index = result_index[offset_dims...]
. -
full_offset_index = [oi0, ..., 0, ..., oiN]
, гдеoi
— отдельные элементы вoffset_index
, а0
вставляется в индексы изcollapsed_slice_dims
operand_batching_dims
. -
operand_index = full_start_index + full_batching_index + full_offset_index
.
Если indices_are_sorted
имеет true
, то реализация может предположить, что start_indices
сортируются относительно start_index_map
, в противном случае поведение не определено. Более формально, для всех i1 < i2
из indices(result)
, full_start_index(i1) <= full_start_index(i2)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С8), (С11), (С17), (С19-С21), (С23) |
(И2) | start_indices | тензор целочисленного типа | (С2-С3), (С14), (С17), (С22) |
(И3) | offset_dims | 1-мерная тензорная константа типа si64 | (С1), (С4-С5), (С22) |
(И4) | collapsed_slice_dims | 1-мерная тензорная константа типа si64 | (С1), (С6-С9), (С22) |
(И5) | operand_batching_dims | 1-мерная тензорная константа типа si64 | (С1), (С6), (С10-С12), (С16-С18), (С22) |
(И6) | start_indices_batching_dims | 1-мерная тензорная константа типа si64 | (С13-С17) |
(I7) | start_index_map | 1-мерная тензорная константа типа si64 | (С3), (С18-С19) |
(И8) | index_vector_dim | константа типа si64 | (С2-С3), (С15), (С22) |
(I9) | slice_sizes | 1-мерная тензорная константа типа si64 | (С9), (С12), (С20-С22) |
(I10) | indices_are_sorted | константа типа i1 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С5), (С22-С23) |
Ограничения
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
где:-
batch_dim_sizes = shape(start_indices)
за исключением того, что размер измеренияstart_indices
соответствующийindex_vector_dim
не включается. -
offset_dim_sizes = slice_sizes
за исключением того, что размеры измерений вslice_sizes
соответствующиеcollapsed_slice_dims
иoperand_batching_dims
не включены. -
combine
помещаетbatch_dim_sizes
в оси, соответствующиеbatch_dims
, иoffset_dim_sizes
в оси, соответствующиеoffset_dims
.
-
- (C23)
element_type(operand) = element_type(result)
.
Примеры
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Семантика
Возвращает размер заданного dimension
operand
. Более формально, result = dim(operand, dimension)
. Семантика касается только компонента формы типа. Тип элемента может быть любым.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1) |
(И2) | dimension | константа типа si64 | (С1) |
Выходы
Имя | Тип |
---|---|
result | 0-мерный тензор типа si32 |
Ограничения
- (C1)
0 <= dimension < rank(operand)
.
Примеры
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Семантика
Извлекает элемент в index
позиции кортежа operand
и выдает result
. Более формально, result = operand[index]
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | кортеж | (С1), (С2) |
(И2) | index | константа типа si32 | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | любой поддерживаемый тип | (С2) |
Ограничения
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Примеры
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
если
Семантика
Создает результат выполнения ровно одной функции из true_branch
или false_branch
в зависимости от значения pred
. Более формально, result = pred ? true_branch() : false_branch()
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | pred | 0-мерный тензор типа i1 | |
(И2) | true_branch | функция | (С1-С3) |
(И3) | false_branch | функция | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С3) |
Ограничения
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Примеры
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
изображение
Семантика
Извлекает мнимую часть поэлементно из operand
и создает result
тензор. Более формально, для каждого элемента x
: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой | (С1), (С2) |
Ограничения
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
определяется как:-
complex_element_type(element_type(operand))
ifis_complex(operand)
. -
element_type(operand)
в противном случае.
-
Примеры
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
подача
Семантика
Считывает данные с источника питания и выдает results
.
Семантика infeed_config
определяется реализацией.
results
состоят из значений полезной нагрузки, которые идут первыми, и токена, который идет последним. В будущем мы планируем разделить полезную нагрузку и токен на два отдельных вывода для повышения ясности ( #670 ).
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | token | token |
(И2) | infeed_config | константа типа string |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С1-С3) |
Ограничения
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
илиis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Примеры
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
йота
Семантика
Заполняет output
тензор значениями в порядке возрастания, начиная с нуля, по измерению iota_dimension
. Более формально,
output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | iota_dimension | si64 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
0 <= iota_dimension < rank(output)
.
Примеры
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Семантика
Выполняет поэлементную проверку, является ли значение x
конечным (т. е. не является ни +Inf, -Inf, ни NaN), и создает тензор y
. Реализует операцию isFinite
из спецификации IEEE-754. Для квантованных типов результат всегда true
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | x | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
y | тензор логического типа | (С1) |
Ограничения
- (C1)
shape(x) = shape(y)
.
Примеры
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
бревно
Семантика
Выполняет поэлементную операцию логарифма над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
log
IEEE-754. - Для комплексных чисел: комплексный логарифм.
- Для квантованных типов:
dequantize_op_quantize(log, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Семантика
Выполняет поэлементный логарифм плюс одну операцию над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
logp1
из IEEE-754. - Для комплексных чисел: комплексный логарифм плюс один.
- Для квантованных типов:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
логистический
Семантика
Выполняет поэлементную логистическую операцию над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
division(1, addition(1, exp(-x)))
из IEEE-754. - Для комплексных чисел: сложная логистика.
- Для квантованных типов:
dequantize_op_quantize(logistic, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
карта
Семантика
Применяет computation
функции карты к inputs
по dimensions
и создает result
тензор.
Более формально, result[result_index] = computation(inputs...[result_index])
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (С1-С4) |
(И2) | dimensions | 1-мерная тензорная константа типа si64 | (С3) |
(И3) | computation | функция | (С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1), (С4) |
Ограничения
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
- (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
имеет тип(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
гдеEi = element_type(inputs[i])
иE' = element_type(result)
.
Примеры
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
максимум
Семантика
Выполняет поэлементную операцию max над тензорами lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое ИЛИ.
- Для целых чисел: целое число максимум.
- Для чисел с плавающей запятой:
maximum
из IEEE-754. - Для комплексных чисел: лексикографический максимум для пары
(real, imaginary)
. Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел для этой операции ( #560 ). - Для квантованных типов:
-
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
минимум
Семантика
Выполняет поэлементную операцию min над тензорами lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое И.
- Для целых чисел: целочисленный минимум.
- Для плавающих чисел:
minimum
IEEE-754. - Для комплексных чисел: лексикографический минимум для пары
(real, imaginary)
. Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел для этой операции ( #560 ). - Для квантованных типов:
-
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
умножать
Семантика
Выполняет поэлементное произведение двух тензоров lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое И.
- Для целых чисел: целочисленное умножение.
- Для чисел с плавающей запятой:
multiplication
из IEEE-754. - Для комплексных чисел: комплексное умножение.
- Для квантованных типов:
-
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
отрицать
Семантика
Выполняет поэлементное отрицание тензора operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел со знаком: целочисленное отрицание.
- Для целых чисел без знака: преобразование битов в целое число со знаком, отрицание целого числа, обратное преобразование целых чисел без знака.
- Для чисел с плавающей запятой:
negate
IEEE-754. - Для комплексных чисел: комплексное отрицание.
- Для квантованных типов:
dequantize_op_quantize(negate, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
нет
Семантика
Выполняет поэлементное НЕ тензорного operand
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое НЕ.
- Для целых чисел: побитовое НЕ.
Аргументы
Имя | Тип | Ограничения |
---|---|---|
operand | тензор логического или целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического или целочисленного типа | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
.
Примеры
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
оптимизация_барьер
Семантика
Гарантирует, что операции, создающие operand
выполняются перед любыми операциями, которые зависят от result
, и предотвращает перемещение операций через барьер преобразованиями компилятора. В остальном операция является тождественной, т.е. result = operand
.
Аргументы
Имя | Тип | Ограничения |
---|---|---|
operand | вариативное число тензоров, по-тензорные квантованные тензоры или токены | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | вариативное число тензоров, по-тензорные квантованные тензоры или токены | (С1) |
Ограничения
- (C1)
type(operand...) = type(result...)
.
Примеры
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
или
Семантика
Выполняет поэлементное ИЛИ двух тензоров lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое ИЛИ.
- Для целых чисел: побитовое ИЛИ.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного или логического типа | (С1) |
(И2) | rhs | тензор целочисленного или логического типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного или логического типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
перекормить
Семантика
Записывает inputs
на выход и создает токен result
.
Семантика outfeed_config
определяется реализацией.
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | вариативное число тензоров или квантованных тензоров |
(И2) | token | token |
(И3) | outfeed_config | константа типа string |
Выходы
Имя | Тип |
---|---|
result | token |
Примеры
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
подушечка
Семантика
Расширяет operand
путем заполнения вокруг тензора, а также между элементами тензора с заданным padding_value
.
edge_padding_low
и edge_padding_high
определяют количество дополнений, добавляемых в нижнем конце (рядом с индексом 0) и верхнем конце (рядом с самым высоким индексом) каждого измерения соответственно. Величина заполнения может быть отрицательной, при этом абсолютное значение отрицательного заполнения указывает количество элементов, которые необходимо удалить из указанного измерения.
interior_padding
определяет количество отступов, добавляемых между любыми двумя элементами в каждом измерении, которое не может быть отрицательным. Внутреннее заполнение происходит перед заполнением краев, так что заполнение отрицательных краев удалит элементы из операнда с внутренним заполнением.
Более формально, result[result_index]
определяется как:
-
operand[operand_index]
еслиresult_index = edge_padding_low + operand_index * (interior_padding + 1)
. -
padding_value
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С2), (С4) |
(И2) | padding_value | 0-мерный тензор или потензорный квантованный тензор | (С1) |
(И3) | edge_padding_low | 1-мерная тензорная константа типа si64 | (С1), (С4) |
(И4) | edge_padding_high | 1-мерная тензорная константа типа si64 | (С1), (С4) |
(И5) | interior_padding | 1-мерная тензорная константа типа si64 | (С2-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С3-С6) |
Ограничения
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Примеры
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
идентификатор_раздела
Семантика
Создает partition_id
текущего процесса.
Выходы
Имя | Тип |
---|---|
result | 0-мерный тензор типа ui32 |
Примеры
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
попкнт
Семантика
Выполняет поэлементный подсчет количества бит, установленных в тензоре operand
, и создает result
тензор.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
.
Примеры
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
власть
Семантика
Выполняет поэлементное возведение тензора lhs
с помощью тензора rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел: возведение в степень целого числа.
- Для поплавков:
pow
из IEEE-754. - Для комплексных чисел: комплексное возведение в степень.
- Для квантованных типов:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
настоящий
Семантика
Извлекает действительную часть поэлементно из operand
и создает result
тензор. Более формально, для каждого элемента x
: real(x) = is_complex(x) ? real_part(x) : x
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой | (С1), (С2) |
Ограничения
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
определяется как:-
complex_element_type(element_type(operand))
ifis_complex(operand)
. -
element_type(operand)
в противном случае.
-
Примеры
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
получение
Семантика
Получает данные из канала с channel_id
и выдает results
.
Если is_host_transfer
имеет true
, операция передает данные с хоста. В противном случае он передает данные с другого устройства. Это означает, что это определяется реализацией. Этот флаг дублирует информацию, предоставленную в channel_type
, поэтому в будущем мы планируем сохранить только один из них ( #666 ).
results
состоят из значений полезной нагрузки, которые идут первыми, и токена, который идет последним. В будущем мы планируем разделить полезную нагрузку и токен на два отдельных вывода для повышения ясности ( #670 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | token | token | (С4) |
(И2) | channel_id | константа типа si64 | |
(И3) | channel_type | перечисление DEVICE_TO_DEVICE и HOST_TO_DEVICE | (С1) |
(И4) | is_host_transfer | константа типа i1 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С2-С4) |
Ограничения
- (C1)
channel_type
определяется как:-
HOST_TO_DEVICE
еслиis_host_transfer = true
, -
DEVICE_TO_DEVICE
в противном случае.
-
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
илиis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Примеры
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
уменьшать
Семантика
Применяет body
функции сокращения к inputs
и init_values
вдоль dimensions
и создает тензоры results
.
Порядок сокращений определяется реализацией, а это означает, что body
и init_values
должны образовывать моноид, чтобы гарантировать, что операция дает одинаковые результаты для всех входных данных во всех реализациях. Однако это условие не выполняется для многих популярных сокращений. Например, сложение чисел с плавающей запятой для body
и нуля для init_values
на самом деле не образует моноид, поскольку сложение чисел с плавающей запятой не является ассоциативным.
Более формально, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, где:
-
input_slices = inputs...[j0, ..., :, ..., jR-1]
, где:
вставляются вdimensions
. -
input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
. -
init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
. -
reduce(input_slices_converted) = exec(schedule)
для некоторогоschedule
двоичного дерева, где:-
exec(node) = body(exec(node.left), exec(node.right))
. -
exec(leaf) = leaf.value
.
-
-
schedule
— это определяемое реализацией полное двоичное дерево, упорядоченный обход которого состоит из:- Значения
input_slices_converted...[index]
для всехindex
вindex_space(input_slices_converted)
в возрастающем лексикографическом порядкеindex
. - Перемежается с определяемым реализацией количеством
init_values_converted
в позициях, определенных реализацией.
- Значения
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (С1-С4), (С6), (С7) |
(И2) | init_values | Вариадическое число 0-мерных тензоров или потензорных квантованных тензоров | (С2), (С3) |
(И3) | dimensions | 1-мерная тензорная константа типа si64 | (С4), (С5), (С7) |
(И4) | body | функция | (С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С3), (С7), (С8) |
Ограничения
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
имеет тип(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
гдеis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
за исключением того, что размерыinputs...
соответствующиеdimensions
не включены. - (C8)
element_type(results[i]) = Ei
для всехi
в[0,N)
.
Примеры
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
уменьшить_точность
Семантика
Выполняет поэлементное преобразование operand
в другой тип с плавающей запятой, который использует exponent_bits
и mantissa_bits
, и обратно в исходный тип с плавающей запятой и создает output
тензор.
Более формально:
- Биты мантиссы исходного значения обновляются для округления исходного значения до ближайшего значения, которое можно представить с помощью
mantissa_bits
с использованием семантикиroundToIntegralTiesToEven
. - Затем, если
mantissa_bits
меньше количества битов мантиссы исходного значения, биты мантиссы усекаются доmantissa_bits
. - Затем, если биты экспоненты промежуточного результата не помещаются в диапазон, предоставленный
exponent_bits
, промежуточный результат переполняется до бесконечности, используя исходный знак, или уменьшается до нуля, используя исходный знак. - Для квантованных типов выполняет
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
(И2) | exponent_bits | константа типа si32 | (С2) |
(И3) | mantissa_bits | константа типа si32 | (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Примеры
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
уменьшить_разброс
Семантика
Внутри каждой группы процессов в сетке процессов StableHLO выполняет сокращение, используя computations
, над значениями тензора operand
каждого процесса, разбивает результат сокращения по scatter_dimension
на части и распределяет разделенные части между процессами для получения result
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
еслиchannel_id <= 0 and use_global_device_ids = false
. -
cross_replica_and_partition(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = false
. -
flattened_ids(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = true
.
Затем внутри каждой process_group
:
-
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
. -
parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
. -
result@receiver = parts@sender[receiver_index]
для всехsender
вprocess_group
, гдеreceiver_index = process_group.index(receiver)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С2), (С7), (С8) |
(И2) | scatter_dimension | константа типа si64 | (С1), (С2), (С8) |
(И3) | replica_groups | 2-мерная тензорная константа типа si64 | (С3-С5) |
(И4) | channel_id | константа типа si64 | (С6) |
(И5) | use_global_device_ids | константа типа i1 | (С6) |
(И6) | computation | функция | (С7) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С8-С9) |
Ограничения
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_replicas
если используетсяcross_replica_and_partition
. -
num_processes
если используетсяflattened_ids
.
-
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Если
use_global_device_ids = true
, тоchannel_id > 0
. - (C7)
computation
имеет тип(tensor<E>, tensor<E>) -> (tensor<E>)
гдеis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
за исключением:-
dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
-
- (C9)
element_type(result) = E
.
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
уменьшить_окно
Семантика
Применяет body
функции сокращения к окнам inputs
и init_values
и выдает results
.
На следующей диаграмме показано, как элементы results...
вычисляются на основе inputs...
на конкретном примере.
Более формально, results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(см. сокращение ), где:
-
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
. -
window_start = result_index * window_strides
. -
window_end = window_start + (window_dimensions - 1) * window_dilations + 1
. -
windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (С1-С4), (С6), (С8), (С10), (С12), (С13), (С15) |
(И2) | init_values | вариатическое число 0-мерных тензоров или потензорных квантованных тензоров | (С1), (С13) |
(И3) | window_dimensions | 1-мерная тензорная константа типа si64 | (С4), (С5), (С15) |
(И4) | window_strides | 1-мерная тензорная константа типа si64 | (С6), (С7), (С15) |
(И5) | base_dilations | 1-мерная тензорная константа типа si64 | (С8), (С9), (С15) |
(И6) | window_dilations | 1-мерная тензорная константа типа si64 | (С10), (С11), (С15) |
(I7) | padding | 2-мерная тензорная константа типа si64 | (С12), (С15) |
(И8) | body | функция | (С13) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С1), (С14-С16) |
Ограничения
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
имеет тип(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
гдеis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
где:-
dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
. -
padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
. -
dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
. -
is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
. -
num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
-
- (C16)
element_type(results[i]) = Ei
для всехi
в[0,N)
.
Примеры
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
остаток
Семантика
Вычисляет поэлементный остаток тензоров делимого lhs
и rhs
делителя и создает result
тензор.
Более формально, знак результата берется из делимого, а абсолютное значение результата всегда меньше абсолютного значения делителя. Остаток рассчитывается как lhs - d * rhs
, где d
определяется выражением:
- Для целых чисел:
stablehlo.divide(lhs, rhs)
. - Для чисел с плавающей запятой:
division(lhs, rhs)
из IEEE-754 с атрибутом округленияroundTowardZero
. - Для комплексных чисел: подлежит уточнению ( #997 ).
- Для квантованных типов:
-
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
-
Для типов элементов с плавающей точкой эта операция в отличие от remainder
операции из спецификации IEEE-754, где d
является интегральным значением, ближайшим к точным значениям lhs/rhs
со связями до ровности.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
Replica_id
Семантика
Создает replica_id
текущего процесса.
Выходы
Имя | Тип |
---|---|
result | 0-мерный тензор типа ui32 |
Примеры
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
изменить форму
Семантика
Выполняет изменение тензора operand
до result
. Концептуально, это равносильно тому же каноническому представлению, но потенциально изменяет форму, например, от tensor<2x3xf32>
до tensor<3x2xf32>
или tensor<6xf32>
.
Более формально, result[result_index] = operand[operand_index]
, где result_index
и operand_index
имеют одинаковую позицию в лексикографическом упорядочении index_space(result)
и index_space(operand)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С3) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
иquantization_dimension(result)
могут отличаться, в противном случае.
-
- (C2)
size(operand) = size(result)
. - (C3) Если
is_per_axis_quantized(operand)
:-
reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
. -
dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
. -
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
-
Примеры
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
обеспечить регресс
Семантика
Отправляет порядок элементов в operand
вдоль указанных dimensions
и дает тензор result
. Более формально, result[result_index] = operand[operand_index]
где:
-
operand_index[d] = dim(result, d) - result_index[d] - 1
еслиd
вdimensions
. -
operand_index[d] = result_index[d]
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С3) |
(И2) | dimensions | 1-мерная тензорная константа типа si64 | (С2), (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1), (С3) |
Ограничения
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Примеры
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
звонок
Семантика
Генерирует случайные числа, используя алгоритм rng_distribution
и дает result
тензор заданной shape
.
Если rng_distribution = UNIFORM
, то случайные числа генерируются после равномерного распределения по интервалу [a, b)
. Если a >= b
, поведение не определен.
Если rng_distribution = NORMAL
, то случайные числа генерируются после нормального распределения со средним = a
и стандартным отклонениями = b
. Если b < 0
, поведение не определен.
Точный способ, как генерируются случайные числа, определяется реализацией. Например, они могут быть или не быть детерминированными, и они могут или не могут использовать скрытое состояние.
В разговорах со многими заинтересованными сторонами этот OP появился как эффективно устаревший, поэтому в будущем мы планируем изучить его удаление ( #597 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | a | 0-мерный тензор целочисленного, логического или плавающего типа | (С1), (С2) |
(И2) | b | 0-мерный тензор целочисленного, логического или плавающего типа | (С1), (С2) |
(И3) | shape | 1-мерная тензорная константа типа si64 | (С3) |
(И4) | rng_distribution | перевозить UNIFORM и NORMAL | (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного, логического или плавающего типа | (С1-С3) |
Ограничения
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Если
rng_distribution = NORMAL
, тоis_float(a)
. - (C3)
shape(result) = shape
.
Примеры
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Семантика
Возвращает output
заполненный равномерными случайными битами и обновленным выходным состоянием output_state
используя алгоритм генератора псевдордомов rng_algorithm
с учетом начального состояния initial_state
. Выход гарантированно является детерминированной функцией initial_state
, но не гарантированно будет детерминированным между реализациями.
rng_algorithm
является одним из следующих:
-
DEFAULT
: определяемый реализацией алгоритм. -
THREE_FRY
: определяемый реализацией вариант алгоритма Threefry.* -
PHILOX
: определяемый реализацией вариант алгоритма Philox.*
* См.: Salmon et al. SC 2011. Параллельные случайные числа: просто 1, 2, 3.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | rng_algorithm | enum of DEFAULT , THREE_FRY и PHILOX | (С2) |
(И2) | initial_state | 1-мерный тензор типа ui64 | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output_state | 1-мерный тензор типа ui64 | (С1) |
output | Тензор целочисленного или типа с плавающей точкой |
Ограничения
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
определяется как:- Реализация DEFATION IF
rng_algorithm = DEFAULT
. -
2
Еслиrng_algorithm = THREE_FRY
. -
2
или3
, еслиrng_algorithm = PHILOX
.
- Реализация DEFATION IF
Примеры
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Семантика
Выполняет элемент, округление к ближайшему целому числу, разрывает связи с нулем, на тензоре operand
и дает тензор result
. Реализует операцию roundToIntegralTiesToAway
из спецификации IEEE-754. Для квантованных типов выполняет dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Семантика
Выполняет элементное округление в сторону ближайшего целого числа, разбивая связи с ровным целым числом, на тензоре operand
и дает тензор result
. Реализует операцию roundToIntegralTiesToEven
из спецификации IEEE-754. Для квантовых типов выполняет dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Семантика
Выполняет элементную операцию взаимного квадратного корня на тензоре operand
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
rSqrt
от IEEE-754. - Для сложных чисел: сложный взаимный квадратный корень.
- Для квантовых типов:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
разбрасывать
Семантика
Получает тензоры results
, которые равны тенденциям inputs
, за исключением того, что несколько срезов, указанных с помощью scatter_indices
обновляются с updates
значений с использованием update_computation
.
На следующей диаграмме показано, как элементы в updates...
Карта по элементам в results...
используя конкретный пример. Диаграмма выбирает несколько примеров updates...
Индексы и подробно объясняет, какие results...
индексы, с которыми они соответствуют.
Более формально для всех update_index
в index_space(updates[0])
:
-
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
. -
update_scatter_index = update_index[update_scatter_dims...]
. -
start_index
определяется как:-
scatter_indices[si0, ..., :, ..., siN]
, гдеsi
- это отдельные элементы вupdate_scatter_index
и:
вставлены в индексеindex_vector_dim
, еслиindex_vector_dim
<rank(scatter_indices)
. -
[scatter_indices[update_scatter_index]]
в противном случае.
-
- Для
d_input
вaxes(inputs[0])
,-
full_start_index[d_input] = start_index[d_start]
еслиd_input = scatter_dims_to_operand_dims[d_start]
. -
full_start_index[d_input] = 0
в противном случае.
-
- Для
d_input
вaxes(inputs[0])
,-
full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
, еслиd_input = input_batching_dims[i_batching]
иd_start = scatter_indices_batching_dims[i_batching]
. -
full_batching_index[d_input] = 0
в противном случае.
-
-
update_window_index = update_index[update_window_dims...]
. -
full_window_index = [wi0, ..., 0, ..., wiN]
, гдеwi
являются отдельными элементами вupdate_window_index
, а0
вставлено по индексам изinserted_window_dims
иinput_batching_dims
. -
result_index = full_start_index + full_batching_index + full_window_index
.
Учитывая это, results = exec(schedule, inputs)
, где:
-
schedule
-это определяемая реализацией перестановкаindex_space(updates[0])
. -
exec([update_index, ...], results) = exec([...], updated_results)
где:- Если
result_index
находится в границах дляshape(results...)
-
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
-
updated_values = update_computation(results...[result_index], updates_converted)
-
updated_results
- это копияresults
сresults...[result_index]
установлен наupdated_values...
- В противном случае
-
updated_results = results
.
- Если
-
exec([], results) = results
.
Если indices_are_sorted
является true
, то реализация может предположить, что scatter_indices
сортируется по отношению к scatter_dims_to_operand_dims
, в противном случае поведение не определен. Более формально для всех i1 < i2
из indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Если unique_indices
true
, то реализация может предположить, что все индексы result_index
разбросаны, являются уникальными. Если unique_indices
true
, но рассеянные показатели не являются уникальными, то поведение не определен.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(И2) | scatter_indices | тензор целочисленного типа | (C4), (C15), (C19), (C22) |
(И3) | updates | вариативное число тензоров или потензорные квантованные тензоры | (C3-C6), (C8) |
(И4) | update_window_dims | 1-мерная тензорная константа типа si64 | (C2), (C4), (C7-C8) |
(И5) | inserted_window_dims | 1-мерная тензорная константа типа si64 | (C2), (C4), (C9-C11) |
(И6) | input_batching_dims | 1-мерная тензорная константа типа si64 | (C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims | 1-мерная тензорная константа типа si64 | (C14-C18) |
(И8) | scatter_dims_to_operand_dims | 1-мерная тензорная константа типа si64 | (C19-C21) |
(I9) | index_vector_dim | константа типа si64 | (C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted | константа типа i1 | |
(I11) | unique_indices | константа типа i1 | |
(I12) | update_computation | функция | (C23) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (C24-C25) |
Ограничения
- (C1)
same(shape(inputs...))
. - (C2) `rank (inputs [0]) = size (update_window_dims) + size (inserted_window_dims)
- размер (input_batching_dims) `.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
где:-
update_scatter_dim_sizes = shape(scatter_indices)
за исключением того, что размер размерностиscatter_indices
, соответствующийindex_vector_dim
, не включен. -
update_window_dim_sizes <= shape(inputs[0])
за исключением того, что размеры измерений вinputs[0]
соответствующиеinserted_window_dims
иinput_batching_dims
не включены. -
combine
putsupdate_scatter_dim_sizes
по осям, соответствующимupdate_scatter_dims
иupdate_window_dim_sizes
по осям, соответствующимupdate_window_dims
.
-
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
имеет тип(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, гдеis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
для всехi
в[0,N)
.
Примеры
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
выбирать
Семантика
Создает тензор result
, в котором каждый элемент выбран из тензора on_true
или on_false
на основе значения соответствующего элемента pred
. Более формально, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index]
, где pred_element = rank(pred) = 0 ? pred[] : pred[result_index]
. Для квантовых типов выполняет dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | pred | Тензор типа i1 | (С1) |
(И2) | on_true | тензорный или потензорный квантованный тензор | (С1-С2) |
(И3) | on_false | тензорный или потензорный квантованный тензор | (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С2) |
Ограничения
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Примеры
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Семантика
Разбросает значения из тензора source
, используя scatter
на основе результата reduce_window
input
тензора, используя select
и дает тензор result
.
На следующей диаграмме показано, как элементы в result
вычисляются из operand
и source
используя конкретный пример.
Более формально:
selected_values = reduce_window_without_init(...)
со следующими входами:-
inputs = [operand].
-
window_dimensions
,window_strides
иpadding
, которые используются как есть. -
base_dilations = windows_dilations = 1
. -
body
определяется как:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
где
E = element_type(operand)
иreduce_window_without_init
работает точно так же, какreduce_window
, за исключением того, чтоschedule
базовогоreduce
(см. DEMUT ) не включает значения init. В настоящее время не указано, что происходит, если в соответствующем окне нет значений ( #731 ).-
result[result_index] = reduce([source_values], [init_value], [0], scatter)
Где:-
source_values = [source[source_index] for source_index in source_indices]
. -
selected_index(source_index) = operand_index
ifselected_values[source_index]
имеет элементoperand
отoperand_index
. -
source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1-C4), (C6), (C8-C11) |
(И2) | source | тензорный или потензорный квантованный тензор | (С1), (С2) |
(И3) | init_value | 0-мерный тензор или квансор на квансор | (С3) |
(И4) | window_dimensions | 1-мерная тензорная константа типа si64 | (С2), (С4), (С5) |
(И5) | window_strides | 1-мерная тензорная константа типа si64 | (C2), (C6), (C7) |
(И6) | padding | 2-мерная тензорная константа типа si64 | (C2), (C8) |
(I7) | select | функция | (С9) |
(И8) | scatter | функция | (C10) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (C11-C12) |
Ограничения
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
, где:-
padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
. -
is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
. -
num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
-
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
ISESS TYPE(tensor<E>, tensor<E>) -> tensor<i1>
гдеE = element_type(operand)
. - (C10)
scatter
имеет тип(tensor<E>, tensor<E>) -> tensor<E>
гдеis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Примеры
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
отправлять
Семантика
Отправляет inputs
в канал channel_id
и создает токен result
.
Если is_host_transfer
является true
, то операция передает данные на хост. В противном случае он передает данные на другое устройство. Что это значит, определяется реализацией. Этот флаг дублирует информацию, представленную в channel_type
, поэтому в будущем мы планируем сохранить только один из них ( #666 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | Вариальное количество тензоров или квантовых тензоров | |
(И2) | token | token | |
(И3) | channel_id | константа типа si64 | |
(И4) | channel_type | enum DEVICE_TO_DEVICE и DEVICE_TO_HOST | (С1) |
(И5) | is_host_transfer | константа типа i1 | (С1) |
Выходы
Имя | Тип |
---|---|
result | token |
Ограничения
- (C1)
channel_type
определяется как:-
DEVICE_TO_HOST
ifis_host_transfer = true
, -
DEVICE_TO_DEVICE
иначе.
-
Примеры
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Семантика
Выполняет элементную операцию левой сдвиги на тензоре lhs
по количеству битов rhs
и дает тензор result
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа | (С1) |
(И2) | rhs | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Семантика
Выполняет элементную арифметическую операцию правого сдвига на тензоре lhs
по количеству битов rhs
и дает тензор result
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа | (С1) |
(И2) | rhs | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Семантика
Выполняет элементную логическую операцию правого смены на тензоре lhs
по количеству битов rhs
и дает тензор result
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа | (С1) |
(И2) | rhs | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
знак
Семантика
Возвращает признак элемента operand
и дает тензор result
. Более формально, для каждого элемента x
семантика может быть выражена с использованием синтаксиса Python следующим образом:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Для квантовых типов выполняет dequantize_op_quantize(sign, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целого числа со знаком, с плавающей запятой или комплексного типа или по-тензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целого числа со знаком, с плавающей запятой или комплексного типа или по-тензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
синус
Семантика
Выполняет элементную работу синуса на тензоре operand
и получает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
sin
от IEEE-754. - Для сложных чисел: сложный синус.
- Для квантовых типов:
dequantize_op_quantize(sine, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
кусочек
Семантика
Извлекает срез из operand
используя статически выпускаемые начальные индексы и дает тензор result
. start_indices
содержат начальные индексы среза для каждого измерения, limit_indices
содержат конечные индексы (исключительно) для среза для каждого измерения, а strides
содержат шаги для каждого измерения.
Более формально, result[result_index] = operand[operand_index]
, где operand_index = start_indices + result_index * strides
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1-С3), (С5) |
(И2) | start_indices | 1-мерная тензорная константа типа si64 | (C2), (C3), (C5) |
(И3) | limit_indices | 1-мерная тензорная константа типа si64 | (C2), (C3), (C5) |
(И4) | strides | 1-мерная тензорная константа типа si64 | (С2), (С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1), (С5) |
Ограничения
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Примеры
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
сортировать
Семантика
Сортируют 1-мерные срезы inputs
вдоль измерения dimension
вместе, в соответствии с comparator
и дает results
.
В отличие от аналогичных входов в других операциях, dimension
допускает отрицательные значения, а семантика, описанная ниже. В будущем это может быть запрещено по причинам последовательности ( #1377 ).
Если is_stable
верно, то сортировка стабильна, то есть относительный порядок элементов, которые считаются равными компаратором, сохраняется. Для случая, когда существует один вход, два элемента e1
и e2
считаются равными компаратором, если и только в том случае, если comparator(e1, e2) = comparator(e2, e1) = false
. Смотрите формализацию ниже для того, как это обобщается до нескольких входов.
Более формально для всех result_index
в index_space(results[0])
:
-
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
. -
result_slice = [ri0, ..., :, ..., riR-1]
, гдеriN
являются отдельными элементами вresult_index
, и:
вставлен приadjusted_dimension
. -
inputs_together = (inputs[0]..., ..., inputs[N-1]...)
. -
results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
. - где
sort
сортировки 1-мерного среза в неотложном порядке, ожидая, чтоcomparator_together
возвращаетtrue
если аргумент левой стороны меньше, чем правый второй аргумент. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (C1-C5) |
(И2) | dimension | константа типа si64 | (С4) |
(И3) | is_stable | константа типа i1 | |
(И4) | comparator | функция | (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С2), (С3) |
Ограничения
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, гдеR = rank(inputs[0])
. - (C5)
comparator
имеет тип(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, гдеEi = element_type(inputs[i])
.
Примеры
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
кврт
Семантика
Выполняет элементную квадратную работу квадратного корня на тензоре operand
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
squareRoot
от IEEE-754. - Для сложных чисел: сложный квадратный корень.
- Для квантовых типов:
dequantize_op_quantize(sqrt, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
вычесть
Семантика
Выполняет элементную вычитание двух тензоров lhs
и rhs
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел: целочисленное вычитание.
- Для поплавков:
subtraction
из IEEE-754. - Для сложных чисел: сложное вычитание.
- Для квантованных типов:
-
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
загар
Семантика
Выполняет элементную касательную операцию на тензоре operand
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
tan
от IEEE-754. - Для сложных чисел: сложный касательный.
- Для квантовых типов:
dequantize_op_quantize(tan, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
Тан
Семантика
Выполняет элементную гиперболическую касательную операцию на тензоре operand
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
tanh
от IEEE-754. - Для сложных чисел: сложная гиперболическая касательная.
- Для квантованных типов:
-
dequantize_op_quantize(tanh, operand, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
транспонировать
Семантика
Пересталкивает размеры тензора operand
с использованием permutation
и дает тензор result
. Более формально, result[result_index] = operand[operand_index]
, где result_index[d] = operand_index[permutation[d]]
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С4) |
(И2) | permutation | 1-мерная тензорная константа типа si64 | (С2-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (C1), (C3-C4) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
иquantization_dimension(result)
могут отличаться, в противном случае.
-
- (C2)
permutation
является перестановкойrange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Если
is_per_axis_quantized(result)
, тоquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Примеры
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
Triangular_Solve
Семантика
Решает партии систем линейных уравнений с матрицами нижних или верхних треугольных коэффициентов.
Более формально, данный a
и b
, result[i0, ..., iR-3, :, :]
-это решение для op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
, когда left_side
true
или x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
Когда left_side
является false
, решение для переменной x
, где op(a)
определяется transpose_a
, которая может быть одним из следующих:
-
NO_TRANSPOSE
: выполнить операцию, используяa
-IS. -
TRANSPOSE
: Выполните операцию на транспонированииa
-
ADJOINT
: выполните операцию на конъюгате транспонированияa
Входные данные считываются только из нижнего треугольника a
, если lower
является true
или верхним треугольником a
, в противном случае. Выходные данные возвращаются в том же треугольнике; Значения в другом треугольнике определяются внедрением.
Если unit_diagonal
является истинной, то реализация может предположить, что диагональные элементы a
равны 1, в противном случае поведение не определен.
Для квантованных типов выполняет dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | a | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1-С3) |
(И2) | b | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1-С4) |
(И3) | left_side | константа типа i1 | (С3) |
(И4) | lower | константа типа i1 | |
(И5) | unit_diagonal | константа типа i1 | |
(И6) | transpose_a | enum NO_TRANSPOSE , TRANSPOSE и ADJOINT |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) Связь между
shape(a)
иshape(b)
определяется следующим образом:-
shape(a)[:-3] = shape(b)[:-3]
. -
dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
-
- (C4)
baseline_type(b) = baseline_type(result)
.
Примеры
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
кортеж
Семантика
Создает result
кортеж из значений val
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | val | вариативное число значений | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | кортеж | (С1) |
Ограничения
- (C1)
result
имеет типtuple<E0, ..., EN-1>
гдеEi = type(val[i])
.
Примеры
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
Uniform_Dequantize
Семантика
Выполняет элементное преобразование квантового тензора operand
в result
тензора с плавающей точкой в соответствии с параметрами квантования, определенными типом operand
.
Более формально, result = dequantize(operand)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | квантованный тензор | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | Тензор типа с плавающей точкой | (С1), (С2) |
Ограничения
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Примеры
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
Uniform_quantize
Семантика
Выполняет элементное преобразование тензора с плавающей точкой или квантового operand
в квантовый тензор- result
в соответствии с параметрами квантования, определенными типом result
.
Более формально,
- Если
is_float(operand)
:-
result = quantize(operand, type(result))
.
-
- Если
is_quantized(operand)
:-
float_result = dequantize(operand)
. -
result = quantize(float_result, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | Тензор плавающей точки или квантового типа | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | квантованный тензор | (С1), (С2) |
Ограничения
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Примеры
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
пока
Семантика
Создает выход из выполнения функции body
0 или более раз, в то время как функция cond
выводит true
. Более формально, семантика может быть выражена с использованием синтаксиса Python следующим образом:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Поведение бесконечной петли - TBD ( #383 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | вариативное число тензоров, квантованных тензоров или токенов | (С1-С3) |
(И2) | cond | функция | (С1) |
(И3) | body | функция | (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С3) |
Ограничения
- (C1)
cond
имеет тип(T0, ..., TN-1) -> tensor<i1>
, гдеTi = type(operand[i])
. - (C2)
body
имеет тип(T0, ..., TN-1) -> (T0, ..., TN-1)
, гдеTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Примеры
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
XOR
Семантика
Выполняет элементный XOR из двух тензоров lhs
и rhs
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для логических: логический Xor.
- Для целых чисел: бить XOR.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор логического или целочисленного типа | (С1) |
(И2) | rhs | тензор логического или целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического или целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Диалект Interop
На данный момент программы StableHlo в дикой природе иногда содержат операции, которые не определяются StableHlo.
Модуль, функционирование, вызов и возврат
StableHlo использует операции MLIR вверх по течению для модуле, фанкопа, кольца и returnop. Это было сделано для лучшего взаимосвязи с существующим механизмом MLIR, так как написано много полезных проходов, нацеленных на целевую функцию и модуле, и многие компиляционные трубопроводы ожидают, что эти OPS будут присутствовать. Гарантии полной совместимости применяются к этим OPS. Если что -либо изменится в этих операциях несовместимым образом (т.е. удаление), эквиваленты стабильногохлового периода будут добавлены для сохранения совместимости.
Хло
CHLO OPSET содержит операции более высокого уровня, которые разлагаются на stablehlo. В настоящее время для CHLO нет гарантий совместимости. Для гарантий совместимости пропуск Chlo-легализ-к стабильному расту должен использоваться до сериализации.
Форма Операции
В сообществе является общий вариант использования, чтобы использовать определенные операции из Core Mlir Dialects в программах Dynamic StableHlo для выполнения вычислений формы. Чаще всего они включают в себя диалекты shape
, такие как shape_of
или num_elements
, tensor
диалекты , такие как dim
или from_elements
, и тип index
встроенного.
Dynamism RFC> O2 обозначает их как из -за масштаба, однако некоторая поддержка типов index
включена для целей взаимодействия. Для этих OPS или типов нет гарантий совместимости. Пропуск Shape-Legalize-to StableHlo может быть использован для преобразования этих операций в полностью поддерживаемые StableHlo Ops.
Устаревшие операции
Есть несколько операций StableHlo, которые были унаследованы от MHLO , которые устарели и на выходе из стабильной. Полная информация об этих удалениях можно найти в Clainhlo v1.0 #2283 . Проблема трекера для этих детекций - #2340 .
Эти операции делятся на несколько категорий:
- «Не в HLO» категории операций StableHlo-они изначально были частью Opset StableHlo, но позже были признаны не подходящими для него хорошо:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
( #3 ) . - Неиспользуемые OPS - эти операции могли быть полезны в какой -то момент, но OPS были либо недостаточно развиты, либо трубопроводы, использующие эти OPS, были рефакторированы, чтобы больше не требовать их. Это включает в себя
map
,tuple
( #598 ),get_tuple_element
,rng
,complex
сравнения № 560 и Convolutionwindow_reversal
( #1181 ).
Некоторые из этих OPS можно легко удалить, учитывая, что они могут быть выражены с использованием существующих OPS ( broadcast
, create_token
, cross-replica-sum
, dot
, unary_einsum
) и будут удалены после существующих проходов окна со совместимыми (6 месяцев). Другие все еще изучаются для удаления ( einsum
, get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
сравнения, window_reversal
). В ожидании обратной связи сообщества эти OPS будут либо удалены, либо добавлены в спецификацию с полной поддержкой. Пока эти фьючерсы на OPS не будут известны, они гарантируются только 6 месяцев совместимости.
Исполнение
Последовательное выполнение
Программа StableHLO выполняется путем предоставления входных значений для main
функции и вычисления выходных значений. Выходные значения функции вычисляются путем выполнения графика OPS, конутированного в соответствующем return
OP.
Порядок выполнения определяется внедрением, если он выровнен с DataFlow, то есть, если OPS выполняется до их использования. В stablehlo все побочные операции потребляют один токен и создают один токен (множественные токены могут быть мультиплексированы в один токен через after_all
), поэтому порядок выполнения побочных эффектов также выровнен с DataFlow. Например, в приведенной ниже программе существуют два возможных ордера на выполнение: %0
→ %1
→ %2
→ return
и %1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Более формально, процесс стабильногохлоя - это комбинация: 1) программа StableHlo, 2) статусы операции (еще не выполненные, уже выполненные) и 3) промежуточные значения, над которыми работает процесс. Процесс начинается с входных значений к main
функции, прогрессирует через график статусов операции обновления OPS и промежуточных значений и заканчивается выходными значениями. Дальнейшая формализация - TBD ( #484 ).
Параллельное исполнение
Программы StableHLO могут быть выполнены параллельно, организованные в 2D -сетку процесса num_replicas
с помощью num_partitions
, которые имеют тип ui32
.
В сетке процесса CTABLEHLO num_replicas * num_partitions
Процессов stableHLO выполняется одновременно. Каждый процесс имеет уникальный process_id = (replica_id, partition_id)
, где replica_id
в replica_ids = range(num_replicas)
и partition_id
в partition_ids = range(num_partitions)
, которые оба имеют тип ui32
.
Размер процесса сетки известен для каждой программы (в будущем мы планируем сделать ее явной частью программ № 650 #650 ), и позиция в сетке процесса известна для каждого процесса. Каждый процесс имеет доступ к своей позиции в сетке процесса через replica_id
и partition_id
Ops.
В рамках сетки процесса программы могут быть одинаковыми (в стиле «одиночная программа, многочисленные данные»), все могут быть разными (в стиле «Многочисленные программы, множественные данные») или что -то между ними. В будущем мы планируем ввести поддержку другим идиомам определения параллельных программ StableHlo, включая GSPMD ( #619 ).
В рамках сетки процесса процессы в основном не зависят друг от друга - они имеют отдельные статусы операции, отдельные значения ввода/промежуточных/выходных данных, и большинство OPS выполняются отдельно между процессами, за исключением небольшого числа коллективных операций, описанных ниже .
Учитывая, что выполнение большинства OPS использует только значения из одного и того же процесса, обычно однозначно ссылаться на эти значения по их именам. Однако при описании семантики коллективных операций это недостаточно, и это приводит к тому, что нотация name@process_id
для обозначения name
значения в определенном процессе. (С этой точки зрения, неквалифицированное name
можно рассматривать как сокращение для name@(replica_id(), partition_id())
).
Порядок выполнения в разных процессах определяется внедрением, за исключением синхронизации, введенной в соответствии с точечной связи и коллективными операциями, как описано ниже.
Общение с точки зрения
Процессы стабильногохла могут общаться друг с другом по каналам стабильных контейнеров . Канал представлен положительным идентификатором типа si64
. Через различные операции можно отправить значения на каналы и получать их по каналам.
Дальнейшая формализация, например, откуда поступают эти идентификаторы каналов, как программы процессов осведомлены о них и какую синхронизацию их вводятся TBD ( #484 ).
Потоковая связь
Каждый процесс stablehlo имеет доступ к двум потоковым интерфейсам:
- Дополнение , которое можно прочитать от.
- Outfeed , который может быть написан.
В отличие от каналов, которые используются для общения между процессами и, следовательно, имеют процессы на обоих концах, у инфинов и находов есть другая конечная реализация.
Дальнейшая формализация, например, как потоковая коммуникация влияет на порядок выполнения и какую синхронизацию вводится им, это TBD ( #484 ).
Коллективные операции
В StableHlo есть шесть коллективных операций: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
и reduce_scatter
. Все эти OPS разделили процессы в сетке процесса StableHlo на группы процессов StableHlo и выполняют совместные вычисления в каждой группе процессов, независимо от других групп процессов.
В каждой группе процессов Collective Ops может представлять барьер синхронизации. Дальнейшая формализация, например, подробно описывает, когда именно происходит эта синхронизация, как именно процессы попадают в этот барьер, и что произойдет, если они этого не делают, это TBD ( #484 ).
Если группа процессов включает в себя перекрестную связь, то есть в группе процессов есть процессы, идентификаторы раздела которых различны, то выполнение коллективного OP нуждается в канале, а коллективный OP должен обеспечить положительный channel_id
Type si64
. Cross-Replica Communication не нуждается в каналах.
Вычисления, выполняемые коллективными операциями, специфичны для отдельных OPS и описаны в отдельных разделах OP выше. Тем не менее, стратегии, с помощью которых сетка процесса разделяется на группы процессов, разделяются между этими операциями и описаны в этом разделе. Более формально, StableHlo поддерживает следующие четыре стратегии.
Cross_replica
Только перекрестные коммуникации происходят в каждой группе процессов. Эта стратегия принимает replica_groups
- список списков идентификаторов реплик - и вычисляет картезианский продукт replica_groups
по partition_ids
. replica_groups
должны иметь уникальные элементы и охватывать все replica_ids
. Более формально, используя синтаксис Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Например, для replica_groups = [[0, 1], [2, 3]]
и num_partitions = 2
, cross_replica
будет создавать [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
Cross_partition
Только перекрестные коммуникации происходят в каждой группе процессов. Эта стратегия принимает partition_groups
- список списков идентификаторов раздела - и вычисляет картезианский продукт partition_groups
по replica_ids
. partition_groups
должен иметь уникальные элементы и охватывать все partition_ids
. Более формально, используя синтаксис Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Например, для partition_groups = [[0, 1]]
и num_replicas = 4
, cross_partition
будет создавать [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
В каждой группе процессов могут произойти как перекрестная реплика, так и перекрестная коммуникация. Эта стратегия принимает replica_groups
- список списков идентификаторов реплик - и вычисляет картезианские продукты каждой replica_group
по partition_ids
. replica_groups
должны иметь уникальные элементы и охватывать все replica_ids
. Более формально, используя синтаксис Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Например, для replica_groups = [[0, 1], [2, 3]]
и num_partitions = 2
, cross_replica_and_partition
будет создавать [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
сглаженные_ид
Эта стратегия принимает flattened_id_groups
- список списков «сглаженных» идентификаторов процесса в форме replica_id * num_partitions + partition_id
- и превращает их в идентификаторы процесса. flattened_id_groups
должны иметь уникальные элементы и охватывать все process_ids
. Более формально, используя синтаксис Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Например, для flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
и num_partitions = 2
, flattened_ids
будут создавать [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Точность
На данный момент StableHlo не предоставляет гарантий о численной точности, но это может измениться в будущем ( #1156 ).
Семантика выполнения квантовой операции
Интерпретация квантованных операций StableHlo может варьироваться в зависимости от требований и возможностей оборудования. Например, некоторые аппаратные обеспечения могут предположить интерпретировать квантованные операции с использованием стратегии «отбрасывать, выполнять операцию с плавающей точкой и, наконец, квантовать». Другие могут выполнить все вычисления с целой арифметикой. Consequently, the interpretation of quantized StableHLO operations is exclusively determined by the specific implementation. The interpretation of hybrid quantization ( #1575 ) should be based on the it's semantics as prescribed in the specification (via 1792 ).
Ошибки
StableHLO programs are validated through an extensive set of constraints for individual ops, which rules out many classes of errors prior to run time. However, error conditions are still possible, eg through integer overflows, out-of-bounds accesses, etc. Unless explicitly called out, all these errors result in implementation-defined behavior, but this may change in the future ( #1157 ).
Floating-point exceptions
As an exception to this rule, floating-point exceptions in StableHLO programs have well-defined behavior. Operations which result in exceptions defined by the IEEE-754 standard (invalid operation, division-by-zero, overflow, underflow, or inexact exceptions) produce default results (as defined in the standard) and continue execution without raising the corresponding status flag; similar to raiseNoFlag
exception handling from the standard. Exceptions for nonstandard operations (eg complex arithmetic and certain transcendental functions) are implementation-defined.
Shape mismatches
StableHLO supports dynamically-shaped tensors. However, shapes have to agree at runtime, otherwise the behavior is undefined. StableHLO does not explicitly provide an op that can assert that a tensor has a given shape at runtime. Generating correct code is the responsibility of the producer.
As a specific example, the below program is valid. However, at runtime, the exact shapes of %arg0
and %arg1
will have to be the same, otherwise the behavior of the program is undefined:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
For describing syntax, this document is using the modified ISO flavor of EBNF syntax ( ISO/IEC 14977:1996 , Wikipedia ), with two modifications: 1) rules are defined using ::=
rather than =
,
2) concatenation is expressed using juxtaposition rather than ,
.
For describing semantics (ie within "Types", "Constants" and "Ops" sections), we are using formulas which are based on Python syntax extended with support for concisely expressing array operations as described below. This works well for small snippets of code, but in rare cases when larger snippets of code are needed, we use vanilla Python syntax which is always introduced explicitly.
Формулы
Let's explore how formulas work based on an example from the dot_general
specification. One of the constraints for this operation looks as follows: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
The names used in this formula come from two sources: 1) global functions, ie dim
, 2) member definitions of the corresponding program element, ie lhs
, lhs_batching_dimensions
, rhs
and rhs_batching_dimensions
inputs defined in the "Inputs" section of dot_general
.
As mentioned above, the syntax of this formula is Python-based with some conciseness-oriented extensions. To make sense of the formula, let's transform it into vanilla Python syntax.
A) In these formulas, we are using =
to represent equality, so the first step towards obtaining Python syntax is replacing =
with ==
, as follows: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Also, these formulas support ellipses ( ...
) which turn scalar expressions into tensor expressions. In a nutshell, f(xs...)
roughly means "for each scalar x
in the tensor xs
, compute a scalar f(x)
and then return all these scalar results together as a tensor result". In vanilla Python syntax, our example formula turns into: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Thanks to ellipses, it is often possible to avoid working at the level of individual scalars. However, in some tricky cases, lower-level semi-informal syntax may be used like in the start_indices[bi0, ..., :, ..., biN]
formula from the gather
specification. In the service of conciseness, we don't provide an exact formalism for translating such syntax to vanilla Python, in hopes that it is still intuitively understandable on case-by-case basis. Please let us know if some specific formulas look opaque, and we'll try to improve them.
Also, you will notice that formulas use ellipses to expand all sorts of lists, including tensors, lists of tensors (which eg can arise from a variadic number of tensors), etc. This is another area where we don't provide an exact formalism (eg lists are not even part of the StableHLO type system) and instead rely on intuitive understandability.
C) The final noteworthy notational vehicle that we employ is implicit broadcasting. While the StableHLO opset doesn't support implicit broadcasting, the formulas do, also in the service of conciseness. In a nutshell, if a scalar is used in a context where a tensor is expected, the scalar is broadcasted to the expected shape.
To continue the dot_general
example, here's another constraint: 0 <= lhs_batching_dimensions < rank(lhs)
. As defined in the dot_general
specification, lhs_batching_dimensions
is a tensor, however both 0
and rank(lhs)
are scalars. After we apply implicit broadcasting, the formula will become [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
When applied to a particular dot_general
operation, this formula will evaluate to a tensor of booleans. When formulas are used as constraints, the constraint holds if the formula evaluates to either true
or to a tensor which only has true
elements.
Names
In formulas, lexical scope includes: 1) global functions, 2) member definitions,
3) local definitions. The list of global functions is provided below. The list of element definitions depends on the program element that the notation is applied to:
- For operations, member definitions include names introduced in "Inputs" and "Outputs" sections.
- For everything else, member definitions include structural parts of the program element, named after the corresponding EBNF non-terminals. Most of the time, the names of these structural parts are obtained by converting the names of the non-terminals to snake case (eg
IntegerLiteral
=>integer_literal
), but sometimes names get abbreviated in the process (egQuantizationStorageType
=>storage_type
) in which case the names are introduced explicitly similarly to "Inputs" / "Outputs" sections in operation specifications. - Additionally, member definitions always include
self
to refer to the corresponding program element.
Ценности
When formulas are evaluated, they work with the following types of values: 1) Value
(actual values, eg dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
; they always know their types), 2) Placeholder
(future values, eg lhs
, rhs
or result
; their actual values are not known yet, only their types are known), 3) Type
(types as defined in the "Types" section), 4) Function
(global functions as defined in the "Functions" section).
Depending on the context, names may be referring to different values. More specifically, the "Semantics" section for ops (and equivalents for other program elements) defines runtime logic, so all inputs are available as Value
. In contrast, the "Constraints" section for ops (and equivalents) defines "compile-time" logic, ie something that is typically executed before runtime, so only constant inputs are available as Value
and other inputs are available only as Placeholder
.
Names | In "Semantics" | In "Constraints" |
---|---|---|
Global functions | Function | Function |
Constant inputs | Value | Value |
Non-constant inputs | Value | Placeholder |
Выходы | Value | Placeholder |
Local definitions | Depends on the definition | Depends on the definition |
Let's consider an example transpose
operation:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
For this operation, permutation
is a constant, so it's available as a Value
in both semantics and constraints. In contrast, operand
and result
are available as a Value
in semantics but only as a Placeholder
in constraints.
Функции
Construction of types
There are no functions that can be used to construct types. Instead, we directly use type syntax because it's typically more concise. Eg (tensor<E>, tensor<E>) -> (tensor<E>)
rather than function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Functions on types
-
element_type
is defined on tensor types and quantized tensor types and returns, respectively, theTensorElementType
orQuantizedTensorElementType
part of the correspondingTensorType
orQuantizedTensorType
.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
is a shortcut foris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
is a shortcut foris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
checks if typex
can be promoted to typey
. Whenx
andy
areQuantizedTensorElementType
s, the promotion is applied only to thestorage_type
. This specific version of promotion is currently used in context of reduction computation (refer to RFC for more details).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
is a shortcut foris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Available for all types. For example,is_float(x)
returnstrue
ifx
is aFloatType
. Ifx
is a value or placeholder, this function is a shortcut foris_type_name(type(x))
.max_value(x: Type) -> Value
returns the maximum value of anTensorElementType
. Ifx
is not anTensorElementType
, returnsNone
.min_value(x: Type) -> Value
returns the minimum possible value of anTensorElementType
. Ifx
is not anTensorElementType
, returnsNone
.member_name(x: Value | Placeholder | Type) -> Any
. Available for all member definitionsmember_name
of all types. For example,tensor_element_type(x)
returns theTensorElementType
part of a correspondingTensorType
. Ifx
is a value or placeholder, this function is a shortcut formember_name(type(x))
. Ifx
is not a type that has an appropriate member, or a value or a placeholder of such a type, returnsNone
.is_empty_algorithm(*args: Type)
checks if all dot algorithm fields are set toNone
. This is needed since dot algorithms have implementation defined default behaviors, so specifying a default value would be incorrect.
Construction of values
-
operation_name(*xs: Value | Type) -> Value
. Available for all operations. For example,add(lhs, rhs)
takes two tensor valueslhs
andrhs
and returns the output of evaluating theadd
operation with these inputs. For some operations egbroadcast_in_dim
, types of their outputs are "load-bearing", ie needed to evaluate an operation. In this case, the function takes these types as arguments.
Functions on values
All Python's operators and functions are available. Eg both subscription and slicing notations from Python are available to index into tensors, quantized tensors and tuples.
to_destination_type(x: Value, destination_type: Type) -> Value
is defined on tensors and returns the converted value ofx
based on thetype(x)
anddestination_type
as follows:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
There is early discussion on merging convert
, uniform_quantize
and uniform_dequantize
operations ( #1576 ). After the merge we do not need the above function and can use the operation name for convert
instead.
is_nan(x: Value) -> Value
is defined on tensors and returnstrue
if all elements ofx
areNaN
orfalse
otherwise. Ifx
is not a tensor, returnsNone
.is_sorted(x: Value) -> Value
is defined on tensors and returnstrue
if elements ofx
are sorted in ascending order with respect to the ascending lexicographical order of their indices orfalse
otherwise. Ifx
is not a tensor, returnsNone
.is_unique(x: Value) -> Value
is defined on tensors and returnstrue
ifx
doesn't have duplicate elements orfalse
otherwise. Ifx
is not a tensor, returnsNone
.member_name(x: Value) -> Any
is defined for all member definitionsmember_name
of all values. For example,real_part(x)
returns theRealPart
part of a correspondingComplexConstant
. Ifx
is not a value that has an appropriate member, returnsNone
.same(x: Value) -> Value
is defined on tensors and returnstrue
if elements ofx
are all equal to each other orfalse
otherwise. If the tensor doesn't have elements, that counts as "all equal to each other", ie the function returnstrue
. Ifx
is not a tensor, returnsNone
.split(x: Value, num_results: Value, axis: Value) -> Value
is defined on tensors and returnsnum_results
slices ofx
along the axisaxis
. Ifx
is not a tensor ordim(x, axis) % num_results != 0
, returnsNone
.is_defined_in_parent_scope(x: Value) -> Value
is defined on strings and returnstrue
ifx
is the name of a function defined in the same scope as the parent function of the relevant op.is_namespaced_op_name(x: Value) -> Value
is defined on strings and returnstrue
ifx
is a valid op name, that is it respects the following regular expression:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Shape computations
axes(x: Value | Placeholder | Type) -> Value
is a shortcut forrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
is a shortcut forshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
is a shortcut forlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
is defined on tensors and returnssize(x)
indices for the correspondingTensorType
sorted in ascending lexicographical order, ie[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Ifx
is not a tensor type, a quantized tensor type, or a value or a placeholder of one of these types, returnsNone
.rank(x: Value | Placeholder | Type) -> Value
is a shortcut forsize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
is defined in the "Functions on types" section viamember_name
.size(x: Value | Placeholder | Type) -> Value
is a shortcut forreduce(lambda x, y: x * y, shape(x))
.
Quantization computations
def baseline_element_type(x: Value | Placeholder | Type) -> Type
is a shortcut forelement_type(baseline_type(x))
.baseline_type
is defined on tensor types and quantized tensor types and transforms them to a "baseline", ie a type with the same shape but with the quantization parameters of the element type reset to default values. This is used as a handy trick to compare both tensor and quantized tensor types uniformly, which is needed quite often. For quantized types, this enables comparing types ignoring the quantization parameters, that is,shape
,storage_type
,expressed_type
,storage_min
,storage_max
, andquantization_dimension
(for per-axis quantized type) must all match, butscales
andzero points
may differ.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
-
dequantize
is defined on quantized tensor types and turns them into floating-point tensor types. This happens via converting quantized elements which represent integer values of the storage type into corresponding floating-point values of the expressed type using the zero point and scale associated with the quantized element type.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
-
quantize
is defined on floating-point tensor types and turns them into quantized tensor types. This happens via converting floating-point values of the expressed type into corresponding integer values of the storage type using the zero point and scale associated with the quantized element type.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
-
dequantize_op_quantize
is used to specify element-wise computations on quantized tensors. It dequantizes, ie turns quantized elements into their expressed types, then performs an operation, and then quantizes, ie turns the results back into their storage types. At the moment, this function only works for per-tensor quantization. Per-axis quantization is work in progress ( #1574 ).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
-
hybrid_dequantize_then_op
is used to specify weight-only quantization for hybrid op which accepts lhs in floating-point and rhs in quantized types. It dequantizes quantized inputs into their expressed types and performs computation in float. Element type of float lhs tensor and expressed type of quantized rhs tensor should be identical.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Grid computations
cross_partition(replica_groups: Value) -> Value
. See the "cross_replica" section above.cross_replica(replica_groups: Value) -> Value
. See the "cross_replica" section above.cross_replica_and_partition(replica_groups: Value) -> Value
. See the "cross_replica_and_partition" section above.flattened_ids(replica_groups: Value) -> Value
. See the "flattened_ids" section above.
Dynamism
StableHLO values can have dynamic dimension sizes, eg tensor<?xi64>
. However, StableHLO values cannot have a dynamic number of dimensions (unranked dynamism, eg tensor<*xi64>
). Operands and results are allowed to use dynamic dimension sizes, even if there are constraints on the sizes. Constraints will be verified statically if possible, otherwise they are deferred to runtime and mismatches will result in undefined behavior. See below for examples.
Shape mismatches for unary elementwise operations
Consider the following toy program:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Such a program is unusual, because it is not common to know the shape of the result but not the shape of the input. Nonetheless, this is a valid StableHLO program. It is not possible to statically validate the abs
operation in this program, because the exact shape of the operand is unknown. However, the shapes are certainly compatible, and this can be checked statically: ?
could turn out to be 2
at runtime, and there would be no issue. Однако, ?
could also turn out to be some other integer, in which case the behavior is undefined.
Note that if a dimension size is dynamic in the result, there cannot be undefined behavior. Indeed, there is no "expected" size, so there cannot be a mismatch.
Shape mismatches for binary elementwise operations
Consider the following toy program:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
When it comes to binary elementwise operations, the shapes of the inputs and the result must agree at runtime. At compile time, static dimensions must be equal, otherwise they merely need to be compatible. If any dimension is dynamic in the inputs, then there could be undefined behavior at runtime, because the dynamic size may not match the corresponding size in the other operand (be it static or dynamic). If all the inputs are static, then whether the result is dynamic or not does not matter: statically known dimensions will be checked statically, and dynamic dimensions do not impose any constraints.
Shape mismatches for ops that take their output shape as an operand
Consider the following toy program:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
The values in the shape operand at runtime must match the shape of the result, otherwise the behavior is undefined. That is, at runtime %arg0
must have a value of dense<[3, 4]> : tensor<2xi32>
. If the shape operand is constant, this can be verified statically. If the result shape is fully dynamic, then there cannot be a mismatch.
StableHLO — это набор операций для операций высокого уровня (HLO) в моделях машинного обучения (ML). StableHLO работает как уровень переносимости между различными платформами машинного обучения и компиляторами машинного обучения: платформы машинного обучения, создающие программы StableHLO, совместимы с компиляторами машинного обучения, которые используют программы StableHLO.
Наша цель — упростить и ускорить разработку машинного обучения за счет большей совместимости между различными платформами машинного обучения (такими как TensorFlow, JAX и PyTorch) и компиляторами машинного обучения (такими как XLA и IREE). С этой целью в этом документе представлена спецификация языка программирования StableHLO.
Данная спецификация содержит три основных раздела. Во-первых, раздел «Программы» описывает структуру программ StableHLO, которые состоят из функций StableHLO, которые сами состоят из операций StableHLO. В этой структуре раздел Ops определяет семантику отдельных операций. Раздел «Выполнение» предоставляет семантику для всех этих операций, выполняемых вместе в программе. Наконец, в разделе «Обозначения» обсуждаются обозначения, используемые в спецификации.
Чтобы просмотреть спецификацию предыдущего выпуска StableHLO, откройте репозиторий интересующего выпуска с тегом . Например, StableHLO v0.19.0 Spec . Чтобы просмотреть изменения, произошедшие при каждом второстепенном обновлении версии StableHLO, обратитесь к журналу версий в VhloDialect.td .
Программы
Program ::= {Func}
Программы StableHLO состоят из произвольного количества функций StableHLO. Ниже приведен пример программы с функцией @main
, которая имеет 3 входа ( %image
, %weights
и %bias
) и 1 выход. Тело функции имеет 6 операций.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Функции
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Функции StableHLO (которые также называются именованными функциями ) имеют идентификатор, входы/выходы и тело. В будущем мы планируем ввести дополнительные метаданные для функций для достижения лучшей совместимости с HLO ( #425 , #626 , #740 , #744 ).
Идентификаторы
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Идентификаторы StableHLO похожи на идентификаторы во многих языках программирования, но имеют две особенности: 1) все идентификаторы имеют символы, которые различают разные типы идентификаторов, 2) идентификаторы значений могут быть полностью числовыми, чтобы упростить создание программ StableHLO.
Типы
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Типы StableHLO подразделяются на типы значений (которые также называются типами первого класса ), которые представляют значения StableHLO, и типы, не являющиеся значениями , которые описывают другие элементы программы. Типы StableHLO похожи на типы во многих языках программирования, при этом основной особенностью является предметно-ориентированный характер StableHLO, что приводит к некоторым необычным результатам (например, скалярные типы не являются типами значений).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Тензорные типы представляют собой тензоры, т.е. многомерные массивы. У них есть форма и тип элемента , где форма представляет собой неотрицательные или неизвестные размеры размеров в порядке возрастания соответствующих размеров (которые также называются осями ), пронумерованных от 0
до R-1
. Число измерений R
называется рангом . Например, tensor<2x3xf32>
— это тип тензора с формой 2x3
и типом элемента f32
. Он имеет два измерения (или, другими словами, две оси) — 0-е измерение и 1-е измерение — размеры которых равны 2 и 3. Его ранг равен 2.
Формы могут быть частично или полностью неизвестными (динамическими), например, tensor<?x2xf64>
частично неизвестен, а tensor<?x?xf64>
полностью неизвестен. Размеры динамических размеров обозначаются знаком ?
. Фигуры не могут быть лишены ранжирования.
В будущем мы планируем изучить расширение типов тензоров за пределы размеров размеров и типов элементов, например, включив макеты ( #629 ) и разреженность ( #1078 ).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Имя | Тип | Ограничения |
---|---|---|
storage_type | целочисленный тип | (С1-С3), (С8) |
storage_min | целочисленная константа | (С1), (С3), (С7) |
storage_max | целочисленная константа | (С2), (С3), (С7) |
expressed_type | тип с плавающей запятой | (С4) |
quantization_dimension | необязательная целочисленная константа | (С10-С12) |
scales | вариативное число констант с плавающей запятой | (С4-С6), (С9), (С10), (С13) |
zero_points | вариативное число целочисленных констант | (С7-С9) |
Типы квантованных элементов представляют собой целочисленные значения типа хранения в диапазоне от storage_min
до storage_max
(включительно), которые соответствуют значениям с плавающей запятой выраженного типа . Для данного целочисленного значения i
соответствующее значение f
с плавающей запятой может быть вычислено как f = (i - zero_point) * scale
, где scale
и zero_point
называются параметрами квантования . storage_min
и storage_max
не являются обязательными в грамматике, но имеют значения по умолчанию min_value(storage_type)
и max_value(storage_type)
соответственно. Типы квантованных элементов имеют следующие ограничения:
- (C1)
type(storage_min) = storage_type
. - (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Если
is_empty(quantization_dimension)
, тоsize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
На данный момент QuantizationScale
представляет собой константу с плавающей запятой, но существует большой интерес к целочисленным шкалам, представленным множителями и сдвигами. Мы планируем изучить это в ближайшем будущем ( #1404 ).
Продолжается обсуждение семантики QuantizationZeroPoint
, включая тип, значения и может ли быть только одна или потенциально несколько нулевых точек в типе квантованного тензора. По результатам этого обсуждения спецификация нулевых точек может измениться в будущем ( #1405 ).
Другое продолжающееся обсуждение касается семантики QuantizationStorageMin
и QuantizationStorageMax
, чтобы определить, следует ли налагать какие-либо ограничения на эти значения и на значения квантованных тензоров ( #1406 ).
Наконец, мы планируем изучить представление неизвестных масштабов и нулевых точек аналогично тому, как мы планируем изучить представление неизвестных размеров измерений ( #1407 ).
Типы квантованных тензоров представляют собой тензоры с квантованными элементами. Эти тензоры точно такие же, как и обычные тензоры, за исключением того, что их элементы имеют типы квантованных элементов вместо обычных типов элементов.
В квантованных тензорах квантование может быть потензорным , то есть иметь один scale
и zero_point
для всего тензора, или может быть поосевым , то есть иметь несколько scales
и zero_points
, одну пару на срез определенного измерения quantization_dimension
. Более формально, в тензоре t
с поосевым квантованием существуют dim(t, quantization_dimension)
срезы quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
и т. д. Все элементы в i
-м срезе используют scales[i]
и zero_points[i]
в качестве параметров квантования. Типы квантованных тензоров имеют следующие ограничения:
- Для потензорного квантования:
- Никаких дополнительных ограничений.
- Для поосевого квантования:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Типы токенов представляют собой токены, т.е. непрозрачные значения, создаваемые и потребляемые некоторыми операциями. Токены используются для установления порядка выполнения операций, как описано в разделе «Выполнение» .
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Типы кортежей представляют собой кортежи, т. е. гетерогенные списки. Кортежи — это устаревшая функция, которая существует только для совместимости с HLO. В HLO кортежи используются для представления переменных входных и выходных данных. В StableHLO изначально поддерживаются переменные входные и выходные данные, и единственное использование кортежей в StableHLO — это всестороннее представление HLO ABI, где, например, T
, tuple<T>
и tuple<tuple<T>>
могут существенно отличаться в зависимости от конкретной реализации. . В будущем мы планируем внести изменения в HLO ABI, которые могут позволить нам удалить типы кортежей из StableHLO ( #598 ).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Типы элементов представляют собой элементы тензорных типов. В отличие от многих языков программирования, эти типы не являются первоклассными в StableHLO. Это означает, что программы StableHLO не могут напрямую представлять значения этих типов (в результате идиоматично представлять скалярные значения типа T
с помощью 0-мерных тензорных значений типа tensor<T>
).
- Тип Boolean представляет логические значения
true
иfalse
. - Целочисленные типы могут быть знаковыми (
si
) или беззнаковыми (ui
) и иметь одну из поддерживаемых разрядностей (2
,4
,8
,16
,32
или64
). Знаковые типыsiN
представляют целые значения от-2^(N-1)
до2^(N-1)-1
включительно, а беззнаковые типыuiN
представляют целочисленные значения от0
до2^N-1
включительно. - Типы с плавающей запятой могут быть одним из следующих:
-
f8E3M4
,f8E4M3
иf8E5M2
8-битные числа с плавающей запятой в соответствии с соглашениями IEEE-754. - Типы
f8E4M3FN
иf8E5M2
соответствующие соответственно кодировкамE4M3
иE5M2
формата FP8, описанным в разделе «Форматы FP8 для глубокого обучения» . - Типы
f8E4M3FNUZ
иf8E5M2FNUZ
соответствующие кодировкамE4M3
иE5M2
форматов FP8, описанных в 8-битных числовых форматах для глубоких нейронных сетей . - Тип
f8E4M3B11FNUZ
, соответствующий кодировкеE4M3
форматов FP8, описанных в разделе «Обучение и вывод гибридных 8-битных чисел с плавающей запятой (HFP8) для глубоких нейронных сетей» . - Тип
bf16
, соответствующий форматуbfloat16
, описанному в BFloat16: Секрет высокой производительности на Cloud TPU . - Типы
f16
,f32
иf64
соответствующие соответственноbinary16
(«половинная точность»),binary32
(«одинарная точность»)binary64
(«двойная точность»), описанным в стандарте IEEE 754 . - Тип
tf32
соответствует формату TensorFloat32 и имеет ограниченную поддержку в StableHLO. - Типы
f4E2M1FN
,f6E2M3FN
,f6E3M2FN
иf8E8M0FNU
MX (микромасштабирование) описаны в Спецификации форматов микромасштабирования OCP .
-
- Сложные типы представляют собой комплексные значения, которые имеют действительную и мнимую части одного и того же типа элемента . Поддерживаемые сложные типы:
complex<f32>
(обе части имеют типf32
) иcomplex<f64>
(обе части имеют типf64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Типы функций представляют как именованные, так и анонимные функции. У них есть типы ввода (список типов в левой части ->
) и типы вывода (список типов в правой части ->
). Во многих языках программирования типы функций являются первоклассными, но не в StableHLO.
StringType ::= 'string'
Тип String представляет собой последовательность байтов. В отличие от многих языков программирования, строковый тип не является первым классом в StableHLO и используется только для указания статических метаданных для элементов программы.
Операции
Операции StableHLO (которые также называются ops ) представляют собой закрытый набор операций высокого уровня в моделях машинного обучения. Как обсуждалось выше, синтаксис StableHLO во многом основан на MLIR, который не обязательно является наиболее эргономичной альтернативой, но, возможно, лучше всего подходит для цели StableHLO по созданию большей совместимости между платформами ML и компиляторами ML.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Операции StableHLO (которые также называются ops ) имеют имя, входы/выходы и подпись. Название состоит из stablehlo.
префикс и мнемоника , которая однозначно идентифицирует одну из поддерживаемых операций. Ниже приведен полный список всех поддерживаемых операций.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Операционные операторы потребляют входные данные и производят выходные данные . Входные данные подразделяются на входные значения (вычисляемые во время выполнения), входные функции (предоставляемые статически, поскольку в StableHLO функции не являются значениями первого класса) и входные атрибуты (также предоставляемые статически). Вид входных и выходных данных, потребляемых и производимых операцией, зависит от ее мнемоники. Например, операция add
потребляет 2 входных значения и производит 1 выходное значение. Для сравнения, операция select_and_scatter
использует 3 входных значения, 2 входные функции и 3 входных атрибута.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Функции ввода (которые также называются анонимными функциями ) очень похожи на именованные функции, за исключением того, что: 1) они не имеют идентификатора (отсюда и название «анонимные»), 2) они не объявляют типы вывода (типы вывода выводится из операции return
внутри функции).
Синтаксис функций ввода включает в себя неиспользуемую в настоящее время часть (см. раздел « Unused
» выше), которая предназначена для совместимости с MLIR. В MLIR существует более общая концепция «регионов», которая может состоять из нескольких «блоков» операций, соединенных вместе посредством операций перехода. Эти блоки имеют идентификаторы, соответствующие Unused
продукции, чтобы их можно было отличить друг от друга. В StableHLO нет прыжковых операций, поэтому соответствующая часть синтаксиса MLIR не используется (но все еще существует).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Входные атрибуты имеют имя и значение, которое является одной из поддерживаемых констант. Они являются основным способом указания статических метаданных для элементов программы. Например, операция concatenate
использует dimension
атрибута, чтобы указать измерение, по которому объединяются его входные значения. Аналогично, операция slice
использует несколько атрибутов, таких как start_indices
и limit_indices
для указания границ, которые используются для среза входного значения.
На данный момент существующие программы StableHLO иногда содержат атрибуты, не описанные в этом документе. В будущем мы планируем либо включить эти атрибуты в опсет StableHLO, либо запретить их появление в программах StableHLO. А пока вот список этих атрибутов:
-
layout
( #629 ). -
mhlo.frontend_attributes
( #628 ). -
mhlo.sharding
( #619 ). -
output_operand_aliases
( #740 ). - Метаданные местоположения ( #594 ).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
Сигнатура операции состоит из типов всех входных значений (список типов в левой части ->
) и типов всех выходных значений (список типов в правой части ->
). Строго говоря, входные типы избыточны, и выходные типы также почти всегда избыточны (поскольку для большинства операций StableHLO типы выходных данных могут быть выведены из входных данных). Тем не менее, подпись op намеренно является частью синтаксиса StableHLO для совместимости с MLIR.
Ниже приведен пример операции, мнемоника которой — select_and_scatter
. Он использует 3 входных значения ( %operand
, %source
и %init_value
), 2 входные функции и 3 входных атрибута ( window_dimensions
, window_strides
и padding
). Обратите внимание, что подпись операции включает только типы ее входных значений (но не типы входных функций и атрибутов, которые предоставляются в строке).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Константы
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Константы StableHLO имеют литерал и тип, которые вместе представляют значение StableHLO. Обычно тип является частью синтаксиса константы, за исключением случаев, когда он однозначен (например, логическая константа однозначно имеет тип i1
, тогда как целочисленная константа может иметь несколько возможных типов).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Булевы константы представляют логические значения true
и false
. Булевы константы имеют тип i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Целочисленные константы представляют целочисленные значения посредством строк, в которых используется десятичная или шестнадцатеричная запись. Другие системы счисления, например двоичная или восьмеричная, не поддерживаются. Целочисленные константы имеют следующие ограничения:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Константы с плавающей запятой представляют значения с плавающей запятой в виде строк, в которых используется десятичная или экспоненциальная запись. Кроме того, шестнадцатеричная запись может использоваться для непосредственного указания базовых битов в формате с плавающей запятой соответствующего типа. Константы с плавающей запятой имеют следующие ограничения:
- (C1) Если используется нешестнадцатеричная система записи,
is_wellformed(float_literal, float_type)
. - (C2) Если используется шестнадцатеричная запись,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Комплексные константы представляют комплексные значения с использованием списков вещественной части (идет первой) и мнимой части (идет второй). Например, (1.0, 0.0) : complex<f32>
представляет 1.0 + 0.0i
, а (0.0, 1.0) : complex<f32>
представляет 0.0 + 1.0i
. Порядок, в котором эти части затем сохраняются в памяти, определяется реализацией. Комплексные константы имеют следующие ограничения:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Тензорные константы представляют значения тензора с использованием вложенных списков, заданных с помощью нотации NumPy. Например, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
представляет значение тензора со следующим сопоставлением индексов с элементами: {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
, {1, 2} => 6
. Порядок, в котором эти элементы затем сохраняются в памяти, определяется реализацией. Тензорные константы имеют следующие ограничения:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, где:-
has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
. -
has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
-
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, где:-
has_shape(element_literal: Syntax, []) = true
. -
has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
. - в противном случае
false
.
-
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Квантованные тензорные константы представляют квантованные тензорные значения с использованием тех же обозначений, что и тензорные константы, с элементами, заданными как константы их типа хранения. Квантованные тензорные константы имеют следующие ограничения:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Строковые литералы состоят из байтов, заданных с помощью символов ASCII и escape-последовательностей. Они не зависят от кодировки, поэтому интерпретация этих байтов определяется реализацией. Строковые литералы имеют тип string
.
Операции
пресс
Семантика
Выполняет поэлементную операцию abs над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел со знаком: целочисленный модуль.
- Для поплавков:
abs
из IEEE-754. - Для комплексных чисел: комплексный модуль.
- Для квантованных типов:
dequantize_op_quantize(abs, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целого числа со знаком, с плавающей запятой или комплексного типа или по-тензорный квантованный тензор | (С1-С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целого числа со знаком или типа с плавающей запятой или потензорный квантованный тензор | (С1-С2) |
Ограничения
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
определяется как:-
complex_element_type(element_type(operand))
ifis_complex(operand)
. -
baseline_element_type(operand)
в противном случае.
-
Примеры
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
добавлять
Семантика
Выполняет поэлементное сложение двух тензоров lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое ИЛИ.
- Для целых чисел: сложение целых чисел.
- Для поплавков:
addition
из IEEE-754. - Для комплексных чисел: комплексное сложение.
- Для квантованных типов:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор или квантованный тензор | (С1-С6) |
(И2) | rhs | тензор или квантованный тензор | (С1-С5), (С7) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С7) |
Ограничения
- Если в операции используются неквантованные тензоры:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Если в операции используются квантованные тензоры:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Если
is_per_axis_quantized(lhs)
, тоquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Примеры
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
после всего
Семантика
Гарантирует, что операции, производящие inputs
, выполняются до выполнения любых операций, зависящих от result
. Выполнение этой операции ничего не делает, она существует только для того, чтобы установить зависимости данных от result
до inputs
.
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | вариативное число token |
Выходы
Имя | Тип |
---|---|
result | token |
Примеры
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
все_собрать
Семантика
Внутри каждой группы процессов в сетке процессов StableHLO объединяет значения тензоров operands
каждого процесса по all_gather_dim
и создает тензоры results
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
еслиchannel_id <= 0 and use_global_device_ids = false
. -
cross_replica_and_partition(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = false
. -
flattened_ids(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = true
.
Затем внутри каждой process_group
:
-
operands...@receiver = [operand@sender for sender in process_group]
для всехreceiver
вprocess_group
. -
results...@process = concatenate(operands...@process, all_gather_dim)
для всехprocess
вprocess_group
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operands | вариативное число тензоров или потензорные квантованные тензоры | (С1), (С6) |
(И2) | all_gather_dim | константа типа si64 | (С1), (С6) |
(И3) | replica_groups | 2-мерная тензорная константа типа si64 | (С2-С4) |
(И4) | channel_id | константа типа si64 | (С5) |
(И5) | use_global_device_ids | константа типа i1 | (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С6) |
Ограничения
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_replicas
если используетсяcross_replica_and_partition
. -
num_processes
если используетсяflattened_ids
.
-
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Если
use_global_device_ids = true
, тоchannel_id > 0
. - (C6)
type(results...) = type(operands...)
за исключением:-
dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
-
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Семантика
В каждой группе процессов в сетке процессов StableHLO применяет computation
функции сокращения к значениям тензоров operands
каждого процесса и создает тензоры results
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
еслиchannel_id <= 0 and use_global_device_ids = false
. -
cross_replica_and_partition(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = false
. -
flattened_ids(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = true
.
Затем внутри каждой process_group
:
-
results...@process[result_index] = exec(schedule)
для некоторогоschedule
двоичного дерева, где:-
exec(node)
=computation(exec(node.left), exec(node.right))
. -
exec(leaf)
=leaf.value
.
-
-
schedule
— это двоичное дерево, определяемое реализацией, порядок обхода которого равенto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operands | вариативное число тензоров или потензорные квантованные тензоры | (С5), (С6) |
(И2) | replica_groups | вариатическое число одномерных тензорных констант типа si64 | (С1-С3) |
(И3) | channel_id | константа типа si64 | (С4) |
(И4) | use_global_device_ids | константа типа i1 | (С4) |
(И5) | computation | функция | (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С6-С7) |
Ограничения
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_replicas
если используетсяcross_replica_and_partition
. -
num_processes
если используетсяflattened_ids
.
-
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Если
use_global_device_ids = true
, тоchannel_id > 0
. - (C5)
computation
имеет тип(tensor<E>, tensor<E>) -> (tensor<E>)
гдеis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Семантика
Внутри каждой группы процессов в сетке процессов StableHLO значения тензоров operands
разделяются по split_dimension
на части, распределяются разделенные части между процессами, объединяются разбросанные части по concat_dimension
и выдаются тензоры results
. Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
, еслиchannel_id <= 0
. -
cross_partition(replica_groups)
еслиchannel_id > 0
.
Затем внутри каждой process_group
:
-
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
для всехsender
вprocess_group
. -
scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
, гдеreceiver_index = process_group.index(receiver)
. -
results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operands | вариативное число тензоров или потензорные квантованные тензоры | (С1-С3), (С9) |
(И2) | split_dimension | константа типа si64 | (С1), (С2), (С9) |
(И3) | concat_dimension | константа типа si64 | (С3), (С9) |
(И4) | split_count | константа типа si64 | (С2), (С4), (С8), (С9) |
(И5) | replica_groups | 2-мерная тензорная константа типа si64 | (С5-С8) |
(И6) | channel_id | константа типа si64 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С9) |
Ограничения
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_partitions
если используетсяcross_partition
.
-
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
кроме случая, когдаsplit_dimension != concat_dimension
:-
dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
. -
dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
-
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
и
Семантика
Выполняет поэлементное «И» двух тензоров lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое И.
- Для целых чисел: побитовое И.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор логического или целочисленного типа | (С1) |
(И2) | rhs | тензор логического или целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического или целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
Атан2
Семантика
Выполняет поэлементную операцию atan2 над тензорами lhs
и rhs
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
atan2
из IEEE-754. - Для комплексных чисел: комплекс atan2.
- Для квантованных типов:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
пакет_норм_град
Семантика
Вычисляет градиенты нескольких входных данных batch_norm_training
с обратным распространением ошибки от grad_output
и создает тензоры grad_operand
, grad_scale
и grad_offset
. Более формально эту операцию можно выразить как декомпозицию существующих операций StableHLO с использованием синтаксиса Python следующим образом:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Для квантованных типов выполняет dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1-С3), (С5) |
(И2) | scale | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4), (С5) |
(И3) | mean | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
(И4) | variance | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
(И5) | grad_output | тензор типа с плавающей запятой или потензорный квантованный тензор | (С2), (С3) |
(И6) | epsilon | константа типа f32 | |
(I7) | feature_index | константа типа si64 | (С1), (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
grad_operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С2), (С3) |
grad_scale | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
grad_offset | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
Ограничения
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
иgrad_offset
имеют одинаковыйbaseline_element_type
. - (C3)
operand
,grad_output
иgrad_operand
имеют одинаковую форму. - (C4)
scale
,mean
,variance
,grad_scale
иgrad_offset
имеют одинаковую форму. - (C5)
size(scale) = dim(operand, feature_index)
.
Примеры
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
пакетная норма_inference
Семантика
Нормализует тензор operand
по всем измерениям, кроме измерения feature_index
, и создает result
тензор. Более формально эту операцию можно выразить как декомпозицию существующих операций StableHLO с использованием синтаксиса Python следующим образом:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Для квантованных типов выполняет dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1-С7) |
(И2) | scale | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С3) |
(И3) | offset | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С4) |
(И4) | mean | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С5) |
(И5) | variance | Одномерный тензор с плавающей запятой или потензорный квантованный тип | (С2), (С6) |
(И6) | epsilon | константа типа f32 | |
(I7) | feature_index | константа типа si64 | (С1), (С3-С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С2), (С7) |
Ограничения
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
иresult
имеют один и тот жеbaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
Batch_norm_training
Семантика
Вычисляет среднее значение и дисперсию по всем измерениям, за исключением измерения feature_index
, и нормализует тензор operand
, создавая output
, тензоры batch_mean
и batch_var
. Более формально эту операцию можно выразить как декомпозицию существующих операций StableHLO с использованием синтаксиса Python следующим образом:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Для квантованных типов выполняет dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
(И2) | scale | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С3) |
(И3) | offset | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С4) |
(И4) | epsilon | константа типа f32 | (С1), (С3-С6) |
(И5) | feature_index | константа типа si64 | (С1), (С3-С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор типа с плавающей запятой или потензорный квантованный тензор | (С7) |
batch_mean | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С5) |
batch_var | Одномерный тензор с плавающей запятой или потензорный квантователь | (С2), (С6) |
Ограничения
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
иoutput
имеют один и тот жеbaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Примеры
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Семантика
Выполняет операцию побитового преобразования тензора operand
и создает тензор result
, в котором биты всего тензора operand
переинтерпретируются с использованием типа тензора result
.
Более формально, учитывая E = element_type(operand)
, E' = element_type(result)
и R = rank(operand)
:
- Если
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Если
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Если
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
возвращает представление данного значения в памяти, и его поведение определяется реализацией, поскольку точное представление тензоров определяется реализацией, а точное представление типов элементов также определяется реализацией.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С2) |
Ограничения
- (C1) Учитывая
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
иR = rank(operand)
:- Если
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Если
num_bits(E') < num_bits(E)
: -
rank(result) = R + 1
. -
dim(result, i) = dim(operand, i)
для всех0 <= i < R
. -
dim(result, R) * num_bits(E') = num_bits(E)
. - Если
num_bits(E') > num_bits(E)
: -
rank(result) = R - 1
. -
dim(result, i) = dim(operand, i)
для всех0 <= i < R
. -
dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Если
- (C2) Если
is_complex(operand) or is_complex(result)
, тоis_complex(operand) and is_complex(result)
.
Примеры
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
Broadcast_in_dim
Семантика
Расширяет размеры и/или ранг входного тензора путем дублирования данных в тензоре operand
и создает result
тензор. Более формально, result[result_index] = operand[operand_index]
где для всех d
в axes(operand)
:
-
operand_index[d] = 0
еслиdim(operand, d) = 1
. -
operand_index[d] = result_index[broadcast_dimensions[d]]
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С2), (С5-С6) |
(И2) | broadcast_dimensions | 1-мерная тензорная константа типа si64 | (С2-С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1), (С3), (С5-С6) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
,scales(operand)
иzero_points(operand)
могут отличаться отquantization_dimension(result)
,scales(result)
иzero_points(result)
соответственно, в противном случае.
-
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Для всех
d
вaxes(operand)
:-
dim(operand, d) = 1
или -
dim(operand, d) = dim(result, broadcast_dimensions[d])
.
-
- (C6) Если
is_per_axis_quantized(result)
:-
quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
. - Если
dim(operand, quantization_dimension(operand)) = 1
, тоscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
-
Примеры
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
случай
Семантика
Производит результат выполнения ровно одной функции из branches
в зависимости от значения index
. Более формально, result = selected_branch()
где:
-
selected_branch = branches[index]
if0 <= index < size(branches)
. -
selected_branch = branches[-1]
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | index | 0-мерный тензор типа si32 | |
(И2) | branches | вариативное число функций | (С1-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С4) |
Ограничения
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Примеры
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CBRT
Семантика
Выполняет поэлементную операцию кубического корня над тензором operand
и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
rootn(x, 3)
из IEEE-754. - Для комплексных чисел: комплексный кубический корень.
- Для квантованных типов:
dequantize_op_quantize(cbrt, operand, type(result))
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
клетка
Семантика
Выполняет поэлементную ячейку тензора operand
и создает тензор result
. Реализует операцию roundToIntegralTowardPositive
из спецификации IEEE-754. Для квантованных типов выполняется dequantize_op_quantize(ceil, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
холецкий
Семантика
Вычисляет разложение Холецкого пакета матриц.
Более формально, для всех i
в index_space(result)
, result[i0, ..., iR-3, :, :]
является разложением Холецкого a[i0, ..., iR-3, :, :]
, в виде нижне-треугольной (если lower
имеет true
) или верхнетреугольной (если lower
является false
) матрицы. Выходные значения в противоположном треугольнике, т.е. строгом верхнем треугольнике или строгом нижнем треугольнике соответственно, определяются реализацией.
Если существует i
, где входная матрица не является эрмитовой положительно определенной матрицей, то поведение не определено.
Для квантованных типов выполняется dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | a | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1-С3) |
(И2) | lower | 0-мерная тензорная константа типа i1 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Примеры
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
зажим
Семантика
Зажимает каждый элемент тензора operand
между минимальным и максимальным значением и создает result
тензор. Более формально, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element)
, где min_element = rank(min) = 0 ? min[] : min[result_index]
, max_element = rank(max) = 0 ? max[] : max[result_index]
. Для квантованных типов выполняется dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел для этой операции ( #560 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | min | тензорный или потензорный квантованный тензор | (С1), (С3) |
(И2) | operand | тензорный или потензорный квантованный тензор | (С1-С4) |
(И3) | max | тензорный или потензорный квантованный тензор | (С2), (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С4) |
Ограничения
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
коллективное_вещание
Семантика
В каждой группе процессов в сетке процессов StableHLO отправьте значение тензора operand
из исходного процесса в целевые процессы и создайте тензор result
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
, еслиchannel_id <= 0
. -
cross_partition(replica_groups)
еслиchannel_id > 0
.
После этого result@process
определяется следующим образом:
-
operand@process_groups[i, 0]
если существуетi
такой, что процесс находится вprocess_groupsprocess_groups[i]
. -
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
иначе.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С3) |
(И2) | replica_groups | вариатическое число одномерных тензорных констант типа si64 | (С1), (С2) |
(И3) | channel_id | константа типа si64 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С3) |
Ограничения
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, гдеN
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_partitions
если используетсяcross_partition
.
-
- (C3)
type(result) = type(operand)
.
Примеры
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
коллективный_пермуте
Семантика
Внутри каждой группы процессов в сетке процессов StableHLO отправляет значение тензора operand
из исходного процесса в целевой процесс и создает тензор result
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(source_target_pairs)
, еслиchannel_id <= 0
. -
cross_partition(source_target_pairs)
, еслиchannel_id > 0
.
После этого result@process
определяется следующим образом:
-
operand@process_groups[i, 0]
, если существуетi
такой,process_groups[i, 1] = process
. -
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
иначе.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С5) |
(И2) | source_target_pairs | 2-мерная тензорная константа типа si64 | (С1-С4) |
(И3) | channel_id | константа типа si64 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, гдеN
определяется как:-
num_replicas
если используетсяcross_replica
. -
num_partitions
если используетсяcross_partition
.
-
- (C5)
type(result) = type(operand)
.
Примеры
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
сравнивать
Семантика
Выполняет поэлементное сравнение тензоров lhs
и rhs
в соответствии с comparison_direction
и compare_type
и создает result
тензор.
Значения comparison_direction
и compare_type
имеют следующую семантику:
Для логических и целочисленных типов элементов:
-
EQ
:lhs = rhs
. -
NE
:lhs != rhs
. -
GE
:lhs >= rhs
. -
GT
:lhs > rhs
. -
LE
:lhs <= rhs
. -
LT
:lhs < rhs
.
Для типов элементов с плавающей запятой с compare_type = FLOAT
оператор реализует следующие операции IEEE-754:
-
EQ
:compareQuietEqual
. -
NE
:compareQuietNotEqual
. -
GE
:compareQuietGreaterEqual
. -
GT
:compareQuietGreater
. -
LE
:compareQuietLessEqual
. -
LT
:compareQuietLess
.
Для типов элементов с плавающей запятой с compare_type = TOTALORDER
оператор использует комбинацию операций totalOrder
и compareQuietEqual
из IEEE-754.
Для сложных типов элементов лексикографическое сравнение пар (real, imag)
выполняется с использованием предоставленных comparison_direction
и compare_type
. Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел, когда comparison_direction
— GE
, GT
, LE
или LT
( #560 ).
Для квантованных типов. выполняет dequantize_compare(lhs, rhs, comparison_direction)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1-С3) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1-С2) |
(И3) | comparison_direction | перечисление EQ , NE , GE , GT , LE и LT | |
(И4) | compare_type | перечисление FLOAT , TOTALORDER , SIGNED и UNSIGNED | (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического типа | (С2) |
Ограничения
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
определяется как:-
SIGNED
, еслиis_signed_integer(element_type(lhs))
. -
UNSIGNED
, еслиis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
. -
FLOAT
илиTOTALORDER
, еслиis_float(element_type(lhs))
. -
FLOAT
, еслиis_complex(element_type(lhs))
.
-
Примеры
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
сложный
Семантика
Выполняет поэлементное преобразование в комплексное значение из пары действительных и мнимых значений, lhs
и rhs
, и создает result
тензор.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор типа f32 или f64 | (С1-С3) |
(И2) | rhs | тензор типа f32 или f64 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор комплексного типа | (С2), (С3) |
Ограничения
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
имеет типcomplex<E>
, гдеE = element_type(lhs)
.
Примеры
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
композитный
Семантика
Инкапсулирует операцию, состоящую из других операций StableHLO, принимающую inputs
и composite_attributes
и выдающую results
. Семантика операции реализуется атрибутом decomposition
. composite
операцию можно заменить ее декомпозицией без изменения семантики программы. В тех случаях, когда встраивание декомпозиции не обеспечивает одинаковую семантику операций, предпочтительнее использовать custom_call
.
Поле version
(по умолчанию 0
) используется для обозначения изменения семантики композиции.
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | вариативное число значений |
(И2) | name | константа типа string |
(И3) | composite_attributes | словарь атрибутов |
(И4) | decomposition | константа типа string |
(И5) | version | константа типа si32 |
Выходы
Имя | Тип |
---|---|
results | вариативное число значений |
Ограничения
- (C1)
is_namespaced_op_name(name)
- (C2)
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Примеры
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
объединять
Семантика
Объединяет inputs
по dimension
в том же порядке, что и заданные аргументы, и создает result
тензор. Более формально, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, где:
-
id = d0 + ... + dk-1 + kd
. -
d
равноdimension
, аd0
,... — этоd
-е размерностьinputs
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (С1-С6) |
(И2) | dimension | константа типа si64 | (С2), (С4), (С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С5-С6) |
Ограничения
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
за исключениемdim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
за исключением:-
dim(result, dimension) = dim(inputs[0], dimension) + ...
.
-
Примеры
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
постоянный
Семантика
Создает output
тензор из постоянного value
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | value | постоянный | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор или квантованный тензор | (С1) |
Ограничения
- (C1)
type(value) = type(output)
.
Примеры
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
конвертировать
Семантика
Выполняет поэлементное преобразование одного типа элемента в другой в тензоре operand
и создает тензор result
.
Для преобразования логического значения в любой поддерживаемый тип значение false
преобразуется в ноль, а значение true
преобразуется в единицу. Для любого поддерживаемого преобразования типа в логическое значение нулевое значение преобразуется в false
, а ненулевые значения преобразуются в true
. Ниже показано, как это работает для сложных типов.
Для преобразований, включающих целое число в целое число , целое число в число с плавающей запятой или число с плавающей запятой в число с плавающей запятой , если исходное значение может быть точно представлено в целевом типе, значение результата будет таким же точным представлением. В противном случае поведение будет определено позднее ( #180 ).
Для преобразований с плавающей запятой в целое число дробная часть усекается. Если усеченное значение не может быть представлено в целевом типе, поведение будет определено позднее ( #180 ).
Преобразование комплексного значения в комплексное происходит так же, как и преобразование чисел с плавающей запятой при преобразовании действительных и мнимых частей.
Для преобразований «комплексный тип в любой другой» и «любой другой тип в комплекс» исходное мнимое значение игнорируется или конечное мнимое значение обнуляется соответственно. Преобразование вещественной части следует за преобразованиями с плавающей запятой.
В принципе, эта операция могла бы выражать деквантование (преобразование квантованных тензоров в регулярные тензоры), квантование (преобразование регулярных тензоров в квантованные тензоры) и реквантование (преобразование между квантованными тензорами), но на данный момент для этого у нас есть специальные операции — uniform_dequantize
для первый вариант использования и uniform_quantize
для второго и третьего вариантов использования. В будущем эти две операции могут быть объединены в convert
( #1576 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор | (С1) |
Ограничения
- (C1)
shape(operand) = shape(result)
.
Примеры
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
свертка
Семантика
Вычисляет скалярное произведение между окнами lhs
и срезами rhs
и выдает result
. На следующей диаграмме показано, как элементы result
вычисляются на основе lhs
и rhs
на конкретном примере.
Более формально, рассмотрим следующее переформулирование входных данных в терминах lhs
чтобы иметь возможность выражать окна lhs
:
-
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
. -
lhs_window_strides = lhs_shape(1, window_strides, 1)
. -
lhs_padding = lhs_shape([0, 0], padding, [0, 0])
. -
lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
. -
lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
В этом рефрейминге используются следующие вспомогательные функции:
-
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
. -
result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
. -
permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
гдеj[d] = i[permutation[d]]
.
Если feature_group_count = 1
и batch_group_count = 1
, то для всех output_spatial_index
в index_space(dim(result, output_spatial_dimensions...))
, result[result_shape(:, output_spatial_index, :)] = dot_product
где:
-
padding_value = constant(0, element_type(lhs))
. -
padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
. -
lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
. -
lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
. -
reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Эта функция, похоже, не используется, поэтому в будущем мы планируем ее удалить ( #1181 ). -
dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Если feature_group_count > 1
:
-
lhses = split(lhs, feature_group_count, input_feature_dimension)
. -
rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
. -
results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
. -
result = concatenate(results, output_feature_dimension)
.
Если batch_group_count > 1
:
-
lhses = split(lhs, batch_group_count, input_batch_dimension)
. -
rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
. -
results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
. -
result = concatenate(results, output_feature_dimension)
.
For quantized types, performs dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result))
.
For hybrid quantized types, performs hybrid_dequantize_then_op( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1), (С10-С11), (С14) (С25), (С27-С28), (С31-С32), (С34) |
(И2) | rhs | тензор или квантованный тензор | (С1), (С14-С16), (С25), (С27-С29), (С31-С34) |
(И3) | window_strides | 1-мерная тензорная константа типа si64 | (С2-С3), (С25) |
(И4) | padding | 2-мерная тензорная константа типа si64 | (С4), (С25) |
(И5) | lhs_dilation | 1-мерная тензорная константа типа si64 | (С5-С6), (С25) |
(И6) | rhs_dilation | 1-мерная тензорная константа типа si64 | (С7-С8), (С25) |
(I7) | window_reversal | 1-мерная тензорная константа типа i1 | (С9) |
(И8) | input_batch_dimension | константа типа si64 | (С10), (С13), (С25) |
(I9) | input_feature_dimension | константа типа si64 | (С11), (С13-С14) |
(I10) | input_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С12), (С13), (С25) |
(I11) | kernel_input_feature_dimension | константа типа si64 | (С14), (С18) |
(I12) | kernel_output_feature_dimension | константа типа si64 | (С15-С16), (С18), (С25), (С29) |
(I13) | kernel_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С17-С18), (С25) |
(I14) | output_batch_dimension | константа типа si64 | (С20), (С25) |
(I15) | output_feature_dimension | константа типа si64 | (С20), (С25), (С30) |
(I16) | output_spatial_dimensions | 1-мерная тензорная константа типа si64 | (С19-С20), (С25) |
(I17) | feature_group_count | константа типа si64 | (С11), (С14), (С16), (С21), (С23) |
(I18) | batch_group_count | константа типа si64 | (С10), (С15), (С22), (С23), (С25) |
(I19) | precision_config | вариативное количество перечислений DEFAULT , HIGH и HIGHEST | (С24) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С25-С28), (С30), (С32-34) |
Ограничения
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Учитывая
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:-
is_unique(input_dimensions)
. -
0 <= input_dimensions < N
.
-
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Учитывая
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:-
is_unique(kernel_dimensions)
. -
0 <= kernel_dimensions < N
.
-
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Учитывая
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:-
is_unique(output_dimensions)
. -
0 <= output_dimensions < N
.
-
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
определяется как:-
dim(lhs, input_batch_dimension) / batch_group_count
еслиresult_dim = output_batch_dimension
. -
dim(rhs, kernel_output_feature_dimension)
еслиresult_dim = output_feature_dimension
. -
num_windows
в противном случае, где: -
output_spatial_dimensions[spatial_dim] = result_dim
. -
lhs_dim = input_spatial_dimensions[spatial_dim]
. -
rhs_dim = kernel_spatial_dimensions[spatial_dim]
. -
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
. -
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
. -
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
. -
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
. -
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
-
- (C26)
rank(result) = N
. - Если в операции используются неквантованные тензоры:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Если в операции используются квантованные тензоры:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Если
is_per_axis_quantized(result)
, тоquantization_dimension(result) = output_feature_dimension
. - Если
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Если
is_per_tensor_quantized(rhs)
, тоis_per_tensor_quantized(result)
. - Если
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Примеры
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
косинус
Семантика
Выполняет поэлементную операцию косинуса над тензором operand
и создает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для чисел с плавающей запятой:
cos
из IEEE-754. - Для комплексных чисел: комплексный косинус.
- Для квантованных типов:
dequantize_op_quantize(cosine, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Семантика
Выполняет поэлементный подсчет количества ведущих нулевых битов в тензоре operand
и создает result
тензор.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
.
Примеры
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Семантика
Инкапсулирует определяемую реализацией операцию call_target_name
, которая принимает inputs
и called_computations
и выдает results
. has_side_effect
, backend_config
и api_version
могут использоваться для предоставления дополнительных метаданных, определяемых реализацией.
На данный момент эта операция содержит довольно неорганизованный набор метаданных, отражающий органическое развитие аналогичной операции в компиляторе XLA. В будущем мы планируем унифицировать эти метаданные ( #741 ).
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | вариативное число значений |
(И2) | call_target_name | константа типа string |
(И3) | has_side_effect | константа типа i1 |
(И4) | backend_config | константа типа string или словаря атрибутов |
(И5) | api_version | константа типа si32 |
(И6) | called_computations | вариативное число констант типа string |
Выходы
Имя | Тип |
---|---|
results | вариативное число значений |
Примеры
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
разделять
Семантика
Выполняет поэлементное деление тензоров делимого lhs
и rhs
делителя и создает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел: целочисленное деление, которое дает алгебраическое частное с отбрасыванием любой дробной части.
- Для поплавков:
division
от IEEE-754. - Для комплексных чисел: комплексное деление.
- Для квантованных типов:
-
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Семантика
Вычисляет скалярное произведение между срезами lhs
и rhs
и создает тензор result
.
Более формально, result[result_index] = dot_product
, где:
-
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
. -
rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
. -
result_batching_index + result_lhs_index + result_rhs_index = result_index
гдеsize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
иsize(result_rhs_index) = size(rhs_result_dimensions)
. -
transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
. -
transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
. -
reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
. -
transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
. -
transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
. -
reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
. -
dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Для квантованных типов выполняет dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Для гибридных квантованных типов выполняется hybrid_dequantize_then_op( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
управляет компромиссом между скоростью и точностью вычислений на серверных компонентах ускорителя. Это может быть одно из следующих (на данный момент семантика этих значений перечисления недостаточно определена, но мы планируем решить эту проблему в #755 ):
-
DEFAULT
: Самый быстрый расчет, но наименее точное приближение к исходному числу. -
HIGH
: более медленный расчет, но более точное приближение к исходному числу. -
HIGHEST
: Самый медленный расчет, но наиболее точное приближение к исходному числу.
DotAlgorithm
определяет основные свойства алгоритма, используемого для реализации операции с точкой, что также определяет точность. Если поля атрибутов алгоритма установлены, то для precision_config
должно быть DEFAULT
. DotAlgorithms
не имеют значения по умолчанию, поскольку параметры по умолчанию определяются реализацией. Таким образом, для всех полей алгоритма с точкой можно установить значение None
, чтобы указать алгоритм с пустой точкой, который вместо этого будет использовать значение precision_config
.
Поля DotAlgorithm
включают:
-
lhs_precision_type
иrhs_precision_type
— точность, до которой округляются левая и правая части операции. Типы точности не зависят от типов хранения входных и выходных данных. -
accumulation_type
точность, используемая для накопления. -
lhs_component_count
,rhs_component_count
иnum_primitive_operations
применяются, когда мы выполняем алгоритм, который разлагает LHS и/или RHS на несколько компонентов и выполняет несколько «примитивных» точечных операций над этими значениями — обычно для эмуляции более высокой точности (например, использование типа данных искусственного интеллекта bfloat16). Для вычислений более высокой точности : bf16_6x tf32_3x и т. д.). Для алгоритмов без декомпозиции эти значения должны быть установлены равными1
. -
allow_imprecise_accumulation
, чтобы указать, разрешено ли накопление с более низкой точностью для некоторых шагов (например,CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Пример атрибутов DotAlgorithm
:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Реализации должны решить, какие комбинации поддерживаются. В общем, не гарантируется, что каждый алгоритм поддерживается на каждом типе ускорителя потребителем StableHLO. Если данный алгоритм не поддерживается, следует выдать ошибку, а не возвращаться к альтернативе. Проверка StableHLO обеспечит максимальную проверку, предотвращая использование алгоритмов, о поддержке которых неизвестно ни на каком оборудовании.
См. xla_data.proto > Algorithm
для получения информации о некоторых поддерживаемых значениях алгоритма. Заявка № 2483 описывает план создания централизованного документа по поддерживаемым алгоритмам с помощью серверной части.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С5-С6), (С9-С10), (С12-С14), (С17-С18), (С20) |
(И2) | rhs | тензор или квантованный тензор | (С7-С10), (С12-С20) |
(И3) | lhs_batching_dimensions | 1-мерная тензорная константа типа si64 | (С1), (С3), (С5), (С9), (С12) |
(И4) | rhs_batching_dimensions | 1-мерная тензорная константа типа si64 | (С1), (С4), (С7), (С9) |
(И5) | lhs_contracting_dimensions | 1-мерная тензорная константа типа si64 | (С2), (С3), (С6), (С10) |
(И6) | rhs_contracting_dimensions | 1-мерная тензорная константа типа si64 | (С2), (С4), (С8), (С10), (С16) |
(I7) | precision_config | вариативное количество перечислений DEFAULT , HIGH и HIGHEST | (С11), (С21) |
(И8) | lhs_precision_type | FloatType или TensorFloat32 | (С21) |
(I9) | rhs_precision_type | FloatType или TensorFloat32 | (С21) |
(I10) | accumulation_type | FloatType или TensorFloat32 | (С21) |
(I11) | lhs_component_count | константа типа si32 | (С21), (С22) |
(I12) | rhs_component_count | константа типа si32 | (С21), (С23) |
(I13) | num_primitive_operations | константа типа si32 | (С21), (С24) |
(I14) | allow_imprecise_accumulation | константа типа bool | (С21) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С12), (С14), (С18-С20) |
Ограничения
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Если в операции используются неквантованные тензоры:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Если в операции используются квантованные тензоры:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs)
не находится вrhs_contracting_dimensions
. - Если
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Если
is_per_tensor_quantized(rhs)
, тоis_per_tensor_quantized(result)
. - Если
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Если
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Примеры
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
Dynamic_broadcast_in_dim
Семантика
Эта операция функционально идентична операции Broadcast_in_dim , но форма результата задается динамически через output_dimensions
.
Операция также принимает необязательные known_expanding_dimensions
, known_nonexpanding_dimensions
для выражения статических знаний о поведении измерений при расширении. Если не указано иное, предполагается, что все размеры могут расширяться.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С2), (С5-С6), (С9) |
(И2) | output_dimensions | 1-мерный тензор целочисленного типа | (С7) |
(И3) | broadcast_dimensions | Одномерный постоянный тензор целочисленного типа | (С2-С6) |
(И4) | known_expanding_dimensions | Одномерный постоянный тензор целочисленного типа | (С8-С9) |
(И5) | known_nonexpanding_dimensions | Одномерный постоянный тензор целочисленного типа | (С8-С9) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1), (С3), (С5-С7) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
,scales(operand)
иzero_points(operand)
могут отличаться отquantization_dimension(result)
,scales(result)
иzero_points(result)
соответственно, в противном случае.
-
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Для всех
d
вaxes(operand)
:-
dim(operand, d) = 1
или -
dim(operand, d) = dim(result, broadcast_dimensions[d])
.
-
- (C6) Если
is_per_axis_quantized(result)
:-
quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
. - Если
dim(operand, quantization_dimension(operand)) = 1
, тоscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
-
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Примеры
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
динамический_конв
Семантика
Эта операция функционально идентична операции свертки , но заполнение задается динамически с помощью padding
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1), (С10-С11), (С14) (С25), (С26-С27), (С30-С31), (С33) |
(И2) | rhs | тензор или квантованный тензор | (С1), (С14-С16), (С26-С28), (С30-С33) |
(И3) | padding | 2-мерный тензор целочисленного типа | (С4) |
(И4) | window_strides | 1-мерная тензорная константа типа si64 | (С2-С3) |
(И5) | lhs_dilation | 1-мерная тензорная константа типа si64 | (С5-С6) |
(И6) | rhs_dilation | 1-мерная тензорная константа типа si64 | (C7-C8) |
(I7) | window_reversal | 1-мерная тензорная константа типа i1 | (С9) |
(И8) | input_batch_dimension | константа типа si64 | (C10), (C13) |
(I9) | input_feature_dimension | константа типа si64 | (С11), (С13-С14) |
(I10) | input_spatial_dimensions | 1-мерная тензорная константа типа si64 | (C12), (C13) |
(I11) | kernel_input_feature_dimension | константа типа si64 | (С14), (С18) |
(I12) | kernel_output_feature_dimension | константа типа si64 | (C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions | 1-мерная тензорная константа типа si64 | (C17-C18) |
(I14) | output_batch_dimension | константа типа si64 | (C20) |
(I15) | output_feature_dimension | константа типа si64 | (C20), (C29) |
(I16) | output_spatial_dimensions | 1-мерная тензорная константа типа si64 | (C19-C20) |
(I17) | feature_group_count | константа типа si64 | (С11), (С14), (С16), (С21), (С23) |
(I18) | batch_group_count | константа типа si64 | (C10), (C15), (C22), (C23) |
(I19) | precision_config | вариативное количество перечислений DEFAULT , HIGH и HIGHEST | (С24) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (C25-C27), (C29), (C31-C33) |
Ограничения
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Учитывая
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:-
is_unique(input_dimensions)
. -
0 <= input_dimensions < N
.
-
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Учитывая
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:-
is_unique(kernel_dimensions)
. -
0 <= kernel_dimensions < N
.
-
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Учитывая
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:-
is_unique(output_dimensions)
. -
0 <= output_dimensions < N
.
-
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
определяется как:-
dim(lhs, input_batch_dimension) / batch_group_count
еслиresult_dim = output_batch_dimension
. -
dim(rhs, kernel_output_feature_dimension)
еслиresult_dim = output_feature_dimension
. -
num_windows
в противном случае, где: -
output_spatial_dimensions[spatial_dim] = result_dim
. -
lhs_dim = input_spatial_dimensions[spatial_dim]
. -
rhs_dim = kernel_spatial_dimensions[spatial_dim]
. -
dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
. -
padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
. -
dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
. -
is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
. -
num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
-
- (C26)
rank(result) = N
. - Если в операции используются неквантованные тензоры:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Если в операции используются квантованные тензоры:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Если
is_per_axis_quantized(rhs)
, тоquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Если
is_per_axis_quantized(result)
, тоquantization_dimension(result) = output_feature_dimension
. - Если
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Если
is_per_tensor_quantized(rhs)
, тоis_per_tensor_quantized(result)
. - Если
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Примеры
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
Dynamic_gather
Семантика
Эта операция функционально идентична для сбора OP, с динамически указанными slice_sizes
как значение.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1), (C7), (C10-C12), (C14) |
(И2) | start_indices | тензор целочисленного типа | (C2), (C3), (C13) |
(И3) | slice_sizes | 1-мерный тензор целочисленного типа | (C8), (C11-C13) |
(И4) | offset_dims | 1-мерная тензорная константа типа si64 | (C1), (C4-C5), (C13) |
(И5) | collapsed_slice_dims | 1-мерная тензорная константа типа si64 | (C1), (C6-C8), (C13) |
(И6) | start_index_map | 1-мерная тензорная константа типа si64 | (C3), (C9), (C10) |
(I7) | index_vector_dim | константа типа si64 | (C2), (C3), (C13) |
(И8) | indices_are_sorted | константа типа i1 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (C5), (C13-C14) |
Ограничения
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
Где:-
batch_dim_sizes = shape(start_indices)
за исключением того, что размер размераstart_indices
соответствующийindex_vector_dim
, не включен. -
offset_dim_sizes = shape(slice_sizes)
за исключением того, что размеры измерений вslice_sizes
соответствующиеcollapsed_slice_dims
, не включены. -
combine
ставитbatch_dim_sizes
на оси, соответствующиеbatch_dims
иoffset_dim_sizes
по осям, соответствующимoffset_dims
.
-
- (C14)
element_type(operand) = element_type(result)
.
Примеры
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
Dynamic_iota
Семантика
Эта операция функционально идентична iota OP, но форма результата определяется динамически через output_shape
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | output_shape | 1-мерный тензор целочисленного типа | (С1), (С2) |
(И2) | iota_dimension | si64 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С2) |
Ограничения
- (C1)
0 <= iota_dimension < size(output_shape)
. - (C2)
rank(result) = size(output_shape)
.
Примеры
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
Dynamic_pad
Семантика
Эта операция функционально идентична для Pad Op, но с edge_padding_low
, edge_padding_high
и interior_padding
, указанным динамически как значения.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1), (C2), (C4) |
(И2) | padding_value | 0-мерный тензор или квансор на квансор | (С1) |
(И3) | edge_padding_low | 1-мерный тензор целочисленного типа | (C1), (C4) |
(И4) | edge_padding_high | 1-мерный тензор целочисленного типа | (C1), (C4) |
(И5) | interior_padding | 1-мерный тензор целочисленного типа | (С2-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (C3-C6) |
Ограничения
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Примеры
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
Dynamic_Reshape
Семантика
Эта операция функционально идентична RESHAPE OP, но форма результата определяется динамически через output_shape
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С3) |
(И2) | output_shape | 1-мерный тензор целочисленного типа | (С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С4) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
иquantization_dimension(result)
могут отличаться, в противном случае.
-
- (C2)
size(operand) = size(result)
. - (C3) Если
is_per_axis_quantized(operand)
:-
reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
. -
dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
. -
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
-
- (C4)
size(output_shape) = rank(result)
.
Примеры
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
Dynamic_slice
Семантика
Извлекает срез из operand
используя динамически вычисленные начальные индексы и дает тензор result
. start_indices
содержит начальные показатели среза для каждого измерения, подверженного потенциальной регулировке, а slice_sizes
содержат размеры среза для каждого измерения. Более формально, result[result_index] = operand[operand_index]
где:
-
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
. -
operand_index = adjusted_start_indices + result_index
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1), (C2), (C4) |
(И2) | start_indices | Вариальное число 0-мерных тензоров типа целочисленного целого типа | (С2), (С3) |
(И3) | slice_sizes | 1-мерная тензорная константа типа si64 | (С2), (С4), (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1), (С5) |
Ограничения
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Примеры
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Семантика
Получает тензор result
, который равен тензору operand
, за исключением того, что срез, начинающийся при start_indices
обновляется с значениями в update
. Более формально, result[result_index]
определяется как:
-
update[update_index]
, если0 <= update_index < shape(update)
где:-
adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
. -
update_index = result_index - adjusted_start_indices
.
-
-
operand[result_index]
в противном случае.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1-C4), (C6) |
(И2) | update | тензорный или потензорный квантованный тензор | (C2), (C3), (C6) |
(И3) | start_indices | Вариальное число 0-мерных тензоров типа целочисленного целого типа | (C4), (C5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Примеры
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
экспоненциальный
Семантика
Выполняет экспоненциальную экспоненциальную операцию на тензоре operand
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
exp
от IEEE-754. - Для сложных чисел: сложная экспонента.
- Для квантовых типов:
dequantize_op_quantize(exponential, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
EXPONELINEL_MINUS_ONE
Семантика
Выполняет экспоненциальную экспоненциальную операцию на One One One One operand
Tensor и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
expm1
от IEEE-754. - Для сложных чисел: сложный экспоненциальный минус один.
- Для квантовых типов:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
фф
Семантика
Выполняет прямое и обратное преобразование Фурье для реальных и сложных входов/выходов.
fft_type
является одним из следующих:
-
FFT
: форвардный комплекс к комплексу FFT. -
IFFT
: FFT обратного комплекса к комплексу. -
RFFT
: Forward Real-Complex FFT. -
IRFFT
: обратный реальный для комплекса БПФ (т.е. сложный, возвращает реальные).
Более формально, учитывая функцию fft
, которая принимает 1-мерные тензоры сложных типов в качестве входных данных, создает 1-мерные тензоры с теми же типами, что и вывод, и вычисляет дискретное преобразование Фурье:
Для fft_type = FFT
result
определяется как конечный результат серии L вычислений, где L = size(fft_length)
. Например, для L = 3
:
-
result1[i0, ..., :] = fft(operand[i0, ..., :])
. -
result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Кроме того, учитывая функцию ifft
, которая имеет одинаковую сигнатуру типа, и вычисляет обратный fft
:
Для fft_type = IFFT
result
определяется как обратное вычисления для fft_type = FFT
. Например, для L = 3
:
-
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
. -
result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Кроме того, учитывая функцию rfft
, которая принимает 1-мерные тензоры типов с плавающей точкой, создает 1-мерные тензоры сложных типов одной и той же семантики с плавающей точкой и работает следующим образом:
-
rfft(real_operand) = truncated_result
-
complex_operand... = (real_operand..., 0.0)
. -
complex_result = fft(complex_operand)
. -
truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Когда дискретное преобразование Фурье вычисляется для реальных операндов, первые элементы N/2 + 1
результата однозначно определяют остальную часть результата, поэтому результат rfft
усечен, чтобы избежать вычисления избыточных элементов).
Для fft_type = RFFT
result
определяется как конечный результат серии L вычислений, где L = size(fft_length)
. Например, для L = 3
:
-
result1[i0, ..., :] = rfft(operand[i0, ..., :])
. -
result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Наконец, учитывая функцию irfft
, которая имеет одинаковую подпись типа и вычисляет обратную rfft
:
Для fft_type = IRFFT
result
определяется как обратное вычисления для fft_type = RFFT
. Например, для L = 3
:
-
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
. -
result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
. -
result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | Тензор плавающей точки или сложного типа | (C1), (C2), (C4), (C5) |
(И2) | fft_type | enum of FFT , IFFT , RFFT и IRFFT | (С2), (С5) |
(И3) | fft_length | 1-мерная тензорная константа типа si64 | (C1), (C3), (C4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | Тензор плавающей точки или сложного типа | (С2), (С4), (С5) |
Ограничения
- (C1)
size(fft_length) <= rank(operand)
. - (C2) взаимосвязь между типами
operand
иresult
варьируется:- Если
fft_type = FFT
,element_type(operand)
иelement_type(result)
имеют одинаковый комплексный тип. - Если
fft_type = IFFT
,element_type(operand)
иelement_type(result)
имеют одинаковый комплексный тип. - Если
fft_type = RFFT
,element_type(operand)
-это тип с плавающей точкой, аelement_type(result)
-сложный тип той же семантики с плавающей точкой. - Если
fft_type = IRFFT
,element_type(operand)
-это сложный тип, аelement_type(result)
-тип с плавающей точкой той же семантики с плавающей запятой.
- Если
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Если среди
operand
иresult
, существует тензорreal
типа с плавающей точкой, тоshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
за исключением:- Если
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Если
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Если
Примеры
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
пол
Семантика
Выполняет элементный этаж тензора operand
и дает тензор result
. Реализует roundToIntegralTowardNegative
Operation из спецификации IEEE-754. Для квантовых типов выполняет dequantize_op_quantize(floor, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
собирать
Семантика
Собирает ломтики от operand
из смещений, указанных в start_indices
и дает тензор result
.
На следующей диаграмме показано, как элементы в карте result
на элементах в operand
с использованием конкретного примера. Диаграмма выбирает несколько примеров индексов result
и подробно объясняет, какие индексы operand
они соответствуют.
Более формально, result[result_index] = operand[operand_index]
где:
-
batch_dims = [d for d in axes(result) and d not in offset_dims]
. -
batch_index = result_index[batch_dims...]
. -
start_index
определяется как:-
start_indices[bi0, ..., :, ..., biN]
, гдеbi
являются отдельными элементами вbatch_index
и:
вставлены в индексindex_vector_dim
, еслиindex_vector_dim
<rank(start_indices)
. -
[start_indices[batch_index]]
в противном случае.
-
- Для
d_operand
вaxes(operand)
,-
full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
еслиd_operand = start_index_map[d_start]
. -
full_start_index[d_operand] = 0
в противном случае.
-
- Для
d_operand
вaxes(operand)
,-
full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
, еслиd_operand = operand_batching_dims[i_batching]
иd_start = start_indices_batching_dims[i_batching]
. -
full_batching_index[d_operand] = 0
в противном случае.
-
-
offset_index = result_index[offset_dims...]
. -
full_offset_index = [oi0, ..., 0, ..., oiN]
, гдеoi
являются отдельными элементами вoffset_index
, а0
вставляется по индексам изcollapsed_slice_dims
иoperand_batching_dims
. -
operand_index = full_start_index + full_batching_index + full_offset_index
.
Если indices_are_sorted
является true
, то реализация может предположить, что start_indices
сортируется по отношению к start_index_map
, в противном случае поведение не определен. Более формально для всех i1 < i2
из indices(result)
, full_start_index(i1) <= full_start_index(i2)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(И2) | start_indices | тензор целочисленного типа | (C2-C3), (C14), (C17), (C22) |
(И3) | offset_dims | 1-мерная тензорная константа типа si64 | (C1), (C4-C5), (C22) |
(И4) | collapsed_slice_dims | 1-мерная тензорная константа типа si64 | (C1), (C6-C9), (C22) |
(И5) | operand_batching_dims | 1-мерная тензорная константа типа si64 | (C1), (C6), (C10-C12), (C16-C18), (C22) |
(И6) | start_indices_batching_dims | 1-мерная тензорная константа типа si64 | (C13-C17) |
(I7) | start_index_map | 1-мерная тензорная константа типа si64 | (C3), (C18-C19) |
(И8) | index_vector_dim | константа типа si64 | (C2-C3), (C15), (C22) |
(I9) | slice_sizes | 1-мерная тензорная константа типа si64 | (C9), (C12), (C20-C22) |
(I10) | indices_are_sorted | константа типа i1 |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (C5), (C22-C23) |
Ограничения
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
. - (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
Где:-
batch_dim_sizes = shape(start_indices)
за исключением того, что размер размераstart_indices
соответствующийindex_vector_dim
, не включен. -
offset_dim_sizes = slice_sizes
за исключением того, что размеры размеров вslice_sizes
соответствующиеcollapsed_slice_dims
иoperand_batching_dims
не включены. -
combine
ставитbatch_dim_sizes
на оси, соответствующиеbatch_dims
иoffset_dim_sizes
по осям, соответствующимoffset_dims
.
-
- (C23)
element_type(operand) = element_type(result)
.
Примеры
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Семантика
Производит размер данного dimension
operand
. Более формально, result = dim(operand, dimension)
. Семантика касается только компонента формы типа. Тип элемента может быть чем угодно.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1) |
(И2) | dimension | константа типа si64 | (С1) |
Выходы
Имя | Тип |
---|---|
result | 0-мерный тензор типа si32 |
Ограничения
- (C1)
0 <= dimension < rank(operand)
.
Примеры
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Семантика
Извлекать элемент в положении index
operand
и дает result
. Более формально, result = operand[index]
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | кортеж | (С1), (С2) |
(И2) | index | константа типа si32 | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | любой поддерживаемый тип | (С2) |
Ограничения
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Примеры
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
если
Семантика
Создает выход из выполнения ровно одной функции из true_branch
или false_branch
в зависимости от значения pred
. Более формально, result = pred ? true_branch() : false_branch()
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | pred | 0-мерный тензор типа i1 | |
(И2) | true_branch | функция | (С1-С3) |
(И3) | false_branch | функция | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С3) |
Ограничения
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Примеры
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
изображение
Семантика
Извлекает воображаемую часть, по элементу, из operand
и дает тензор result
. Более формально, для каждого элемента x
: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | Тензор плавающей точки или сложного типа | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | Тензор типа с плавающей точкой | (С1), (С2) |
Ограничения
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
определяется как:-
complex_element_type(element_type(operand))
ifis_complex(operand)
. -
element_type(operand)
в противном случае.
-
Примеры
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
подача
Семантика
Считывает данные из добычи и дает results
.
Семантика infeed_config
определяется реализацией.
results
состоят из значений полезной нагрузки, которые поступают первым, и токен, который наступает последним. В будущем мы планируем разделить полезную нагрузку и токен на два отдельных выхода для улучшения ясности ( #670 ).
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | token | token |
(И2) | infeed_config | константа типа string |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С1-С3) |
Ограничения
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
илиis_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Примеры
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
йота
Семантика
Заполняет output
тенсор со значениями в увеличении порядка, начиная с нуля вдоль измерения iota_dimension
. Более формально,
output[output_index] = constant(is_quantized(output) ? quantize(output_index[iota_dimension], element_type(output)) : output_index[iota_dimension], element_type(output))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | iota_dimension | si64 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
0 <= iota_dimension < rank(output)
.
Примеры
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Семантика
Выполняет элементы, проверьте ли значение x
конечным (т.е. не +inf, -inf или nan) и производит y
. Реализует операцию isFinite
из спецификации IEEE-754. Для квантованных типов результат всегда true
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | x | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
y | тензор логического типа | (С1) |
Ограничения
- (C1)
shape(x) = shape(y)
.
Примеры
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
бревно
Семантика
Выполняет элементную работу логарифма на тензоре operand
и получает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавок:
log
от IEEE-754. - Для сложных чисел: сложный логарифм.
- Для квантованных типов:
dequantize_op_quantize(log, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Семантика
Выполняет элементный логарифм плюс одна операция на operand
тензоре и получает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавков:
logp1
от IEEE-754. - Для сложных чисел: сложный логарифм плюс один.
- Для квантовых типов:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
логистический
Семантика
Выполняет элементную логистическую операцию на operand
Tensor и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для поплавок:
division(1, addition(1, exp(-x)))
от IEEE-754. - Для сложных чисел: сложная логистика.
- Для квантовых типов:
dequantize_op_quantize(logistic, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
карта
Семантика
Применяет computation
функции карты к inputs
вдоль dimensions
и дает тензор result
.
Более формально, result[result_index] = computation(inputs...[result_index])
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (С1-С4) |
(И2) | dimensions | 1-мерная тензорная константа типа si64 | (С3) |
(И3) | computation | функция | (С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (C1), (C4) |
Ограничения
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
имеет тип(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
гдеEi = element_type(inputs[i])
иE' = element_type(result)
.
Примеры
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
максимум
Семантика
Выполняет элементную работу максимальной работы на Tensors lhs
и rhs
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое ИЛИ.
- Для целых чисел: целое число максимум.
- Для поплавков:
maximum
от IEEE-754. - Для сложных чисел: лексикографический максимум для
(real, imaginary)
пары. Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел для этой операции ( #560 ). - Для квантованных типов:
-
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
минимум
Семантика
Выполняет элементную операцию на мин на Tensors lhs
и rhs
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое И.
- Для целых чисел: целочисленное минимум.
- Для поплавков:
minimum
от IEEE-754. - Для сложных чисел: лексикографический минимум для
(real, imaginary)
пары. Наложение порядка на комплексные числа требует удивительной семантики, поэтому в будущем мы планируем удалить поддержку комплексных чисел для этой операции ( #560 ). - Для квантованных типов:
-
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
умножать
Семантика
Выполняет элементарный продукт двух тензоров lhs
и rhs
и дает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое И.
- Для целых чисел: целочисленное умножение.
- Для поплавок:
multiplication
от IEEE-754. - Для сложных чисел: сложное умножение.
- Для квантованных типов:
-
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензорный или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензорный или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
отрицать
Семантика
Выполняет элементное отрицание operand
тензора и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для подписанных целых чисел: целочисленное отрицание.
- Для не подписываемых целых чисел: Bitcast для подписанного целого числа, целочисленного отрицания, Bitcast обратно в Unsigned Integer.
- Для поплавков:
negate
от IEEE-754. - Для сложных чисел: сложное отрицание.
- Для квантовых типов:
dequantize_op_quantize(negate, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
нет
Семантика
Выполняет элементы не тензора operand
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для логических: логично нет.
- Для целых чисел: кусочке нет.
Аргументы
Имя | Тип | Ограничения |
---|---|---|
operand | тензор логического или целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического или целочисленного типа | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
.
Примеры
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
Optimization_Barrier
Семантика
Гарантирует, что операции, которые производят operand
выполняются до любых операций, которые зависят от result
и предотвращают перемещающиеся преобразования компилятора через барьер. Кроме этого, операция является личности, то есть result = operand
.
Аргументы
Имя | Тип | Ограничения |
---|---|---|
operand | Вариальное количество тензоров, кванторов-кванторов или токенов | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | Вариальное количество тензоров, кванторов-кванторов или токенов | (С1) |
Ограничения
- (C1)
type(operand...) = type(result...)
.
Примеры
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
или
Семантика
Выполняет элементные или два тензора lhs
и rhs
и дает result
тензор. В зависимости от типа элемента выполняет следующие действия:
- Для логических значений: логическое ИЛИ.
- Для целых чисел: бить или.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного или логического типа | (С1) |
(И2) | rhs | тензор целочисленного или логического типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного или логического типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
перекормить
Семантика
Записывает inputs
в Outfeed и производит токен result
.
Семантика outfeed_config
определяется реализацией.
Входы
Этикетка | Имя | Тип |
---|---|---|
(I1) | inputs | Вариальное количество тензоров или квантовых тензоров |
(И2) | token | token |
(И3) | outfeed_config | константа типа string |
Выходы
Имя | Тип |
---|---|
result | token |
Примеры
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
подушечка
Семантика
Расширяется operand
, пробиваясь вокруг тензора, а также между элементами тензора с данным padding_value
.
edge_padding_low
и edge_padding_high
Укажите количество заполнения, добавленное в низком уровне (рядом с индексом 0) и высококачественным (рядом с самым высоким индексом) каждого измерения соответственно. Количество прокладки может быть отрицательным, где абсолютное значение отрицательного прокладки указывает количество элементов для удаления из указанного измерения.
interior_padding
указывает количество прокладки, добавленное между любыми двумя элементами в каждом измерении, которое может быть негативным. Внутренняя прокладка происходит перед набережной, так что отрицательная края набережная удаляет элементы из операнда с внутренними укладками.
Более формально, result[result_index]
определяется как:
-
operand[operand_index]
, еслиresult_index = edge_padding_low + operand_index * (interior_padding + 1)
. -
padding_value
иначе.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1), (C2), (C4) |
(И2) | padding_value | 0-мерный тензор или квансор на квансор | (С1) |
(И3) | edge_padding_low | 1-мерная тензорная константа типа si64 | (C1), (C4) |
(И4) | edge_padding_high | 1-мерная тензорная константа типа si64 | (C1), (C4) |
(И5) | interior_padding | 1-мерная тензорная константа типа si64 | (С2-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (C3-C6) |
Ограничения
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Примеры
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Семантика
Производит partition_id
текущего процесса.
Выходы
Имя | Тип |
---|---|
result | 0-мерный тензор типа ui32 |
Примеры
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
Popcnt
Семантика
Выполняет элемент количества битов, установленных в тензоре operand
, и дает тензор result
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(operand) = type(result)
.
Примеры
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
власть
Семантика
Выполняет элементное экспонент в тензоре lhs
Tensor rhs
и дает тензор result
. В зависимости от типа элемента выполняет следующие действия:
- Для целых чисел: целочисленное экспонентация.
- Для поплавков:
pow
от IEEE-754. - Для сложных чисел: сложное экспонент.
- Для квантовых типов:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
настоящий
Семантика
Извлекает реальную часть, по элементу, из operand
и дает тензор result
. Более формально для каждого элемента x
: real(x) = is_complex(x) ? real_part(x) : x
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | Тензор плавающей точки или сложного типа | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | Тензор типа с плавающей точкой | (С1), (С2) |
Ограничения
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
определяется как:-
complex_element_type(element_type(operand))
ifis_complex(operand)
. -
element_type(operand)
в противном случае.
-
Примеры
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
получение
Семантика
Получает данные из канала с channel_id
и дает results
.
Если is_host_transfer
является true
, то операция передает данные с хоста. В противном случае он передает данные с другого устройства. Что это значит, определяется реализацией. Этот флаг дублирует информацию, представленную в channel_type
, поэтому в будущем мы планируем сохранить только один из них ( #666 ).
results
состоят из значений полезной нагрузки, которые поступают первым, и токен, который наступает последним. В будущем мы планируем разделить полезную нагрузку и токен на два отдельных выхода для улучшения ясности ( #670 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | token | token | (С4) |
(И2) | channel_id | константа типа si64 | |
(И3) | channel_type | enum DEVICE_TO_DEVICE и HOST_TO_DEVICE | (С1) |
(И4) | is_host_transfer | константа типа i1 | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С2-С4) |
Ограничения
- (C1)
channel_type
определяется как:-
HOST_TO_DEVICE
ifis_host_transfer = true
, -
DEVICE_TO_DEVICE
иначе.
-
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
илиis_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Примеры
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
уменьшать
Семантика
Применяет body
функции восстановления к inputs
и init_values
вдоль dimensions
и дает results
тензоры.
Порядок сокращения определяется реализацией, что означает, что body
и init_values
должны образовывать моноид, чтобы гарантировать, что операция дает одинаковые результаты для всех входов на все реализации. Тем не менее, это условие не содержится для многих популярных сокращений. Например, добавление с плавающей точкой для body
и ноль для init_values
на самом деле не образует моноид, поскольку добавление с плавающей точкой не является ассоциативным.
Более формально, results...[j0, ..., jR-1] = reduce(input_slices_converted)
Где:
-
input_slices = inputs...[j0, ..., :, ..., jR-1]
, где:
вставлены вdimensions
. -
input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
. -
init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
. -
reduce(input_slices_converted) = exec(schedule)
для некоторогоschedule
двоичных деревьев, где:-
exec(node) = body(exec(node.left), exec(node.right))
. -
exec(leaf) = leaf.value
.
-
-
schedule
is an implementation-defined full binary tree whose in-order traversal consists of:-
input_slices_converted...[index]
values, for allindex
inindex_space(input_slices_converted)
in the ascending lexicographic order ofindex
. - Interspersed with an implementation-defined amount of
init_values_converted
at implementation-defined positions.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (C1-C4), (C6), (C7) |
(И2) | init_values | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (С2), (С3) |
(И3) | dimensions | 1-мерная тензорная константа типа si64 | (C4), (C5), (C7) |
(И4) | body | функция | (С6) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (C3), (C7), (C8) |
Ограничения
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
has type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
whereis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
except that the dimension sizes ofinputs...
corresponding todimensions
are not included. - (C8)
element_type(results[i]) = Ei
for alli
in[0,N)
.
Примеры
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Семантика
Performs element-wise conversion of operand
to another floating-point type that uses exponent_bits
and mantissa_bits
and back to the original floating-point type and produces an output
tensor.
More formally:
- The mantissa bits of the original value are updated to round the original value to the nearest value representable with
mantissa_bits
usingroundToIntegralTiesToEven
semantics. - Then, if
mantissa_bits
are smaller than the number of mantissa bits of the original value, the mantissa bits are truncated tomantissa_bits
. - Then, if the exponent bits of the intermediate result don't fit into the range provided by
exponent_bits
, the intermediate result overflows to infinity using the original sign or underflows to zero using the original sign. - For quantized types, performs
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
(И2) | exponent_bits | константа типа si32 | (С2) |
(И3) | mantissa_bits | константа типа si32 | (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Примеры
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Семантика
Within each process group in the StableHLO process grid, performs reduction, using computations
, over the values of the operand
tensor from each process, splits the reduction result along scatter_dimension
into parts, and scatters the split parts between the processes to produce the result
.
Эта операция разбивает сетку процессов StableHLO на process_groups
, которые определяются следующим образом:
-
cross_replica(replica_groups)
еслиchannel_id <= 0 and use_global_device_ids = false
. -
cross_replica_and_partition(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = false
. -
flattened_ids(replica_groups)
еслиchannel_id > 0 and use_global_device_ids = true
.
Затем внутри каждой process_group
:
-
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
. -
parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
. -
result@receiver = parts@sender[receiver_index]
for allsender
inprocess_group
, wherereceiver_index = process_group.index(receiver)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1), (C2), (C7), (C8) |
(И2) | scatter_dimension | константа типа si64 | (C1), (C2), (C8) |
(И3) | replica_groups | 2-мерная тензорная константа типа si64 | (C3-C5) |
(И4) | channel_id | константа типа si64 | (С6) |
(И5) | use_global_device_ids | константа типа i1 | (С6) |
(И6) | computation | функция | (С7) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С8-С9) |
Ограничения
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
is defined as:-
num_replicas
если используетсяcross_replica
. -
num_replicas
если используетсяcross_replica_and_partition
. -
num_processes
если используетсяflattened_ids
.
-
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) If
use_global_device_ids = true
, thenchannel_id > 0
. - (C7)
computation
has type(tensor<E>, tensor<E>) -> (tensor<E>)
whereis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
except:-
dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
-
- (C9)
element_type(result) = E
.
Примеры
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Семантика
Applies a reduction function body
to windows of inputs
and init_values
and produces results
.
The following diagram shows how elements in results...
are computed from inputs...
using a concrete example.
More formally, results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(see reduce ) where:
-
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
. -
window_start = result_index * window_strides
. -
window_end = window_start + (window_dimensions - 1) * window_dilations + 1
. -
windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
(И2) | init_values | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (C1), (C13) |
(И3) | window_dimensions | 1-мерная тензорная константа типа si64 | (C4), (C5), (C15) |
(И4) | window_strides | 1-мерная тензорная константа типа si64 | (C6), (C7), (C15) |
(И5) | base_dilations | 1-мерная тензорная константа типа si64 | (C8), (C9), (C15) |
(И6) | window_dilations | 1-мерная тензорная константа типа si64 | (C10), (C11), (C15) |
(I7) | padding | 2-мерная тензорная константа типа si64 | (C12), (C15) |
(И8) | body | функция | (C13) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (C1), (C14-C16) |
Ограничения
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
has type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
whereis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
where:-
dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
. -
padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
. -
dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
. -
is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
. -
num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
-
- (C16)
element_type(results[i]) = Ei
for alli
in[0,N)
.
Примеры
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
остаток
Семантика
Performs element-wise remainder of dividend lhs
and divisor rhs
tensors and produces a result
tensor.
More formally, the sign of the result is taken from the dividend, and the absolute value of the result is always less than the divisor's absolute value. The remainder is calculated as lhs - d * rhs
, where d
is given by:
- For integers:
stablehlo.divide(lhs, rhs)
. - For floats:
division(lhs, rhs)
from IEEE-754 with rounding attributeroundTowardZero
. - For complex numbers: TBD ( #997 ).
- Для квантованных типов:
-
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
-
For floating-point element types, this operation is in contrast with the remainder
operation from IEEE-754 specification where d
is an integral value nearest to the exact value of lhs/rhs
with ties to even.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Семантика
Produces replica_id
of the current process.
Выходы
Имя | Тип |
---|---|
result | 0-мерный тензор типа ui32 |
Примеры
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
изменить форму
Семантика
Performs reshape of operand
tensor to a result
tensor. Conceptually, it amounts to keeping the same canonical representation but potentially changing the shape, eg from tensor<2x3xf32>
to tensor<3x2xf32>
or tensor<6xf32>
.
More formally, result[result_index] = operand[operand_index]
where result_index
and operand_index
have the same position in the lexicographic ordering of index_space(result)
and index_space(operand)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (С1-С3) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
иquantization_dimension(result)
могут отличаться, в противном случае.
-
- (C2)
size(operand) = size(result)
. - (C3) Если
is_per_axis_quantized(operand)
:-
reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
. -
dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
. -
reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
-
Примеры
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
обеспечить регресс
Семантика
Reverses the order of elements in the operand
along the specified dimensions
and produces a result
tensor. Более формально, result[result_index] = operand[operand_index]
где:
-
operand_index[d] = dim(result, d) - result_index[d] - 1
ifd
indimensions
. -
operand_index[d] = result_index[d]
otherwise.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1), (С3) |
(И2) | dimensions | 1-мерная тензорная константа типа si64 | (С2), (С3) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1), (С3) |
Ограничения
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Примеры
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
звонок
Семантика
Generates random numbers using the rng_distribution
algorithm and produces a result
tensor of a given shape shape
.
If rng_distribution = UNIFORM
, then the random numbers are generated following the uniform distribution over the interval [a, b)
. If a >= b
, the behavior is undefined.
If rng_distribution = NORMAL
, then the random numbers are generated following the normal distribution with mean = a
and standard deviation = b
. If b < 0
, the behavior is undefined.
The exact way how random numbers are generated is implementation-defined. For example, they may or may not be deterministic, and they may or may not use hidden state.
In conversations with many stakeholders, this op has come up as effectively deprecated, so in the future we are planning to explore removing it ( #597 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | a | 0-dimensional tensor of integer, boolean, or floating-point type | (С1), (С2) |
(И2) | b | 0-dimensional tensor of integer, boolean, or floating-point type | (С1), (С2) |
(И3) | shape | 1-мерная тензорная константа типа si64 | (С3) |
(И4) | rng_distribution | enum of UNIFORM and NORMAL | (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | tensor of integer, boolean, or floating-point type | (С1-С3) |
Ограничения
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) If
rng_distribution = NORMAL
, thenis_float(a)
. - (C3)
shape(result) = shape
.
Примеры
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Семантика
Returns an output
filled with uniform random bits and an updated output state output_state
using the pseudorandom number generator algorithm rng_algorithm
given an initial state initial_state
. The output is guaranteed to be deterministic function of initial_state
, but it is not guaranteed to be deterministic between implementations.
rng_algorithm
is one of the following:
-
DEFAULT
: Implementation-defined algorithm. -
THREE_FRY
: Implementation-defined variant of the Threefry algorithm.* -
PHILOX
: Implementation-defined variant of the Philox algorithm.*
* See: Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | rng_algorithm | enum of DEFAULT , THREE_FRY , and PHILOX | (С2) |
(И2) | initial_state | 1-dimensional tensor of type ui64 | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
output_state | 1-dimensional tensor of type ui64 | (С1) |
output | tensor of integer or floating-point type |
Ограничения
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
is defined as:- implementation-defined if
rng_algorithm = DEFAULT
. -
2
ifrng_algorithm = THREE_FRY
. -
2
or3
ifrng_algorithm = PHILOX
.
- implementation-defined if
Примеры
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Семантика
Performs element-wise rounding towards the nearest integer, breaking ties away from zero, on the operand
tensor and produces a result
tensor. Implements the roundToIntegralTiesToAway
operation from the IEEE-754 specification. For quantized types, performs dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Семантика
Performs element-wise rounding towards the nearest integer, breaking ties towards the even integer, on the operand
tensor and produces a result
tensor. Implements the roundToIntegralTiesToEven
operation from the IEEE-754 specification. For quantized types, performs dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор типа с плавающей запятой или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Семантика
Performs element-wise reciprocal square root operation on operand
tensor and produces a result
tensor. В зависимости от типа элемента выполняет следующие действия:
- For floats:
rSqrt
from IEEE-754. - For complex numbers: complex reciprocal square root.
- For quantized types:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
разбрасывать
Семантика
Produces results
tensors which are equal to inputs
tensors except that several slices specified by scatter_indices
are updated with the values updates
using update_computation
.
The following diagram shows how elements in updates...
map on elements in results...
using a concrete example. The diagram picks a few example updates...
indices and explains in detail which results...
indices they correspond to.
More formally, for all update_index
in index_space(updates[0])
:
-
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
. -
update_scatter_index = update_index[update_scatter_dims...]
. -
start_index
определяется как:-
scatter_indices[si0, ..., :, ..., siN]
wheresi
are individual elements inupdate_scatter_index
and:
is inserted at theindex_vector_dim
index, ifindex_vector_dim
<rank(scatter_indices)
. -
[scatter_indices[update_scatter_index]]
otherwise.
-
- For
d_input
inaxes(inputs[0])
,-
full_start_index[d_input] = start_index[d_start]
ifd_input = scatter_dims_to_operand_dims[d_start]
. -
full_start_index[d_input] = 0
otherwise.
-
- For
d_input
inaxes(inputs[0])
,-
full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
ifd_input = input_batching_dims[i_batching]
andd_start = scatter_indices_batching_dims[i_batching]
. -
full_batching_index[d_input] = 0
otherwise.
-
-
update_window_index = update_index[update_window_dims...]
. -
full_window_index = [wi0, ..., 0, ..., wiN]
wherewi
are individual elements inupdate_window_index
, and0
is inserted at indices frominserted_window_dims
andinput_batching_dims
. -
result_index = full_start_index + full_batching_index + full_window_index
.
Given that, results = exec(schedule, inputs)
, where:
-
schedule
is an implementation-defined permutation ofindex_space(updates[0])
. -
exec([update_index, ...], results) = exec([...], updated_results)
where:- If
result_index
is in bounds forshape(results...)
-
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
-
updated_values = update_computation(results...[result_index], updates_converted)
-
updated_results
is a copy ofresults
withresults...[result_index]
set toupdated_values...
. - В противном случае
-
updated_results = results
.
- If
-
exec([], results) = results
.
If indices_are_sorted
is true
then the implementation can assume that scatter_indices
are sorted with respect to scatter_dims_to_operand_dims
, otherwise the behavior is undefined. Более формально для всех i1 < i2
из indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
If unique_indices
is true
then the implementation can assume that all result_index
indices being scattered to are unique. If unique_indices
is true
but the indices being scattered to are not unique then the behavior is undefined.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(И2) | scatter_indices | тензор целочисленного типа | (C4), (C15), (C19), (C22) |
(И3) | updates | вариативное число тензоров или потензорные квантованные тензоры | (C3-C6), (C8) |
(И4) | update_window_dims | 1-мерная тензорная константа типа si64 | (C2), (C4), (C7-C8) |
(И5) | inserted_window_dims | 1-мерная тензорная константа типа si64 | (C2), (C4), (C9-C11) |
(И6) | input_batching_dims | 1-мерная тензорная константа типа si64 | (C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims | 1-мерная тензорная константа типа si64 | (C14-C18) |
(И8) | scatter_dims_to_operand_dims | 1-мерная тензорная константа типа si64 | (C19-C21) |
(I9) | index_vector_dim | константа типа si64 | (C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted | константа типа i1 | |
(I11) | unique_indices | константа типа i1 | |
(I12) | update_computation | функция | (C23) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (C24-C25) |
Ограничения
- (C1)
same(shape(inputs...))
. - (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
where:-
update_scatter_dim_sizes = shape(scatter_indices)
except that the dimension size ofscatter_indices
corresponding toindex_vector_dim
is not included. -
update_window_dim_sizes <= shape(inputs[0])
except that the dimension sizes ininputs[0]
corresponding toinserted_window_dims
andinput_batching_dims
are not included. -
combine
putsupdate_scatter_dim_sizes
at axes corresponding toupdate_scatter_dims
andupdate_window_dim_sizes
at axes corresponding toupdate_window_dims
.
-
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
has type(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, whereis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
for alli
in[0,N)
.
Примеры
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
выбирать
Семантика
Produces a result
tensor where each element is selected from on_true
or on_false
tensor based on the value of the corresponding element of pred
. More formally, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index]
, where pred_element = rank(pred) = 0 ? pred[] : pred[result_index]
. For quantized types, performs dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | pred | tensor of type i1 | (С1) |
(И2) | on_true | тензорный или потензорный квантованный тензор | (С1-С2) |
(И3) | on_false | тензорный или потензорный квантованный тензор | (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С2) |
Ограничения
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Примеры
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Семантика
Scatters the values from the source
tensor using scatter
based on the outcome of reduce_window
of the input
tensor using select
and produces a result
tensor.
The following diagram shows how elements in result
are computed from operand
and source
using a concrete example.
More formally:
selected_values = reduce_window_without_init(...)
with the following inputs:-
inputs = [operand].
-
window_dimensions
,window_strides
, andpadding
which are used as is. -
base_dilations = windows_dilations = 1
. -
body
is defined as:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
where
E = element_type(operand)
, andreduce_window_without_init
works exactly likereduce_window
, except that theschedule
of the underlyingreduce
(see reduce ) doesn't include init values. It is currently unspecified what happens if the corresponding window doesn't have values ( #731 ).-
result[result_index] = reduce([source_values], [init_value], [0], scatter)
where:-
source_values = [source[source_index] for source_index in source_indices]
. -
selected_index(source_index) = operand_index
ifselected_values[source_index]
has theoperand
element fromoperand_index
. -
source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (C1-C4), (C6), (C8-C11) |
(И2) | source | тензорный или потензорный квантованный тензор | (С1), (С2) |
(И3) | init_value | 0-мерный тензор или квансор на квансор | (С3) |
(И4) | window_dimensions | 1-мерная тензорная константа типа si64 | (С2), (С4), (С5) |
(И5) | window_strides | 1-мерная тензорная константа типа si64 | (C2), (C6), (C7) |
(И6) | padding | 2-мерная тензорная константа типа si64 | (C2), (C8) |
(I7) | select | функция | (С9) |
(И8) | scatter | функция | (C10) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (C11-C12) |
Ограничения
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
where:-
padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
. -
is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
. -
num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
-
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
has type(tensor<E>, tensor<E>) -> tensor<i1>
whereE = element_type(operand)
. - (C10)
scatter
has type(tensor<E>, tensor<E>) -> tensor<E>
whereis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Примеры
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
отправлять
Семантика
Sends inputs
to a channel channel_id
and produces a result
token.
If is_host_transfer
is true
, then the operation transfers data to the host. Otherwise, it transfers data to another device. Что это значит, определяется реализацией. Этот флаг дублирует информацию, представленную в channel_type
, поэтому в будущем мы планируем сохранить только один из них ( #666 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | Вариальное количество тензоров или квантовых тензоров | |
(И2) | token | token | |
(И3) | channel_id | константа типа si64 | |
(И4) | channel_type | enum of DEVICE_TO_DEVICE and DEVICE_TO_HOST | (С1) |
(И5) | is_host_transfer | константа типа i1 | (С1) |
Выходы
Имя | Тип |
---|---|
result | token |
Ограничения
- (C1)
channel_type
определяется как:-
DEVICE_TO_HOST
ifis_host_transfer = true
, -
DEVICE_TO_DEVICE
иначе.
-
Примеры
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Семантика
Performs element-wise left-shift operation on the lhs
tensor by rhs
number of bits and produces a result
tensor.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа | (С1) |
(И2) | rhs | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Семантика
Performs element-wise arithmetic right-shift operation on the lhs
tensor by rhs
number of bits and produces a result
tensor.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа | (С1) |
(И2) | rhs | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Семантика
Performs element-wise logical right-shift operation on the lhs
tensor by rhs
number of bits and produces a result
tensor.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа | (С1) |
(И2) | rhs | тензор целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
знак
Семантика
Returns the sign of the operand
element-wise and produces a result
tensor. More formally, for each element x
, the semantics can be expressed using Python syntax as follows:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
For quantized types, performs dequantize_op_quantize(sign, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор целого числа со знаком, с плавающей запятой или комплексного типа или по-тензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целого числа со знаком, с плавающей запятой или комплексного типа или по-тензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
синус
Семантика
Performs element-wise sine operation on operand
tensor and produces a result
tensor. В зависимости от типа элемента выполняет следующие действия:
- For floats:
sin
from IEEE-754. - For complex numbers: complex sine.
- For quantized types:
dequantize_op_quantize(sine, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
кусочек
Семантика
Extracts a slice from the operand
using statically-computed starting indices and produces a result
tensor. start_indices
contain the starting indices of the slice for each dimension, limit_indices
contain the ending indices (exclusive) for the slice for each dimension, and strides
contain the strides for each dimension.
More formally, result[result_index] = operand[operand_index]
where operand_index = start_indices + result_index * strides
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензорный или потензорный квантованный тензор | (С1-С3), (С5) |
(И2) | start_indices | 1-мерная тензорная константа типа si64 | (C2), (C3), (C5) |
(И3) | limit_indices | 1-мерная тензорная константа типа si64 | (C2), (C3), (C5) |
(И4) | strides | 1-мерная тензорная константа типа si64 | (С2), (С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензорный или потензорный квантованный тензор | (С1), (С5) |
Ограничения
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Примеры
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
сортировать
Семантика
Sorts 1-dimensional slices of inputs
along the dimension dimension
together, according to a comparator
and produces results
.
Unlike similar inputs in other operations, dimension
allows negative values, with the semantics described below. In the future, this may be disallowed for consistency reasons ( #1377 ).
If is_stable
is true, then the sorting is stable, that is, relative order of elements considered to be equal by the comparator is preserved. For the case where there is a single input, two elements e1
and e2
are considered to be equal by the comparator if and only if comparator(e1, e2) = comparator(e2, e1) = false
. See the formalization below for how this generalizes to multiple inputs.
More formally, for all result_index
in index_space(results[0])
:
-
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
. -
result_slice = [ri0, ..., :, ..., riR-1]
whereriN
are individual elements inresult_index
, and:
is inserted atadjusted_dimension
. -
inputs_together = (inputs[0]..., ..., inputs[N-1]...)
. -
results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
. - where
sort
sorts a 1-dimensional slice in non-descending order expecting thatcomparator_together
returnstrue
if the left-hand side argument is less than the right-hand second argument. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | inputs | вариативное число тензоров или потензорные квантованные тензоры | (C1-C5) |
(И2) | dimension | константа типа si64 | (С4) |
(И3) | is_stable | константа типа i1 | |
(И4) | comparator | функция | (С5) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров или потензорные квантованные тензоры | (С2), (С3) |
Ограничения
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, whereR = rank(inputs[0])
. - (C5)
comparator
has type(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, whereEi = element_type(inputs[i])
.
Примеры
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
кврт
Семантика
Performs element-wise square root operation on operand
tensor and produces a result
tensor. В зависимости от типа элемента выполняет следующие действия:
- For floats:
squareRoot
from IEEE-754. - For complex numbers: complex square root.
- For quantized types:
dequantize_op_quantize(sqrt, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
вычесть
Семантика
Performs element-wise subtraction of two tensors lhs
and rhs
and produces a result
tensor. В зависимости от типа элемента выполняет следующие действия:
- For integers: integer subtraction.
- For floats:
subtraction
from IEEE-754. - For complex numbers: complex subtraction.
- Для квантованных типов:
-
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
(И2) | rhs | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор целочисленного типа, с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Примеры
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
загар
Семантика
Performs element-wise tangent operation on the operand
tensor and produces a result
tensor. В зависимости от типа элемента выполняет следующие действия:
- For floats:
tan
from IEEE-754. - For complex numbers: complex tangent.
- For quantized types:
dequantize_op_quantize(tan, operand, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Семантика
Performs element-wise hyperbolic tangent operation on operand
tensor and produces a result
tensor. В зависимости от типа элемента выполняет следующие действия:
- For floats:
tanh
from IEEE-754. - For complex numbers: complex hyperbolic tangent.
- Для квантованных типов:
-
dequantize_op_quantize(tanh, operand, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_type(operand) = baseline_type(result)
.
Примеры
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
транспонировать
Семантика
Permutes the dimensions of operand
tensor using permutation
and produces a result
tensor. More formally, result[result_index] = operand[operand_index]
where result_index[d] = operand_index[permutation[d]]
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | тензор или квантованный тензор | (С1-С4) |
(И2) | permutation | 1-мерная тензорная константа типа si64 | (С2-С4) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор или квантованный тензор | (C1), (C3-C4) |
Ограничения
- (C1)
element_type(result)
определяется как:-
element_type(operand)
, если!is_per_axis_quantized(operand)
. -
element_type(operand)
за исключением того, чтоquantization_dimension(operand)
иquantization_dimension(result)
могут отличаться, в противном случае.
-
- (C2)
permutation
is a permutation ofrange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) If
is_per_axis_quantized(result)
, thenquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Примеры
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Семантика
Solves batches of systems of linear equations with lower or upper triangular coefficient matrices.
More formally, given a
and b
, result[i0, ..., iR-3, :, :]
is the solution to op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
when left_side
is true
or x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
when left_side
is false
, solving for the variable x
where op(a)
is determined by transpose_a
, which can be one of the following:
-
NO_TRANSPOSE
: Perform operation usinga
as-is. -
TRANSPOSE
: Perform operation on transpose ofa
. -
ADJOINT
: Perform operation on conjugate transpose ofa
.
Input data is read only from the lower triangle of a
, if lower
is true
or upper triangle of a
, otherwise. Output data is returned in the same triangle; the values in the other triangle are implementation-defined.
If unit_diagonal
is true, then the implementation can assume that the diagonal elements of a
are equal to 1, otherwise the behavior is undefined.
For quantized types, performs dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result))
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | a | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1-С3) |
(И2) | b | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1-С4) |
(И3) | left_side | константа типа i1 | (С3) |
(И4) | lower | константа типа i1 | |
(И5) | unit_diagonal | константа типа i1 | |
(И6) | transpose_a | enum of NO_TRANSPOSE , TRANSPOSE , and ADJOINT |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор с плавающей запятой или комплексного типа или потензорный квантованный тензор | (С1) |
Ограничения
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) The relationship between
shape(a)
andshape(b)
is defined as follows:-
shape(a)[:-3] = shape(b)[:-3]
. -
dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
-
- (C4)
baseline_type(b) = baseline_type(result)
.
Примеры
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
кортеж
Семантика
Produces a result
tuple from values val
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | val | вариативное число значений | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | кортеж | (С1) |
Ограничения
- (C1)
result
has typetuple<E0, ..., EN-1>
whereEi = type(val[i])
.
Примеры
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Семантика
Performs element-wise conversion of quantized tensor operand
to a floating-point tensor result
according to the quantization parameters defined by the operand
type.
More formally, result = dequantize(operand)
.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | quantized tensor | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | Тензор типа с плавающей точкой | (С1), (С2) |
Ограничения
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Примеры
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Семантика
Performs element-wise conversion of floating-point tensor or quantized tensor operand
to a quantized tensor result
according to the quantization parameters defined by the result
type.
Более формально,
- If
is_float(operand)
:-
result = quantize(operand, type(result))
.
-
- If
is_quantized(operand)
:-
float_result = dequantize(operand)
. -
result = quantize(float_result, type(result))
.
-
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | tensor of floating-point or quantized type | (С1), (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | quantized tensor | (С1), (С2) |
Ограничения
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Примеры
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
пока
Семантика
Produces the output from executing body
function 0 or more times while the cond
function outputs true
. More formally, the semantics can be expressed using Python syntax as follows:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
The behavior of an infinite loop is TBD ( #383 ).
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | operand | вариативное число тензоров, квантованных тензоров или токенов | (С1-С3) |
(И2) | cond | функция | (С1) |
(И3) | body | функция | (С2) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
results | вариативное число тензоров, квантованных тензоров или токенов | (С3) |
Ограничения
- (C1)
cond
has type(T0, ..., TN-1) -> tensor<i1>
, whereTi = type(operand[i])
. - (C2)
body
has type(T0, ..., TN-1) -> (T0, ..., TN-1)
, whereTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Примеры
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Семантика
Performs element-wise XOR of two tensors lhs
and rhs
and produces a result
tensor. В зависимости от типа элемента выполняет следующие действия:
- For booleans: logical XOR.
- For integers: bitwise XOR.
Входы
Этикетка | Имя | Тип | Ограничения |
---|---|---|---|
(I1) | lhs | тензор логического или целочисленного типа | (С1) |
(И2) | rhs | тензор логического или целочисленного типа | (С1) |
Выходы
Имя | Тип | Ограничения |
---|---|---|
result | тензор логического или целочисленного типа | (С1) |
Ограничения
- (C1)
type(lhs) = type(rhs) = type(result)
.
Примеры
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Dialect Interop
At the moment, StableHLO programs in the wild sometimes contain operations that are not defined by StableHLO.
Module, Function, Call and Return
StableHLO uses upstream MLIR operations for ModuleOp, FuncOp, CallOp, and ReturnOp. This was done for better interop with existing MLIR machinery, as many useful passes are written targeting FuncOp and ModuleOp, and many compilation pipelines expect these ops to be present. Full compatibility guarantees are applied to these ops. If anything ever changes about these ops in an incompatible way (ie removal), StableHLO equivalents will be added to preserve compatibility.
CHLO
The CHLO opset contains higher level operations that decompose to StableHLO. Currently there are no compatibility guarantees for CHLO. For compatibility guarantees, the chlo-legalize-to-stablehlo pass must be used prior to serialization.
Shape Operations
It is a common use case in the community to use certain operations from core MLIR dialects in dynamic StableHLO programs to perform shape computations. Most commonly, these include shape
dialect ops like shape_of
or num_elements
, tensor
dialect ops like dim
or from_elements
, and the builtin index
type.
The Dynamism RFC > O2 denotes these as out of scope, however some support for index
types is included for interop purposes. There are no compatibility guarantees for these ops or types. The shape-legalize-to-stablehlo pass can be used to convert these operations to fully supported StableHLO ops.
Deprecated Operations
There are several StableHLO operations that were inherited from MHLO which are deprecated and on the way out of StableHLO. The full details on these removals can be found in the StableHLO v1.0 Cleanup #2283 . The tracker issue for these deprecations is #2340 .
These operations fall into a few categories:
- "Not in HLO" category of StableHLO operations - they were initially part of the StableHLO opset but have been later deemed to not fit it well:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
( #3 ) . - Unused ops - These operations may have been useful at some point, but the ops were either underdeveloped, or the pipelines using these ops have been refactored to not require them anymore. This includes
map
,tuple
( #598 ),get_tuple_element
,rng
,complex
comparisons #560 , and convolutionwindow_reversal
( #1181 ).
Some of these ops can be removed easily given that they can be expressed using existing ops ( broadcast
, create_token
, cross-replica-sum
, dot
, unary_einsum
) and will be removed after the existing compatibilty window passes (6 months). Others are still being explored for removal ( einsum
, get_tuple_element
, map
, rng
torch_index_select
, tuple
, complex
comparisons, window_reversal
). Pending community feedback, these ops will either be removed, or added to the spec with full support. Until these ops futures are known, they are only guaranteed 6 months of compatibility.
Исполнение
Sequential execution
A StableHLO program is executed by providing input values to the main
function and computing output values. Output values of a function are computed by executing the graph of ops rooted in the corresponding return
op.
The execution order is implementation-defined as long as it is aligned with dataflow, ie if ops are executed before their uses. In StableHLO, all side-effecting ops consume one token and produce one token (multiple tokens can be multiplexed into one token via after_all
), so the execution order of side effects is also aligned with dataflow. For example, in the below program there are two possible execution orders: %0
→ %1
→ %2
→ return
and %1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
More formally, a StableHLO process is a combination of: 1) a StableHLO program, 2) operation statuses (not executed yet, already executed), and 3) intermediate values that the process is working on. The process starts with input values to the main
function, progresses through the graph of ops updating operation statuses and intermediate values and finishes with output values. Further formalization is TBD ( #484 ).
Parallel execution
StableHLO programs can be executed in parallel, organized into a 2D process grid of num_replicas
by num_partitions
which both have type ui32
.
In the StableHLO process grid , num_replicas * num_partitions
of StableHLO processes are executing at the same time. Each process has a unique process_id = (replica_id, partition_id)
, where replica_id
in replica_ids = range(num_replicas)
and partition_id
in partition_ids = range(num_partitions)
which both have type ui32
.
The size of the process grid is known statically for every program (in the future, we are planning to make it an explicit part of StableHLO programs #650 ), and the position within the process grid is known statically for every process. Each process has access to its position within the process grid via the replica_id
and partition_id
ops.
Within the process grid, the programs can all be the same (in the "Single Program, Multiple Data" style), can all be different (in the "Multiple Program, Multiple Data" style) or something in between. In the future, we are planning to introduce support for other idioms of defining parallel StableHLO programs, including GSPMD ( #619 ).
Within the process grid, the processes are mostly independent from each other - they have separate operation statuses, separate input/intermediate/output values and most of the ops are executed separately between processes, with the exception of a small number of collective ops described below .
Given that execution of most of the ops is only using values from the same process, it is usually unambiguous to refer to these values by their names. However, when describing semantics of collective ops, that is insufficient, and that gives rise to the notation name@process_id
to refer to the value name
within a particular process. (From that perspective, unqualified name
can be viewed as a shorthand for name@(replica_id(), partition_id())
).
The execution order across processes is implementation-defined, except for the synchronization introduced by point-to-point communication and collective ops as described below.
Point-to-point communication
StableHLO processes can communicate with each other through StableHLO channels . A channel is represented by a positive id of type si64
. Through various ops, it is possible to send values to channels and receive them from channels.
Further formalization, eg where these channel ids are coming from, how processes programs become aware of them and what kind of synchronization is introduced by them, is TBD ( #484 ).
Streaming communication
Every StableHLO process has access to two streaming interfaces:
- Infeed that can be read from.
- Outfeed that can be written to.
Unlike channels, which are used to communicate between processes and therefore have processes at both of their ends, infeeds and outfeeds have their other end implementation-defined.
Further formalization, eg how streaming communication influences execution order and what kind of synchronization is introduced by it, is TBD ( #484 ).
Collective ops
There are six collective ops in StableHLO: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
, and reduce_scatter
. All these ops split the processes in the StableHLO process grid into StableHLO process groups and execute a joint computation within each process group, independently from other process groups.
Within each process group, collective ops may introduce a synchronization barrier. Further formalization, eg elaborating on when exactly this synchronization happens, how exactly the processes arrive at this barrier, and what happens if they don't, is TBD ( #484 ).
If the process group involves cross-partition communication, ie there are processes in the process group whose partition ids are different, then execution of the collective op needs a channel, and the collective op must provide a positive channel_id
of type si64
. Cross-replica communication doesn't need channels.
The computations performed by the collective ops are specific to individual ops and are described in individual op sections above. However, the strategies by which the process grid is split into process groups are shared between these ops and are described in this section. More formally, StableHLO supports the following four strategies.
cross_replica
Only cross-replica communications happen within each process group. This strategy takes replica_groups
- a list of lists of replica ids - and computes a Cartesian product of replica_groups
by partition_ids
. replica_groups
must have unique elements and cover all replica_ids
. More formally, using Python syntax:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
For example, for replica_groups = [[0, 1], [2, 3]]
and num_partitions = 2
, cross_replica
will produce [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Only cross-partition communications happen within each process group. This strategy takes partition_groups
- a list of lists of partition ids - and computes a Cartesian product of partition_groups
by replica_ids
. partition_groups
must have unique elements and cover all partition_ids
. More formally, using Python syntax:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
For example, for partition_groups = [[0, 1]]
and num_replicas = 4
, cross_partition
will produce [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Both cross-replica and cross-partition communications may happen within each process group. This strategy takes replica_groups
- a list of lists of replica ids - and computes Cartesian products of each replica_group
by partition_ids
. replica_groups
must have unique elements and cover all replica_ids
. More formally, using Python syntax:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
For example, for replica_groups = [[0, 1], [2, 3]]
and num_partitions = 2
, cross_replica_and_partition
will produce [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
This strategy takes flattened_id_groups
- a list of lists of "flattened" process ids in the form of replica_id * num_partitions + partition_id
- and turns them into process ids. flattened_id_groups
must have unique elements and cover all process_ids
. More formally, using Python syntax:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
For example, for flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
and num_partitions = 2
, flattened_ids
will produce [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Точность
At the moment, StableHLO does not provide guarantees about numerical accuracy, but this may change in the future ( #1156 ).
Execution semantics of quantized operation
The interpretation of quantized StableHLO operations may vary depending on the hardware requirements and capabilities. For instance, some hardware may opt to interpret quantized operations using a "dequantize, perform floating-point operation, and finally quantize" strategy. Others may perform the entire computation with integer arithmetic. Consequently, the interpretation of quantized StableHLO operations is exclusively determined by the specific implementation. The interpretation of hybrid quantization ( #1575 ) should be based on the it's semantics as prescribed in the specification (via 1792 ).
Ошибки
StableHLO programs are validated through an extensive set of constraints for individual ops, which rules out many classes of errors prior to run time. However, error conditions are still possible, eg through integer overflows, out-of-bounds accesses, etc. Unless explicitly called out, all these errors result in implementation-defined behavior, but this may change in the future ( #1157 ).
Floating-point exceptions
As an exception to this rule, floating-point exceptions in StableHLO programs have well-defined behavior. Operations which result in exceptions defined by the IEEE-754 standard (invalid operation, division-by-zero, overflow, underflow, or inexact exceptions) produce default results (as defined in the standard) and continue execution without raising the corresponding status flag; similar to raiseNoFlag
exception handling from the standard. Exceptions for nonstandard operations (eg complex arithmetic and certain transcendental functions) are implementation-defined.
Shape mismatches
StableHLO supports dynamically-shaped tensors. However, shapes have to agree at runtime, otherwise the behavior is undefined. StableHLO does not explicitly provide an op that can assert that a tensor has a given shape at runtime. Generating correct code is the responsibility of the producer.
As a specific example, the below program is valid. However, at runtime, the exact shapes of %arg0
and %arg1
will have to be the same, otherwise the behavior of the program is undefined:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
For describing syntax, this document is using the modified ISO flavor of EBNF syntax ( ISO/IEC 14977:1996 , Wikipedia ), with two modifications: 1) rules are defined using ::=
rather than =
,
2) concatenation is expressed using juxtaposition rather than ,
.
For describing semantics (ie within "Types", "Constants" and "Ops" sections), we are using formulas which are based on Python syntax extended with support for concisely expressing array operations as described below. This works well for small snippets of code, but in rare cases when larger snippets of code are needed, we use vanilla Python syntax which is always introduced explicitly.
Формулы
Let's explore how formulas work based on an example from the dot_general
specification. One of the constraints for this operation looks as follows: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
The names used in this formula come from two sources: 1) global functions, ie dim
, 2) member definitions of the corresponding program element, ie lhs
, lhs_batching_dimensions
, rhs
and rhs_batching_dimensions
inputs defined in the "Inputs" section of dot_general
.
As mentioned above, the syntax of this formula is Python-based with some conciseness-oriented extensions. To make sense of the formula, let's transform it into vanilla Python syntax.
A) In these formulas, we are using =
to represent equality, so the first step towards obtaining Python syntax is replacing =
with ==
, as follows: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Also, these formulas support ellipses ( ...
) which turn scalar expressions into tensor expressions. In a nutshell, f(xs...)
roughly means "for each scalar x
in the tensor xs
, compute a scalar f(x)
and then return all these scalar results together as a tensor result". In vanilla Python syntax, our example formula turns into: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Thanks to ellipses, it is often possible to avoid working at the level of individual scalars. However, in some tricky cases, lower-level semi-informal syntax may be used like in the start_indices[bi0, ..., :, ..., biN]
formula from the gather
specification. In the service of conciseness, we don't provide an exact formalism for translating such syntax to vanilla Python, in hopes that it is still intuitively understandable on case-by-case basis. Please let us know if some specific formulas look opaque, and we'll try to improve them.
Also, you will notice that formulas use ellipses to expand all sorts of lists, including tensors, lists of tensors (which eg can arise from a variadic number of tensors), etc. This is another area where we don't provide an exact formalism (eg lists are not even part of the StableHLO type system) and instead rely on intuitive understandability.
C) The final noteworthy notational vehicle that we employ is implicit broadcasting. While the StableHLO opset doesn't support implicit broadcasting, the formulas do, also in the service of conciseness. In a nutshell, if a scalar is used in a context where a tensor is expected, the scalar is broadcasted to the expected shape.
To continue the dot_general
example, here's another constraint: 0 <= lhs_batching_dimensions < rank(lhs)
. As defined in the dot_general
specification, lhs_batching_dimensions
is a tensor, however both 0
and rank(lhs)
are scalars. After we apply implicit broadcasting, the formula will become [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
When applied to a particular dot_general
operation, this formula will evaluate to a tensor of booleans. When formulas are used as constraints, the constraint holds if the formula evaluates to either true
or to a tensor which only has true
elements.
Names
In formulas, lexical scope includes: 1) global functions, 2) member definitions,
3) local definitions. The list of global functions is provided below. The list of element definitions depends on the program element that the notation is applied to:
- For operations, member definitions include names introduced in "Inputs" and "Outputs" sections.
- For everything else, member definitions include structural parts of the program element, named after the corresponding EBNF non-terminals. Most of the time, the names of these structural parts are obtained by converting the names of the non-terminals to snake case (eg
IntegerLiteral
=>integer_literal
), but sometimes names get abbreviated in the process (egQuantizationStorageType
=>storage_type
) in which case the names are introduced explicitly similarly to "Inputs" / "Outputs" sections in operation specifications. - Additionally, member definitions always include
self
to refer to the corresponding program element.
Ценности
When formulas are evaluated, they work with the following types of values: 1) Value
(actual values, eg dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
; they always know their types), 2) Placeholder
(future values, eg lhs
, rhs
or result
; their actual values are not known yet, only their types are known), 3) Type
(types as defined in the "Types" section), 4) Function
(global functions as defined in the "Functions" section).
Depending on the context, names may be referring to different values. More specifically, the "Semantics" section for ops (and equivalents for other program elements) defines runtime logic, so all inputs are available as Value
. In contrast, the "Constraints" section for ops (and equivalents) defines "compile-time" logic, ie something that is typically executed before runtime, so only constant inputs are available as Value
and other inputs are available only as Placeholder
.
Names | In "Semantics" | In "Constraints" |
---|---|---|
Global functions | Function | Function |
Constant inputs | Value | Value |
Non-constant inputs | Value | Placeholder |
Выходы | Value | Placeholder |
Local definitions | Depends on the definition | Depends on the definition |
Let's consider an example transpose
operation:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
For this operation, permutation
is a constant, so it's available as a Value
in both semantics and constraints. In contrast, operand
and result
are available as a Value
in semantics but only as a Placeholder
in constraints.
Функции
Construction of types
There are no functions that can be used to construct types. Instead, we directly use type syntax because it's typically more concise. Eg (tensor<E>, tensor<E>) -> (tensor<E>)
rather than function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Functions on types
-
element_type
is defined on tensor types and quantized tensor types and returns, respectively, theTensorElementType
orQuantizedTensorElementType
part of the correspondingTensorType
orQuantizedTensorType
.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
is a shortcut foris_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
is a shortcut foris_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
checks if typex
can be promoted to typey
. Whenx
andy
areQuantizedTensorElementType
s, the promotion is applied only to thestorage_type
. This specific version of promotion is currently used in context of reduction computation (refer to RFC for more details).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
is a shortcut foris_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Available for all types. For example,is_float(x)
returnstrue
ifx
is aFloatType
. Ifx
is a value or placeholder, this function is a shortcut foris_type_name(type(x))
.max_value(x: Type) -> Value
returns the maximum value of anTensorElementType
. Ifx
is not anTensorElementType
, returnsNone
.min_value(x: Type) -> Value
returns the minimum possible value of anTensorElementType
. Ifx
is not anTensorElementType
, returnsNone
.member_name(x: Value | Placeholder | Type) -> Any
. Available for all member definitionsmember_name
of all types. For example,tensor_element_type(x)
returns theTensorElementType
part of a correspondingTensorType
. Ifx
is a value or placeholder, this function is a shortcut formember_name(type(x))
. Ifx
is not a type that has an appropriate member, or a value or a placeholder of such a type, returnsNone
.is_empty_algorithm(*args: Type)
checks if all dot algorithm fields are set toNone
. This is needed since dot algorithms have implementation defined default behaviors, so specifying a default value would be incorrect.
Construction of values
-
operation_name(*xs: Value | Type) -> Value
. Available for all operations. For example,add(lhs, rhs)
takes two tensor valueslhs
andrhs
and returns the output of evaluating theadd
operation with these inputs. For some operations egbroadcast_in_dim
, types of their outputs are "load-bearing", ie needed to evaluate an operation. In this case, the function takes these types as arguments.
Functions on values
All Python's operators and functions are available. Eg both subscription and slicing notations from Python are available to index into tensors, quantized tensors and tuples.
to_destination_type(x: Value, destination_type: Type) -> Value
is defined on tensors and returns the converted value ofx
based on thetype(x)
anddestination_type
as follows:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
There is early discussion on merging convert
, uniform_quantize
and uniform_dequantize
operations ( #1576 ). After the merge we do not need the above function and can use the operation name for convert
instead.
is_nan(x: Value) -> Value
is defined on tensors and returnstrue
if all elements ofx
areNaN
orfalse
otherwise. Ifx
is not a tensor, returnsNone
.is_sorted(x: Value) -> Value
is defined on tensors and returnstrue
if elements ofx
are sorted in ascending order with respect to the ascending lexicographical order of their indices orfalse
otherwise. Ifx
is not a tensor, returnsNone
.is_unique(x: Value) -> Value
is defined on tensors and returnstrue
ifx
doesn't have duplicate elements orfalse
otherwise. Ifx
is not a tensor, returnsNone
.member_name(x: Value) -> Any
is defined for all member definitionsmember_name
of all values. For example,real_part(x)
returns theRealPart
part of a correspondingComplexConstant
. Ifx
is not a value that has an appropriate member, returnsNone
.same(x: Value) -> Value
is defined on tensors and returnstrue
if elements ofx
are all equal to each other orfalse
otherwise. If the tensor doesn't have elements, that counts as "all equal to each other", ie the function returnstrue
. Ifx
is not a tensor, returnsNone
.split(x: Value, num_results: Value, axis: Value) -> Value
is defined on tensors and returnsnum_results
slices ofx
along the axisaxis
. Ifx
is not a tensor ordim(x, axis) % num_results != 0
, returnsNone
.is_defined_in_parent_scope(x: Value) -> Value
is defined on strings and returnstrue
ifx
is the name of a function defined in the same scope as the parent function of the relevant op.is_namespaced_op_name(x: Value) -> Value
is defined on strings and returnstrue
ifx
is a valid op name, that is it respects the following regular expression:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Shape computations
axes(x: Value | Placeholder | Type) -> Value
is a shortcut forrange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
is a shortcut forshape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
is a shortcut forlist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
is defined on tensors and returnssize(x)
indices for the correspondingTensorType
sorted in ascending lexicographical order, ie[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Ifx
is not a tensor type, a quantized tensor type, or a value or a placeholder of one of these types, returnsNone
.rank(x: Value | Placeholder | Type) -> Value
is a shortcut forsize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
is defined in the "Functions on types" section viamember_name
.size(x: Value | Placeholder | Type) -> Value
is a shortcut forreduce(lambda x, y: x * y, shape(x))
.
Quantization computations
def baseline_element_type(x: Value | Placeholder | Type) -> Type
is a shortcut forelement_type(baseline_type(x))
.baseline_type
is defined on tensor types and quantized tensor types and transforms them to a "baseline", ie a type with the same shape but with the quantization parameters of the element type reset to default values. This is used as a handy trick to compare both tensor and quantized tensor types uniformly, which is needed quite often. For quantized types, this enables comparing types ignoring the quantization parameters, that is,shape
,storage_type
,expressed_type
,storage_min
,storage_max
, andquantization_dimension
(for per-axis quantized type) must all match, butscales
andzero points
may differ.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
-
dequantize
is defined on quantized tensor types and turns them into floating-point tensor types. This happens via converting quantized elements which represent integer values of the storage type into corresponding floating-point values of the expressed type using the zero point and scale associated with the quantized element type.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
-
quantize
is defined on floating-point tensor types and turns them into quantized tensor types. This happens via converting floating-point values of the expressed type into corresponding integer values of the storage type using the zero point and scale associated with the quantized element type.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
-
dequantize_op_quantize
is used to specify element-wise computations on quantized tensors. It dequantizes, ie turns quantized elements into their expressed types, then performs an operation, and then quantizes, ie turns the results back into their storage types. At the moment, this function only works for per-tensor quantization. Per-axis quantization is work in progress ( #1574 ).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
-
hybrid_dequantize_then_op
is used to specify weight-only quantization for hybrid op which accepts lhs in floating-point and rhs in quantized types. It dequantizes quantized inputs into their expressed types and performs computation in float. Element type of float lhs tensor and expressed type of quantized rhs tensor should be identical.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Grid computations
cross_partition(replica_groups: Value) -> Value
. See the "cross_replica" section above.cross_replica(replica_groups: Value) -> Value
. See the "cross_replica" section above.cross_replica_and_partition(replica_groups: Value) -> Value
. See the "cross_replica_and_partition" section above.flattened_ids(replica_groups: Value) -> Value
. See the "flattened_ids" section above.
Dynamism
StableHLO values can have dynamic dimension sizes, eg tensor<?xi64>
. However, StableHLO values cannot have a dynamic number of dimensions (unranked dynamism, eg tensor<*xi64>
). Operands and results are allowed to use dynamic dimension sizes, even if there are constraints on the sizes. Constraints will be verified statically if possible, otherwise they are deferred to runtime and mismatches will result in undefined behavior. See below for examples.
Shape mismatches for unary elementwise operations
Consider the following toy program:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Such a program is unusual, because it is not common to know the shape of the result but not the shape of the input. Nonetheless, this is a valid StableHLO program. It is not possible to statically validate the abs
operation in this program, because the exact shape of the operand is unknown. However, the shapes are certainly compatible, and this can be checked statically: ?
could turn out to be 2
at runtime, and there would be no issue. Однако, ?
could also turn out to be some other integer, in which case the behavior is undefined.
Note that if a dimension size is dynamic in the result, there cannot be undefined behavior. Indeed, there is no "expected" size, so there cannot be a mismatch.
Shape mismatches for binary elementwise operations
Consider the following toy program:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
When it comes to binary elementwise operations, the shapes of the inputs and the result must agree at runtime. At compile time, static dimensions must be equal, otherwise they merely need to be compatible. If any dimension is dynamic in the inputs, then there could be undefined behavior at runtime, because the dynamic size may not match the corresponding size in the other operand (be it static or dynamic). If all the inputs are static, then whether the result is dynamic or not does not matter: statically known dimensions will be checked statically, and dynamic dimensions do not impose any constraints.
Shape mismatches for ops that take their output shape as an operand
Consider the following toy program:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
The values in the shape operand at runtime must match the shape of the result, otherwise the behavior is undefined. That is, at runtime %arg0
must have a value of dense<[3, 4]> : tensor<2xi32>
. If the shape operand is constant, this can be verified statically. If the result shape is fully dynamic, then there cannot be a mismatch.