StableHLO es un conjunto de operaciones para operaciones de alto nivel (HLO) en modelos de aprendizaje automático (AA). StableHLO funciona como una capa de portabilidad entre diferentes compiladores y frameworks de AA: los frameworks de AA que producen programas de StableHLO son compatibles con los compiladores de AA que consumen programas de StableHLO.
Nuestro objetivo es simplificar y acelerar el desarrollo del AA mediante la creación de más interoperabilidad entre varios frameworks de AA (como TensorFlow, JAX y PyTorch) y compiladores de AA (como IREE y XLA). Para ello, en este documento, se proporciona una especificación para el lenguaje de programación StableHLO.
Esta especificación contiene tres secciones principales. En primer lugar, en la sección Programs, se describe la estructura de los programas de StableHLO, que consisten en funciones de StableHLO que, a su vez, consisten en operaciones de StableHLO. Dentro de esa estructura, la sección Ops especifica la semántica de las operaciones individuales. La sección Ejecución proporciona semántica para todas estas operaciones que se ejecutan juntas dentro de un programa. Por último, en la sección Notation, se analiza la notación que se usa en toda la especificación.
Para ver la especificación de una versión anterior de StableHLO, abre el repositorio en la versión etiquetada de interés. Por ejemplo, la especificación de StableHLO v0.19.0. Para ver los cambios que se produjeron en cada aumento de versión menor de StableHLO, consulta el registro de versiones en VhloDialect.td.
Programas
Program ::= {Func}
Los programas de StableHLO consisten en una cantidad arbitraria de funciones de StableHLO.
A continuación, se muestra un programa de ejemplo con una función @main
que tiene 3 entradas (%image
, %weights
y %bias
) y 1 salida. El cuerpo de la función tiene 6 operaciones.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funciones
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Las funciones StableHLO (que también se denominan funciones con nombre) tienen un identificador, entradas/salidas y un cuerpo. En el futuro, planeamos introducir metadatos adicionales para las funciones para lograr una mejor compatibilidad con HLO (#425, #626, #740 y #744).
Identificadores
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Los identificadores de StableHLO son similares a los identificadores de muchos lenguajes de programación, con dos peculiaridades: 1) todos los identificadores tienen símbolos que distinguen diferentes tipos de identificadores y 2) los identificadores de valor pueden ser completamente numéricos para simplificar la generación de programas de StableHLO.
Tipos
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Los tipos de StableHLO se clasifican en tipos de valor (que también se denominan tipos de primera clase) que representan valores de StableHLO y tipos que no son de valor que describen otros elementos del programa. Los tipos de StableHLO son similares a los tipos de muchos lenguajes de programación, y la principal peculiaridad es la naturaleza específica del dominio de StableHLO, que genera algunos resultados inusuales (p. ej., los tipos escalares no son tipos de valor).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Los tipos de tensores representan tensores, es decir, arreglos multidimensionales. Tienen una forma y un tipo de elemento, en el que una forma representa tamaños de dimensión no negativos o desconocidos en el orden ascendente de las dimensiones correspondientes (que también se denominan ejes) numeradas de 0
a R-1
. La cantidad de dimensiones R
se denomina rango. Por ejemplo, tensor<2x3xf32>
es un tipo de tensor con la forma 2x3
y el tipo de elemento f32
. Tiene dos dimensiones (o, en otras palabras, dos ejes): la dimensión 0 y la dimensión 1, cuyos tamaños son 2 y 3. Su clasificación es 2.
Las formas pueden ser parcialmente o completamente desconocidas (dinámicas), p. ej., tensor<?x2xf64>
es parcialmente desconocida y tensor<?x?xf64>
es completamente desconocida. Los tamaños de las dimensiones dinámicas se representan con un ?
. Las formas no pueden tener una clasificación.
En el futuro, planeamos explorar la extensión de los tipos de tensores más allá de los tamaños de dimensión y los tipos de elementos, por ejemplo, para incluir diseños (#629) y dispersión (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
Nombre | Tipo | Limitaciones |
---|---|---|
storage_type |
tipo de número entero | (C1-C3), (C8) |
storage_min |
constante de número entero | (C1), (C3) y (C7) |
storage_max |
constante de número entero | (C2), (C3), (C7) |
expressed_type |
tipo de punto flotante | (C4) |
quantization_dimension |
constante de número entero opcional | (C10-C12) |
scales |
número variádico de constantes de punto flotante | (C4-C6), (C9), (C10), (C13) |
zero_points |
cantidad variádica de constantes de número entero | (C7-C9) |
Los tipos de elementos cuantificados representan valores enteros de un tipo de almacenamiento en el rango de storage_min
a storage_max
(inclusive) que corresponden a valores de punto flotante de un tipo expresado. Para un valor de número entero i
determinado, el valor de punto flotante correspondiente f
se puede calcular como f = (i - zero_point) * scale
, en el que scale
y zero_point
se denominan parámetros de cuantización. storage_min
y storage_max
son opcionales en la gramática, pero tienen valores predeterminados de min_value(storage_type)
y max_value(storage_type)
, respectivamente. Los tipos de elementos cuantificados tienen las
siguientes restricciones:
- (C1)
type(storage_min) = storage_type
- (C2)
type(storage_max) = storage_type
. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C4)
type(scales...) = expressed_type
. - (C5)
0 < scales
. - (C6)
is_finite(scales...)
. - (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
. - (C9)
size(scales) = size(zero_points)
. - (C10) Si
is_empty(quantization_dimension)
, entoncessize(scales) = 1
. - (C11)
0 <= quantization_dimension
.
Por el momento, QuantizationScale
es una constante de punto flotante, pero hay un gran interés en las escalas basadas en números enteros, representadas con multiplicadores y cambios. Planeamos explorar esto en un futuro cercano (#1404).
Hay un debate en curso sobre la semántica de QuantizationZeroPoint
,
incluidos el tipo, los valores y si puede haber uno o
potencialmente varios puntos cero en un tipo de tensor cuantizado. En función de los resultados de este debate, la especificación en torno a cero puntos podría cambiar en el futuro (#1405).
Otro análisis en curso involucra la semántica de QuantizationStorageMin
y QuantizationStorageMax
para determinar si se deben imponer restricciones en estos valores y en los valores de los tensores cuantificados (#1406).
Por último, planeamos explorar la representación de escalas y puntos cero desconocidos, de manera similar a como planeamos explorar la representación de tamaños de dimensión desconocidos (#1407).
Los tipos de tensores cuantificados representan tensores con elementos cuantificados. Estos tensores son exactamente los mismos que los regulares, con la excepción de que sus elementos tienen tipos de elementos cuantificados, en lugar de tipos de elementos regulares.
En los tensores cuantificados, la cuantización puede ser por tensor, es decir, tener un scale
y un zero_point
para todo el tensor o puede ser por eje, es decir, tener varios scales
y zero_points
, un par por fragmento de una dimensión quantization_dimension
particular. De forma más formal, en un tensor t
con cuantificación por eje, hay dim(t, quantization_dimension)
rebanadas de quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
, etc. Todos los elementos de la i
ª rebanada usan scales[i]
y zero_points[i]
como sus parámetros de cuantificación. Los tipos de tensores cuantificados tienen las siguientes restricciones:
- Para la cuantificación por tensor:
- Sin restricciones adicionales.
- Para la cuantificación por eje, haz lo siguiente:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Los tipos de tokens representan tokens, es decir, valores opacos que producen y consumen algunas operaciones. Los tokens se usan para imponer el orden de ejecución en las operaciones, como se describe en la sección Ejecución.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Los tipos de tupla representan tuplas, es decir, listas heterogéneas. Las tuplas son una función heredada que solo existe para la compatibilidad con HLO. En HLO, se usan tuplas para representar entradas y salidas variadic. En StableHLO, las entradas y salidas variadic se admiten de forma nativa, y el único uso de tuplas en StableHLO es representar de forma integral la ABI de HLO, en la que, p. ej., T
, tuple<T>
y tuple<tuple<T>>
pueden ser muy diferentes según una implementación en particular. En el futuro, planeamos realizar cambios en la ABI de HLO que podrían permitirnos quitar los tipos de tupla de StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Los tipos de elementos representan elementos de tipos de tensores. A diferencia de muchos lenguajes de programación, estos tipos no son de primera clase en StableHLO. Esto significa que los programas de StableHLO no pueden representar directamente valores de estos tipos (como resultado, es idiomático representar valores escalares de tipo T
con valores de tensor de 0 dimensiones de tipo tensor<T>
).
- El tipo booleano representa los valores booleanos
true
yfalse
. - Los tipos de números enteros pueden tener firma (
si
) o no (ui
) y tener uno de los anchos de bits admitidos (2
,4
,8
,16
,32
o64
). Los tipossiN
con firma representan valores enteros del-2^(N-1)
al2^(N-1)-1
inclusive, y los tiposuiN
sin firma representan valores enteros del0
al2^N-1
inclusive. - Los tipos de números de punto flotante pueden ser uno de los siguientes:
f8E3M4
,f8E4M3
yf8E5M2
son números de punto flotante de 8 bits que siguen las convenciones de IEEE-754.- Tipos
f8E4M3FN
yf8E5M2
que corresponden, respectivamente, a las codificacionesE4M3
yE5M2
del formato FP8 que se describe en Formatos FP8 para el aprendizaje profundo. - Los tipos
f8E4M3FNUZ
yf8E5M2FNUZ
correspondientes a las codificacionesE4M3
yE5M2
de los formatos FP8 descritos en Formatos numéricos de 8 bits para redes neuronales profundas. - Tipo
f8E4M3B11FNUZ
que corresponde a la codificaciónE4M3
de los formatos FP8 que se describen en Inferencia y entrenamiento híbridos de punto flotante de 8 bits (HFP8) para redes neuronales profundas. - Tipo
bf16
que corresponde al formatobfloat16
que se describe en BFloat16: El secreto del alto rendimiento en las Cloud TPU. - Los tipos
f16
,f32
yf64
corresponden, respectivamente, a los formatosbinary16
("precisión media"),binary32
("precisión simple") ybinary64
("precisión doble") descritos en el estándar IEEE 754. - El tipo
tf32
corresponde al formato TensorFloat32 y tiene compatibilidad limitada en StableHLO. - Tipos de MX (microescalamiento)
f4E2M1FN
,f6E2M3FN
,f6E3M2FN
yf8E8M0FNU
que se describen en la Especificación de formatos de microescalamiento OCP.
- Los tipos complejos representan valores complejos que tienen una parte real y una parte imaginaria del mismo tipo de elemento. Los tipos complejos compatibles son
complex<f32>
(ambas partes son de tipof32
) ycomplex<f64>
(ambas partes son de tipof64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Los tipos de funciones representan funciones con nombre y anónimas. Tienen tipos de entrada (la lista de tipos en el lado izquierdo de ->
) y tipos de salida (la lista de tipos en el lado derecho de ->
). En muchos lenguajes de programación, los tipos de funciones son de primera clase, pero no en StableHLO.
StringType ::= 'string'
El tipo de cadena representa secuencias de bytes. A diferencia de muchos lenguajes de programación, el tipo de cadena no es la primera clase en StableHLO y solo se usa para especificar metadatos estáticos para elementos del programa.
Operaciones
Las operaciones de StableHLO (que también se denominan ops) representan un conjunto cerrado de operaciones de alto nivel en modelos de aprendizaje automático. Como se mencionó anteriormente, la sintaxis de StableHLO se inspira en gran medida en MLIR, que no es necesariamente la alternativa más ergonómica, pero es, sin duda, la más adecuada para el objetivo de StableHLO de crear más interoperabilidad entre los frameworks y los compiladores de AA.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Las operaciones StableHLO (que también se llaman ops) tienen un nombre, entradas y salidas, y una firma. El nombre consta del prefijo stablehlo.
y una mnemotecnia que identifica de forma exclusiva una de las operaciones admitidas. Consulta a continuación una lista completa de todas las operaciones compatibles.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Las operaciones consumen entradas y producen salidas. Las entradas se clasifican en valores de entrada (calculados durante la ejecución), funciones de entrada (proporcionadas de forma estática porque, en StableHLO, las funciones no son valores de primera clase) y atributos de entrada (también se proporcionan de forma estática). El tipo de entradas y salidas que consume y produce una operación depende de su mnemónico. Por ejemplo, la operación add
consume 2 valores de entrada y produce 1 valor de salida. En comparación, la operación select_and_scatter
consume 3 valores de entrada, 2 funciones de entrada y 3 atributos de entrada.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Las funciones de entrada (también llamadas funciones anónimas) son muy similares a las funciones con nombre, excepto que: 1) no tienen un identificador (de ahí el nombre “anónimo”) y 2) no declaran tipos de salida (los tipos de salida se infieren de la operación return
dentro de la función).
La sintaxis de las funciones de entrada incluye una parte que no se usa actualmente (consulta la producción de Unused
anterior) que está allí para la compatibilidad con MLIR. En MLIR, hay un concepto más general de "regiones" que pueden tener varios "bloques" de operaciones conectados entre sí a través de JumpOps. Estos bloques tienen IDs que corresponden a la producción de Unused
, de modo que se puedan distinguir entre sí.
StableHLO no tiene operaciones de salto, por lo que la parte correspondiente de la sintaxis de MLIR no se usa (pero todavía está allí).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Los atributos de entrada tienen un nombre y un valor que es una de las constantes admitidas. Son la forma principal de especificar metadatos estáticos para los elementos de programa. Por ejemplo, la operación concatenate
usa el atributo dimension
para especificar la dimensión a lo largo de la cual se concatenan sus valores de entrada. De manera similar, la operación slice
usa varios atributos, como start_indices
y limit_indices
, para especificar los límites que se usan para cortar el valor de entrada.
Por el momento, los programas de StableHLO en el entorno real a veces contienen atributos que no se describen en este documento. En el futuro, planeamos absorber estos atributos en el conjunto de operaciones de StableHLO o prohibir que aparezcan en los programas de StableHLO. Mientras tanto, aquí tienes la lista de estos atributos:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Metadatos de ubicación (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma de la operación consta de los tipos de todos los valores de entrada (la lista de tipos en el lado izquierdo de ->
) y los tipos de todos los valores de salida (la lista de tipos en el lado derecho de ->
). Estrictamente hablando, los tipos de entrada son redundantes y los tipos de salida también suelen serlo (porque, para la mayoría de las operaciones de StableHLO, los tipos de salida se pueden inferir a partir de las entradas). Sin embargo, la firma de op forma parte de la sintaxis de StableHLO de forma deliberada para garantizar la compatibilidad con MLIR.
A continuación, se muestra un ejemplo de operación cuyo mnemónico es select_and_scatter
. Consume 3 valores de entrada (%operand
, %source
y %init_value
), 2 funciones de entrada y 3 atributos de entrada (window_dimensions
, window_strides
y padding
). Observa cómo la firma de la operación solo incluye los tipos de sus valores de entrada (pero no los tipos de funciones y atributos de entrada que se proporcionan intercalados).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Las constantes de StableHLO tienen un literal y un tipo que, en conjunto, representan un valor de StableHLO. Por lo general, el tipo forma parte de la sintaxis de la constante, excepto cuando es inequívoco (p.ej., una constante booleana tiene el tipo i1
de forma inequívoca, mientras que una constante de número entero puede tener varios tipos posibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Las constantes booleanas representan los valores booleanos true
y false
. Las constantes booleanas tienen el tipo i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Las constantes de números enteros representan valores de números enteros a través de cadenas que usan notación decimal o hexadecimal. No se admiten otras bases, como las binarias o las octal. Las constantes de números enteros tienen las siguientes restricciones:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Las constantes de punto flotante representan valores de punto flotante a través de cadenas que usan notación decimal o científica. Además, la notación hexadecimal se puede usar para especificar directamente los bits subyacentes en el formato de punto flotante del tipo correspondiente. Las constantes de punto flotante tienen las siguientes restricciones:
- (C1) Si se usa una notación no hexadecimal,
is_wellformed(float_literal, float_type)
. - (C2) Si se usa la notación hexadecimal,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Las constantes complejas representan valores complejos con listas de una parte real (primero) y una parte imaginaria (segundo). Por ejemplo,
(1.0, 0.0) : complex<f32>
representa 1.0 + 0.0i
y
(0.0, 1.0) : complex<f32>
representa 0.0 + 1.0i
. El orden en que estas partes se almacenan en la memoria se define en la implementación. Las constantes complejas tienen las siguientes restricciones:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Las constantes de tensor representan valores de tensores con listas anidadas especificadas a través de la notación NumPy. Por ejemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
representa un valor de tensor con la siguiente asignación de índices a elementos: {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
y {1, 2} => 6
. El orden en el que se almacenan estos elementos en la memoria se define en la implementación. Las constantes de tensores tienen las siguientes restricciones:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, en el que:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, donde:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- de lo contrario,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Las constantes de tensores cuantificados representan valores de tensores cuantificados con la misma notación que las constantes de tensores, con elementos especificados como constantes de su tipo de almacenamiento. Las constantes de tensores cuantificados tienen las siguientes restricciones:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Los literales de cadena consisten en bytes especificados con caracteres ASCII y secuencias de escape. Son independientes de la codificación, por lo que la interpretación de estos
bytes está definida por la implementación. Los literales de cadena tienen el tipo string
.
Ops
abs
Semántica
Realiza la operación abs a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros firmados: módulo de números enteros
- Para números de punto flotante:
abs
de IEEE-754. - Para números complejos: módulo complejo.
- Para tipos cuantificados:
dequantize_op_quantize(abs, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de número entero firmado, punto flotante o tipo complejo, o tensor cuantificado por tensor | (C1-C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero o punto flotante firmado, o tensor cuantificado por tensor | (C1-C2) |
Limitaciones
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
siis_complex(operand)
.baseline_element_type(operand)
en caso contrario.
Ejemplos
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Semántica
Realiza la adición a nivel de elementos de dos tensores lhs
y rhs
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para valores booleanos: OR lógico.
- Para números enteros: suma de números enteros.
- Para números de punto flotante:
addition
de IEEE-754. - Para números complejos: suma compleja.
- Para tipos cuantificados:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado | (C1-C6) |
(I2) | rhs |
tensor o tensor cuantificado | (C1-C5), (C7) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C7) |
Limitaciones
- Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C1)
type(lhs) = type(rhs) = type(result)
.
- (C1)
- Si la operación usa tensores cuantificados, haz lo siguiente:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
. - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
. - (C6) Si es
is_per_axis_quantized(lhs)
, entoncesquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Si
is_per_axis_quantized(rhs)
, entoncesquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semántica
Garantiza que las operaciones que producen inputs
se ejecuten antes que cualquier operación que dependa de result
. La ejecución de esta operación no hace nada, solo existe para establecer dependencias de datos de result
a inputs
.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
Cantidad variadica de token |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Ejemplos
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, concatena los valores de los tensores operands
de cada proceso a lo largo de all_gather_dim
y produce tensores results
.
La operación divide la cuadrícula del proceso de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
operands...@receiver = [operand@sender for sender in process_group]
para todos losreceiver
enprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
para todos losprocess
enprocess_group
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operands |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1), (C6) |
(I2) | all_gather_dim |
constante de tipo si64 |
(C1), (C6) |
(I3) | replica_groups |
Constante de tensor de 2 dimensiones de tipo si64 |
(C2-C4) |
(I4) | channel_id |
constante de tipo si64 |
(C5) |
(I5) | use_global_device_ids |
constante de tipo i1 |
(C5) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C6) |
Limitaciones
- (C1)
0 <= all_gather_dim < rank(operands...)
. - (C2)
is_unique(replica_groups)
- (C3)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Si
use_global_device_ids = true
, entonceschannel_id > 0
. - (C6)
type(results...) = type(operands...)
, excepto:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, aplica una función de reducción computation
a los valores de los tensores operands
de cada proceso y produce tensores results
.
La operación divide la cuadrícula del proceso StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
results...@process[result_index] = exec(schedule)
para algún árbol binarioschedule
en el que:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
es un árbol binario definido por la implementación cuyo recorrido en orden esto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operands |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C5), (C6) |
(I2) | replica_groups |
Número variádico de constantes tensores unidimensionales del tipo si64 |
(C1-C3) |
(I3) | channel_id |
constante de tipo si64 |
(C4) |
(I4) | use_global_device_ids |
constante de tipo i1 |
(C4) |
(I5) | computation |
función | C5 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C6-C7) |
Limitaciones
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Si
use_global_device_ids = true
, entonceschannel_id > 0
. - (C5)
computation
tiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>)
en el queis_promotable(element_type(operand), E)
. - (C6)
shape(results...) = shape(operands...)
. - (C7)
element_type(results...) = E
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, divide los valores de los tensores operands
a lo largo de split_dimension
en partes, dispersa las partes divididas entre los procesos, concatena las partes dispersas a lo largo de concat_dimension
y produce tensores results
.
La operación divide la cuadrícula del proceso de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Luego, dentro de cada process_group
, haz lo siguiente:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
para todos lossender
enprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
en el quereceiver_index = process_group.index(receiver)
.results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operands |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C1-C3), (C9) |
(I2) | split_dimension |
constante de tipo si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
constante de tipo si64 |
(C3), (C9) |
(I4) | split_count |
constante de tipo si64 |
(C2), (C4), (C8) y (C9) |
(I5) | replica_groups |
Constante de tensor bidimensional de tipo si64 |
(C5-C8) |
(I6) | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C9) |
Limitaciones
- (C1)
0 <= split_dimension < rank(operands...)
. - (C2)
dim(operands..., split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operands...)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(results...) = type(operands...)
, excepto sisplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
y
Semántica
Realiza el operador AND a nivel de elementos de dos tensores, lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para valores booleanos: AND lógico.
- Para números enteros: AND binario.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo booleano o de número entero | (C1) |
(I2) | rhs |
tensor de tipo booleano o número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano o número entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semántica
Realiza la operación atan2 a nivel de elementos en el tensor lhs
y rhs
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
atan2
de IEEE-754. - Para números complejos: atan2 complejo.
- Para tipos cuantificados:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semántica
Calcula las gradientes de varias entradas de la propagación hacia atrás de batch_norm_training
desde grad_output
y produce tensores grad_operand
, grad_scale
y grad_offset
. De manera más formal, esta operación se puede expresar como una descomposición de las operaciones StableHLO existentes mediante la sintaxis de Python de la siguiente manera:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Para los tipos cuantizados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1-C3), (C5) |
(I2) | scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2), (C4), (C5) |
(I3) | mean |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I4) | variance |
Tensor de 1 dimensión de tipo cuantificado por tensor o de punto flotante | (C2) y (C4) |
(I5) | grad_output |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2), (C3) |
(I6) | epsilon |
constante de tipo f32 |
|
(I7) | feature_index |
constante de tipo si64 |
(C1), (C5) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
grad_operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2), (C3) |
grad_scale |
Tensor de 1 dimensión de tipo cuantificado por tensor o de punto flotante | (C2) y (C4) |
grad_offset |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2), (C4) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
ygrad_offset
tienen el mismobaseline_element_type
. - (C3)
operand
,grad_output
ygrad_operand
tienen la misma forma. - (C4)
scale
,mean
,variance
,grad_scale
ygrad_offset
tienen la misma forma. - (C5)
size(scale) = dim(operand, feature_index)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semántica
Normaliza el tensor operand
en todas las dimensiones, excepto en la dimensión feature_index
, y produce un tensor result
. De manera más formal, esta operación se puede expresar como una descomposición de las operaciones de StableHLO existentes con la sintaxis de Python de la siguiente manera:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Para los tipos cuantificados, realiza dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1-C7) |
(I2) | scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C3) |
(I3) | offset |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I4) | mean |
Tensor de 1 dimensión de tipo cuantificado por tensor o de punto flotante | (C5) |
(I5) | variance |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2), (C6) |
(I6) | epsilon |
constante de tipo f32 |
|
(I7) | feature_index |
constante de tipo si64 |
(C1), (C3-C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2) y (C7) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
yresult
tienen el mismobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
- (C5)
size(mean) = dim(operand, feature_index)
- (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semántica
Calcula la media y la varianza en todas las dimensiones, excepto en la dimensión feature_index
, y normaliza el tensor operand
que produce los tensores output
, batch_mean
y batch_var
. De forma más formal, esta operación se puede expresar como una descomposición de las operaciones de StableHLO existentes con la sintaxis de Python de la siguiente manera:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Para los tipos cuantificados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
(I2) | scale |
Tensor de 1 dimensión de punto flotante o cuantificado por tensor | (C2) y (C3) |
(I3) | offset |
Tensor de 1 dimensión de punto flotante o cuantificado por tensor | (C2), (C4) |
(I4) | epsilon |
constante de tipo f32 |
(C1), (C3-C6) |
(I5) | feature_index |
constante de tipo si64 |
(C1), (C3-C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C7) |
batch_mean |
Tensor de 1 dimensión de punto flotante o cuantificado por tensor | (C2) y (C5) |
batch_var |
Tensor de 1 dimensión de punto flotante o cuantificado por tensor | (C2), (C6) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand)
- (C2)
operand
,scale
,offset
,batch_mean
,batch_var
youtput
tienen el mismobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semántica
Realiza una operación de transmisión de bits en el tensor operand
y produce un tensor result
en el que los bits de todo el tensor operand
se reinterpretan con el tipo del tensor result
.
De manera más formal, teniendo en cuenta E = element_type(operand)
, E' = element_type(result)
y R = rank(operand)
:
- Si es
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Si es
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Si es
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
muestra la representación en memoria de un valor determinado, y su comportamiento se define en la implementación porque la representación exacta de los tensores y la representación exacta de los tipos de elementos también se definen en la implementación.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C2) |
Limitaciones
- (C1) Dados
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
yR = rank(operand)
:- Si es
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Si
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
para todos los0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Si
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
para todos los0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Si es
- (C2) Si
is_complex(operand) or is_complex(result)
, entoncesis_complex(operand) and is_complex(result)
.
Ejemplos
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semántica
Expande las dimensiones o el rango de un tensor de entrada mediante la duplicación de los datos en el tensor operand
y produce un tensor result
. De forma más formal, result[result_index] = operand[operand_index]
, donde para todos los d
en axes(operand)
:
operand_index[d] = 0
sidim(operand, d) = 1
.- De lo contrario,
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C2-C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3), (C5-C6) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
,scales(operand)
yzero_points(operand)
pueden diferir dequantization_dimension(result)
,scales(result)
yzero_points(result)
respectivamente, de lo contrario.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
- (C5) Para todos los
d
enaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si
dim(operand, quantization_dimension(operand)) = 1
, entoncesscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Ejemplos
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
caso
Semántica
Produce el resultado de la ejecución de exactamente una función de branches
, según el valor de index
. De manera más formal, result = selected_branch()
, que implica lo siguiente:
selected_branch = branches[index]
si0 <= index < size(branches)
.selected_branch = branches[-1]
en caso contrario.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | index |
Tensor de dimensión 0 de tipo si32 |
|
(I2) | branches |
cantidad de funciones variadic | (C1-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores, tensores cuantificados o tokens | (C4) |
Limitaciones
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Ejemplos
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Semántica
Realiza la operación de raíz cúbica a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
rootn(x, 3)
de IEEE-754. - Para números complejos: raíz cúbica compleja.
- Para tipos cuantizados:
dequantize_op_quantize(cbrt, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semántica
Realiza el techo de elementos del tensor operand
y produce un tensor result
.
Implementa la operación roundToIntegralTowardPositive
de la especificación IEEE-754. Para los tipos cuantificados, realiza dequantize_op_quantize(ceil, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Semántica
Calcula la descomposición de Cholesky de un lote de matrices.
De forma más formal, para todos los i
en index_space(result)
, result[i0, ..., iR-3, :, :]
es una descomposición de Cholesky de a[i0, ..., iR-3, :, :]
, en forma de matriz triangular inferior (si lower
es true
) o triangular superior (si lower
es false
).
Los valores de salida en el triángulo opuesto, es decir, el triángulo superior estricto o el triángulo inferior estricto, según corresponda, se definen según la implementación.
Si existe i
en la que la matriz de entrada no es una matriz hermítica positiva definida, el comportamiento no está definido.
Para los tipos cuantizados, realiza dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | a |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1-C3) |
(I2) | lower |
Constante de tensor de 0 dimensiones de tipo i1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(a) = baseline_type(result)
- (C2)
2 <= rank(a)
- (C3)
dim(a, -2) = dim(a, -1)
.
Ejemplos
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
restringir
Semántica
Limita cada elemento del tensor operand
entre un valor mínimo y máximo, y produce un tensor result
. De forma más formal, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
, donde min_element = rank(min) = 0 ? min[] : min[result_index]
, max_element = rank(max) = 0 ? max[] : max[result_index]
. Para los tipos cuantificados, realiza dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Imponer un orden en números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | min |
tensor o tensor cuantificado por tensor | (C1) y (C3) |
(I2) | operand |
tensor o tensor cuantificado por tensor | (C1-C4) |
(I3) | max |
tensor o tensor cuantificado por tensor | (C2), (C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C4) |
Limitaciones
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, envía el valor del tensor operand
desde el proceso de origen a los procesos de destino y produce un tensor result
.
La operación divide la cuadrícula del proceso de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
sichannel_id <= 0
.cross_partition(replica_groups)
sichannel_id > 0
.
Luego, result@process
se obtiene de la siguiente manera:
operand@process_groups[i, 0]
si existe uni
de modo que el proceso esté enprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
en caso contrario.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C3) |
(I2) | replica_groups |
Cantidad variacional de constantes de tensor de 1 dimensión de tipo si64 |
(C1), (C2) |
(I3) | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3) |
Limitaciones
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, en el queN
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C3)
type(result) = type(operand)
.
Ejemplos
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, envía el valor del tensor operand
del proceso de origen al proceso de destino y produce un tensor result
.
La operación divide la cuadrícula del proceso de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(source_target_pairs)
sichannel_id <= 0
.cross_partition(source_target_pairs)
sichannel_id > 0
.
Luego, result@process
se obtiene de la siguiente manera:
operand@process_groups[i, 0]
, si existe uni
tal queprocess_groups[i, 1] = process
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
en caso contrario.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C5) |
(I2) | source_target_pairs |
Constante de tensor de 2 dimensiones de tipo si64 |
(C1-C4) |
(I3) | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
- (C3)
is_unique(source_target_pairs[:, 1])
- (C4)
0 <= source_target_pairs < N
, dondeN
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C5)
type(result) = type(operand)
.
Ejemplos
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
comparar
Semántica
Realiza una comparación a nivel de elementos de los tensores lhs
y rhs
según comparison_direction
y compare_type
, y produce un tensor result
.
Los valores de comparison_direction
y compare_type
tienen la siguiente semántica:
Para los tipos de elementos booleanos y enteros, haz lo siguiente:
EQ
:lhs = rhs
.NE
:lhs != rhs
.GE
:lhs >= rhs
.GT
:lhs > rhs
.LE
:lhs <= rhs
.LT
:lhs < rhs
.
Para los tipos de elementos de punto flotante con compare_type = FLOAT
, la operación implementa las siguientes operaciones IEEE-754:
EQ
:compareQuietEqual
.NE
:compareQuietNotEqual
.GE
:compareQuietGreaterEqual
.GT
:compareQuietGreater
.LE
:compareQuietLessEqual
.LT
:compareQuietLess
.
Para los elementos de punto flotante con compare_type = TOTALORDER
, la op usa la combinación de operaciones totalOrder
y compareQuietEqual
de IEEE-754.
Para los tipos de elementos complejos, la comparación lexicográfica de los pares (real, imag)
se realiza con los comparison_direction
y compare_type
proporcionados.
Imponer un orden en los números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos cuando comparison_direction
sea GE
, GT
, LE
o LT
(#560).
Para tipos cuantizados, realiza dequantize_compare(lhs, rhs,
comparison_direction)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1-C3) |
(I2) | rhs |
tensor o tensor cuantificado por tensor | (C1-C2) |
(I3) | comparison_direction |
enum de EQ , NE , GE , GT , LE y LT |
|
(I4) | compare_type |
enum de FLOAT , TOTALORDER , SIGNED y UNSIGNED |
(C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano | (C2) |
Limitaciones
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
se define de la siguiente manera:SIGNED
siis_signed_integer(element_type(lhs))
.UNSIGNED
siis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oTOTALORDER
si esis_float(element_type(lhs))
.FLOAT
siis_complex(element_type(lhs))
.
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
emergencia compleja,
Semántica
Realiza la conversión de elementos a un valor complejo a partir de un par de valores reales e imaginarios, lhs
y rhs
, y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo f32 o f64 |
(C1-C3) |
(I2) | rhs |
tensor de tipo f32 o f64 |
(C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo complejo | (C2), (C3) |
Limitaciones
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
tiene el tipocomplex<E>
, en el queE = element_type(lhs)
.
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
compuesto
Semántica
Encapsula una operación compuesta por otras operaciones de StableHLO, que toma inputs
y composite_attributes
y produce results
. La semántica de la operación se implementa con el atributo decomposition
. La operación composite
se puede reemplazar por su descomposición sin cambiar la semántica del programa. En los casos en que la incorporación de la descomposición no proporciona la misma semántica de operación, prefiere usar custom_call
.
El campo version
(el valor predeterminado es 0
) se usa para indicar cuándo cambia la semántica de un compuesto.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
Cantidad de valores variadic |
(I2) | name |
constante de tipo string |
(I3) | composite_attributes |
diccionario de atributos |
(I4) | decomposition |
constante de tipo string |
(I5) | version |
constante de tipo si32 |
Salidas
Nombre | Tipo |
---|---|
results |
cantidad variádica de valores |
Limitaciones
- (C1)
is_namespaced_op_name(name)
- C2:
is_defined_in_parent_scope(decomposition)
- (C3)
types(inputs...) == input_types(decomposition)
- (C4)
types(results...) == output_types(decomposition)
Ejemplos
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semántica
Concatena inputs
junto con la dimensión dimension
en el mismo orden que los argumentos dados y produce un tensor result
. De forma más formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, en la que:
id = d0 + ... + dk-1 + kd
.d
es igual adimension
, yd0
, ... son los tamaños ded
a dimensión deinputs
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C6) |
(I2) | dimension |
constante de tipo si64 |
(C2), (C4), (C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5-C6) |
Limitaciones
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
, excepto pordim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, excepto en los siguientes casos:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Ejemplos
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Semántica
Produce un tensor output
a partir de una constante value
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | value |
constante | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor o tensor cuantificado | (C1) |
Limitaciones
- (C1)
type(value) = type(output)
.
Ejemplos
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
generar una conversión
Semántica
Realiza una conversión por elemento de un tipo de elemento a otro en el tensor operand
y produce un tensor result
.
Para las conversiones de boolean-to-any-supported-type, el valor false
se convierte en cero y el valor true
se convierte en uno. En el caso de las conversiones de any-supported-type-to-boolean, un valor cero se convierte en false
y los valores distintos de cero se convierten en true
. A continuación, se explica cómo funciona esto para los tipos complejos.
En el caso de las conversiones que involucran número entero a número entero, número entero a número de punto flotante o número de punto flotante a número de punto flotante, si el valor de origen se puede representar exactamente en el tipo de destino, el valor del resultado es esa representación exacta. De lo contrario, el comportamiento está por definir (#180).
En el caso de las conversiones que involucran números de punto flotante a números enteros, la parte fraccionaria se trunca. Si el valor truncado no se puede representar en el tipo de destino, el comportamiento es por definir (#180).
Las conversiones que involucran números complejos a números complejos siguen el mismo comportamiento de las conversiones de números de punto flotante a números de punto flotante para convertir partes reales e imaginarias.
Para las conversiones complex-to-any-other-type y any-other-type-to-complex, se ignora el valor imaginario de origen o se establece en cero el valor imaginario de destino, respectivamente. La conversión de la parte real sigue las conversiones de punto flotante.
En principio, esta operación podría expresar la decuantización (conversión de tensores cuantificados a tensores normales), la cuantificación (conversión de tensores normales a tensores cuantificados) y la recantidad (conversión entre tensores cuantificados), pero, por el momento, tenemos operaciones dedicadas para eso: uniform_dequantize
para el primer caso de uso y uniform_quantize
para el segundo y el tercer caso de uso. En el futuro, es posible que estas dos operaciones se combinen en convert
(#1576).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor | (C1) |
Limitaciones
- (C1)
shape(operand) = shape(result)
Ejemplos
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolución
Semántica
Calcula los productos escalares entre ventanas de lhs
y rebanadas de rhs
y produce result
. En el siguiente diagrama, se muestra cómo se calculan los elementos de result
a partir de lhs
y rhs
con un ejemplo concreto.
De manera más formal, considera el siguiente replanteamiento de las entradas en términos de lhs
para poder expresar ventanas de lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Este replanteamiento usa las siguientes funciones auxiliares:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
en el quej[d] = i[permutation[d]]
.
Si es feature_group_count = 1
y batch_group_count = 1
, entonces para todos los output_spatial_index
en index_space(dim(result, output_spatial_dimensions...))
, result[result_shape(:, output_spatial_index, :)] = dot_product
, donde:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Parece que esta función no se usa, por lo que planeamos quitarla en el futuro (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Si es feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Si batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Para los tipos cuantificados, realiza dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Para los tipos cuantificados híbridos, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32) y (C34) |
(I2) | rhs |
tensor o tensor cuantificado | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Constante de tensor de 1 dimensión de tipo si64 |
(C2-C3), (C25) |
(I4) | padding |
Constante de tensor de 2 dimensiones de tipo si64 |
(C4), (C25) |
(I5) | lhs_dilation |
Constante de tensor de 1 dimensión de tipo si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Constante de tensor de 1 dimensión de tipo si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Constante de tensor unidimensional de tipo i1 |
(C9) |
(I8) | input_batch_dimension |
constante de tipo si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
constante de tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C12), (C13), (C25) |
(I11) | kernel_input_feature_dimension |
constante de tipo si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
constante de tipo si64 |
(C15-C16), (C18), (C25), (C29) |
(I13) | kernel_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
constante de tipo si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
constante de tipo si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
constante de tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
constante de tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
Cantidad variádica de enumeraciones de DEFAULT , HIGH y HIGHEST |
(C24) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C25-C28), (C30), (C32-34) |
Limitaciones
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
- (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
- (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
se define de la siguiente manera:dim(lhs, input_batch_dimension) / batch_group_count
siresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
siresult_dim = output_feature_dimension
.num_windows
de lo contrario, donde:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si la operación usa tensores cuantificados, haz lo siguiente:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si es
is_per_axis_quantized(rhs)
, entoncesquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, entoncesquantization_dimension(result) = output_feature_dimension
. - Si
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, entoncesis_per_tensor_quantized(result)
. - Si
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Ejemplos
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semántica
Realiza una operación de coseno en términos de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
cos
de IEEE-754. - Para números complejos: coseno complejo.
- Para tipos cuantificados:
dequantize_op_quantize(cosine, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semántica
Realiza un recuento por elemento de la cantidad de bits de ceros iniciales en el tensor operand
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo número entero | (C1) |
Limitaciones
- (C1)
type(operand) = type(result)
.
Ejemplos
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semántica
Encapsula una operación call_target_name
definida por la implementación que toma inputs
y called_computations
y produce results
. Se pueden usar has_side_effect
, backend_config
y api_version
para proporcionar metadatos adicionales definidos por la implementación.
Por el momento, esta operación contiene una colección bastante desorganizada de metadatos que refleja la evolución orgánica de su operación equivalente en el compilador de XLA. En el futuro, planeamos unificar estos metadatos (#741).
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
Cantidad de valores variadic |
(I2) | call_target_name |
constante de tipo string |
(I3) | has_side_effect |
constante de tipo i1 |
(I4) | backend_config |
constante de tipo string o diccionario de atributos |
(I5) | api_version |
constante de tipo si32 |
(I6) | called_computations |
Cantidad variacional de constantes de tipo string |
Salidas
Nombre | Tipo |
---|---|
results |
Cantidad de valores variadic |
Ejemplos
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
dividir
Semántica
Realiza la división por elementos de los tensores lhs
del dividendo y rhs
del divisor, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: División de números enteros que produce el cociente algebraico con cualquier parte fraccionaria descartada.
- Para números de punto flotante:
division
de IEEE-754. - Para números complejos: división compleja.
- Para tipos cuantizados:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semántica
Calcula los productos punto entre las porciones de lhs
y las porciones de rhs
, y produce un tensor result
.
Más formalmente, result[result_index] = dot_product
, donde:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
en el quesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
ysize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Para los tipos cuantificados, realiza dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Para tipos híbridos cuantizados, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
controla el equilibrio entre la velocidad y la precisión para los procesamientos en los backends del acelerador. Puede ser uno de los siguientes (por el momento, la semántica de estos valores de enum no está especificada, pero planeamos abordar esto en #755):
DEFAULT
: Es el cálculo más rápido, pero la aproximación menos precisa al número original.HIGH
: Cálculo más lento, pero aproximación más precisa al número original.HIGHEST
: Es el cálculo más lento, pero la aproximación más precisa al número original.
Un DotAlgorithm
define las propiedades principales del algoritmo que se usa para implementar la operación punto, que también define la precisión. Si se establecen los campos de atributos del algoritmo, precision_config
debe ser DEFAULT
. DotAlgorithms
no tiene un valor predeterminado, ya que los parámetros predeterminados se definen en la implementación. Por lo tanto, todos los campos del algoritmo de punto se pueden establecer en None
para especificar un algoritmo de punto vacío, que usará el valor precision_config
.
Los campos DotAlgorithm
incluyen lo siguiente:
lhs_precision_type
yrhs_precision_type
, las precisiones a las que se redondean el LHS y el RHS de la operación. Los tipos de precisión son independientes de los tipos de almacenamiento de las entradas y la salida.accumulation_type
es la precisión que se usa para la acumulación.lhs_component_count
,rhs_component_count
ynum_primitive_operations
se aplican cuando realizamos un algoritmo que descompone el lado izquierdo o derecho en varios componentes y realiza varias operaciones de punto "primitivo" en esos valores, por lo general, para emular una precisión más alta (p. ej., Aprovecha el tipo de datos de inteligencia artificial bfloat16 para realizar cálculos de mayor precisión: bf16_6x tf32_3x, etc.). Para los algoritmos sin descomposición, estos valores deben establecerse en1
.allow_imprecise_accumulation
para especificar si se permite la acumulación en una precisión más baja para algunos pasos (p. ej.,CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Atributos DotAlgorithm
de ejemplo:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Depende de las implementaciones decidir qué combinaciones son compatibles. En general, no se garantiza que el consumidor de StableHLO admita cada algoritmo en cada tipo de acelerador. Si no se admite un algoritmo determinado, se debe generar un error en lugar de recurrir a una alternativa. La verificación de StableHLO proporcionará una verificación del mejor esfuerzo, lo que evitará que se usen algoritmos que no se sepan compatibles con ningún hardware.
Consulta xla_data.proto > Algorithm
para ver algunos valores de algoritmos compatibles. En el ticket n° 2483, se captura el plan para crear un doc
centralizado sobre los algoritmos compatibles por backend.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
tensor o tensor cuantificado | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C1), (C4), (C7) y (C9) |
(I5) | lhs_contracting_dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
Cantidad variacional de enums de DEFAULT , HIGH y HIGHEST |
(C11), (C21) |
(I8) | lhs_precision_type |
FloatType o TensorFloat32 | C21 |
(I9) | rhs_precision_type |
FloatType o TensorFloat32 | C21 |
(I10) | accumulation_type |
FloatType o TensorFloat32 | (C21) |
(I11) | lhs_component_count |
constante de tipo si32 |
(C21), (C22) |
(I12) | rhs_component_count |
constante de tipo si32 |
(C21), (C23) |
(I13) | num_primitive_operations |
constante de tipo si32 |
(C21), (C24) |
(I14) | allow_imprecise_accumulation |
constante de tipo bool |
C21 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C12), (C14), (C18-C20) |
Limitaciones
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
- (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
- (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Si la operación usa tensores cuantificados, haz lo siguiente:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Si
is_per_axis_quantized(rhs)
, entoncesquantization_dimension(rhs)
no está enrhs_contracting_dimensions
. - Si
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Si
is_per_tensor_quantized(rhs)
, entoncesis_per_tensor_quantized(result)
. - Si
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Si
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Ejemplos
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semántica
Esta operación es funcionalmente idéntica a la operación broadcast_in_dim, pero la forma del resultado se especifica de forma dinámica a través de output_dimensions
.
La operación también acepta atributos opcionales known_expanding_dimensions
y known_nonexpanding_dimensions
para expresar el conocimiento estático sobre el comportamiento de expansión de las dimensiones.
Si no se especifica, se supone que todas las dimensiones pueden expandirse.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensor de 1 dimensión de tipo número entero | (C7) |
(I3) | broadcast_dimensions |
Tensor constante de 1 dimensión de tipo número entero | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensor constante de 1 dimensión de tipo número entero | (C8-C9) |
(I5) | known_nonexpanding_dimensions |
Tensor constante de 1 dimensión de tipo número entero | (C8-C9) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3), (C5-C7) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
,scales(operand)
yzero_points(operand)
pueden diferir dequantization_dimension(result)
,scales(result)
yzero_points(result)
respectivamente, de lo contrario.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
- (C5) Para todos los
d
enaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si es
dim(operand, quantization_dimension(operand)) = 1
, entoncesscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions)
. - (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_nonexpanding_dimensions < rank(operand)
.
Ejemplos
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_nonexpanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semántica
Esta operación es funcionalmente idéntica a la operación de convolución, pero el padding se especifica de forma dinámica a través de padding
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31) y (C33) |
(I2) | rhs |
tensor o tensor cuantificado | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensor de 2 dimensiones de tipo número entero | (C4) |
(I4) | window_strides |
Constante de tensor de 1 dimensión de tipo si64 |
(C2-C3) |
(I5) | lhs_dilation |
Constante de tensor de 1 dimensión de tipo si64 |
(C5-C6) |
(I6) | rhs_dilation |
Constante de tensor de 1 dimensión de tipo si64 |
(C7-C8) |
(I7) | window_reversal |
Constante de tensor unidimensional de tipo i1 |
(C9) |
(I8) | input_batch_dimension |
constante de tipo si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
constante de tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C12), (C13) |
(I11) | kernel_input_feature_dimension |
constante de tipo si64 |
(C14), (C18) |
(I12) | kernel_output_feature_dimension |
constante de tipo si64 |
(C15-C16), (C18), (C28) |
(I13) | kernel_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C17-C18) |
(I14) | output_batch_dimension |
constante de tipo si64 |
C20 |
(I15) | output_feature_dimension |
constante de tipo si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C19-C20) |
(I17) | feature_group_count |
constante de tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
constante de tipo si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
Cantidad variádica de enumeraciones de DEFAULT , HIGH y HIGHEST |
(C24) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C25-C27), (C29), (C31-C33) |
Limitaciones
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
- (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
- (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
se define de la siguiente manera:dim(lhs, input_batch_dimension) / batch_group_count
siresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
siresult_dim = output_feature_dimension
.num_windows
de lo contrario, donde:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si la operación usa tensores cuantificados, haz lo siguiente:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si es
is_per_axis_quantized(rhs)
, entoncesquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, entoncesquantization_dimension(result) = output_feature_dimension
. - Si
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
. - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, entoncesis_per_tensor_quantized(result)
. - Si
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Ejemplos
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semántica
Esta operación es funcionalmente idéntica a la operación gather, con el slice_sizes
especificado de forma dinámica como un valor.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensor de tipo de número entero | (C2), (C3), (C13) |
(I3) | slice_sizes |
Tensor unidimensional de tipo de número entero | (C8), (C11-C13) |
(I4) | offset_dims |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Constante de tensor de 1 dimensión de tipo si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
constante de tipo si64 |
(C2), (C3), (C13) |
(I8) | indices_are_sorted |
constante de tipo i1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5), (C13-C14) |
Limitaciones
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
- (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
- (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
- (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
en el que:batch_dim_sizes = shape(start_indices)
, excepto que no se incluye el tamaño de la dimensión destart_indices
que corresponde aindex_vector_dim
.offset_dim_sizes = shape(slice_sizes)
, excepto que no se incluyen los tamaños de las dimensiones enslice_sizes
correspondientes acollapsed_slice_dims
.combine
colocabatch_dim_sizes
en los ejes correspondientes abatch_dims
yoffset_dim_sizes
en los ejes correspondientes aoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Ejemplos
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semántica
Esta operación es funcionalmente idéntica a la operación iota, pero la forma del resultado se especifica de forma dinámica a través de output_shape
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | output_shape |
Tensor de 1 dimensión de tipo número entero | (C1) y (C2) |
(I2) | iota_dimension |
si64 |
(C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C2) |
Limitaciones
- (C1)
0 <= iota_dimension < size(output_shape)
- (C2)
rank(result) = size(output_shape)
.
Ejemplos
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semántica
Esta operación es funcionalmente idéntica a la operación pad, pero con edge_padding_low
, edge_padding_high
y interior_padding
especificados de forma dinámica como valores.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C1) |
(I3) | edge_padding_low |
Tensor unidimensional de tipo de número entero | (C1), (C4) |
(I4) | edge_padding_high |
Tensor de 1 dimensión de tipo número entero | (C1), (C4) |
(I5) | interior_padding |
Tensor unidimensional de tipo de número entero | (C2-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3-C6) |
Limitaciones
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Ejemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semántica
Esta operación es funcionalmente idéntica a la operación reshape, pero la forma del resultado se especifica de forma dinámica a través de output_shape
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C3) |
(I2) | output_shape |
Tensor de 1 dimensión de tipo número entero | (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C4) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
yquantization_dimension(result)
pueden diferir.
- (C2)
size(operand) = size(result)
. - (C3) Si
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semántica
Extrae una porción de operand
con índices de inicio calculados de forma dinámica y produce un tensor result
. start_indices
contiene los índices iniciales de la porción para cada dimensión sujeta a un posible ajuste, y slice_sizes
contiene los tamaños de la porción para cada dimensión. De forma más formal, result[result_index] = operand[operand_index]
, en la que:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
(I2) | start_indices |
Número variadico de tensores de 0 dimensiones de tipo entero | (C2), (C3) |
(I3) | slice_sizes |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C4), (C5) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C5) |
Limitaciones
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Ejemplos
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semántica
Produce un tensor result
que es igual al tensor operand
, excepto que la porción que comienza en start_indices
se actualiza con los valores en update
.
De forma más formal, result[result_index]
se define de la siguiente manera:
update[update_index]
si0 <= update_index < shape(update)
en el que:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
operand[result_index]
en caso contrario.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C4), (C6) |
(I2) | update |
tensor o tensor cuantificado por tensor | (C2), (C3), (C6) |
(I3) | start_indices |
Número variadico de tensores de 0 dimensiones de tipo entero | (C4), (C5) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Ejemplos
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponencial
Semántica
Realiza una operación exponencial a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
exp
de IEEE-754. - Para números complejos: exponencial compleja.
- Para tipos cuantificados:
dequantize_op_quantize(exponential, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponencial_menos_uno
Semántica
Realiza una operación exponencial menos uno a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
expm1
de IEEE-754. - Para números complejos: exponencial compleja menos uno.
- Para tipos cuantificados:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
fft
Semántica
Realiza las transformaciones de Fourier directas e inversas para entradas y salidas reales y complejas.
fft_type
es una de las siguientes opciones:
FFT
: Reenvía FFT complejo a complejo.IFFT
: FFT inversa de complejo a complejo.RFFT
: FFT real a complejo directo.IRFFT
: FFT inversa de real a complejo (es decir, toma complejo y muestra real).
De manera más formal, dada la función fft
que toma tensores de 1 dimensión de tipos complejos como entrada, produce tensores de 1 dimensión del mismo tipo como salida y calcula la transformada de Fourier discreta:
En fft_type = FFT
, result
se define como el resultado final de una serie de cálculos de L en los que L = size(fft_length)
. Por ejemplo, para L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Además, dada la función ifft
que tiene la misma firma de tipo y calcula el inverso de fft
:
En fft_type = IFFT
, result
se define como la inversa de los cálculos para fft_type = FFT
. Por ejemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Además, dada la función rfft
, que toma tensores de 1 dimensión de tipos de números de punto flotante, produce tensores de 1 dimensión de tipos complejos de la misma semántica de números de punto flotante y funciona de la siguiente manera:
rfft(real_operand) = truncated_result
dondecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Cuando se calcula la transformada de Fourier discreta para operandos reales, los primeros elementos N/2 + 1
del resultado definen de forma inequívoca el resto del resultado, por lo que el resultado de rfft
se trunca para evitar el cálculo de elementos redundantes).
En fft_type = RFFT
, result
se define como el resultado final de una serie de cálculos de L en los que L = size(fft_length)
. Por ejemplo, para L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Por último, dada la función irfft
, que tiene el mismo tipo de firma y calcula la inversa de rfft
:
En fft_type = IRFFT
, result
se define como la inversa de los cálculos para fft_type = RFFT
. Por ejemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante | (C1), (C2), (C4), (C5) |
(I2) | fft_type |
enum de FFT , IFFT , RFFT y IRFFT |
(C2), (C5) |
(I3) | fft_length |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C3), (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo complejo o de punto flotante | (C2), (C4), (C5) |
Limitaciones
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relación entre los tipos de elementos
operand
yresult
varía:- Si
fft_type = FFT
,element_type(operand)
yelement_type(result)
tienen el mismo tipo complejo. - Si
fft_type = IFFT
,element_type(operand)
yelement_type(result)
tienen el mismo tipo complejo. - Si
fft_type = RFFT
,element_type(operand)
es un tipo de punto flotante yelement_type(result)
es un tipo complejo de la misma semántica de punto flotante. - Si
fft_type = IRFFT
,element_type(operand)
es un tipo complejo yelement_type(result)
es un tipo de punto flotante de la misma semántica de punto flotante.
- Si
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Si está entre
operand
yresult
, hay un tensorreal
de tipo de punto flotante, entonces,shape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, excepto en los siguientes casos:- Si es
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Si es
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Si es
Ejemplos
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piso
Semántica
Realiza el precio mínimo de elementos del tensor operand
y produce un tensor result
.
Implementa la operación roundToIntegralTowardNegative
de la especificación IEEE-754. Para los tipos cuantizados, realiza dequantize_op_quantize(floor, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
reunir
Semántica
Recopila rebanadas del tensor operand
de los desplazamientos especificados en start_indices
y produce un tensor result
.
En el siguiente diagrama, se muestra cómo los elementos de result
se asignan a elementos de operand
con un ejemplo concreto. El diagrama elige algunos índices result
de ejemplo y explica en detalle a qué índices operand
corresponden.
De forma más formal, result[result_index] = operand[operand_index]
, en la que:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
se define de la siguiente manera:start_indices[bi0, ..., :, ..., biN]
, en el quebi
son elementos individuales enbatch_index
y:
se inserta en el índiceindex_vector_dim
, siindex_vector_dim
<rank(start_indices)
.[start_indices[batch_index]]
en caso contrario.
- Para
d_operand
enaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
sid_operand = start_index_map[d_start]
.full_start_index[d_operand] = 0
en caso contrario.
- Para
d_operand
enaxes(operand)
,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_operand = operand_batching_dims[i_batching]
yd_start = start_indices_batching_dims[i_batching]
.- De lo contrario,
full_batching_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, en el queoi
son elementos individuales enoffset_index
y0
se inserta en los índices decollapsed_slice_dims
yoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
.
Si indices_are_sorted
es true
, la implementación puede suponer que start_indices
está ordenado en función de start_index_map
; de lo contrario, el comportamiento no está definido. De forma más formal, para todos los i1 < i2
de indices(result)
, full_start_index(i1) <= full_start_index(i2)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
tensor de tipo de número entero | (C2-C3), (C14), (C17) y (C22) |
(I3) | offset_dims |
Constante de tensor unidimensional de tipo si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Constante de tensor unidimensional de tipo si64 |
(C13-C17) |
(I7) | start_index_map |
Constante de tensor de 1 dimensión de tipo si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
constante de tipo si64 |
(C2-C3), (C15), (C22) |
(I9) | slice_sizes |
Constante de tensor de 1 dimensión de tipo si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
constante de tipo i1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5), (C22-C23) |
Limitaciones
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
- (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
- (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
. - (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
- (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
en el que:batch_dim_sizes = shape(start_indices)
, excepto que no se incluye el tamaño de la dimensión destart_indices
que corresponde aindex_vector_dim
.offset_dim_sizes = slice_sizes
, excepto que no se incluyen los tamaños de las dimensiones enslice_sizes
correspondientes acollapsed_slice_dims
yoperand_batching_dims
.combine
colocabatch_dim_sizes
en los ejes correspondientes abatch_dims
yoffset_dim_sizes
en los ejes correspondientes aoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Ejemplos
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semántica
Produce el tamaño de la dimension
determinada de operand
. Más formalmente, result = dim(operand, dimension)
. La semántica solo se relaciona con el componente de forma del tipo. El elemento-tipo puede ser cualquier cosa.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1) |
(I2) | dimension |
constante de tipo si64 |
(C1) |
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de 0 dimensiones de tipo si32 |
Limitaciones
- (C1)
0 <= dimension < rank(operand)
.
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semántica
Extrae el elemento en la posición index
de la tupla operand
y produce un result
. Más formalmente, result = operand[index]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tuple | (C1), (C2) |
(I2) | index |
constante de tipo si32 |
(C1), (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
cualquier tipo compatible | (C2) |
Limitaciones
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Ejemplos
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
si
Semántica
Produce el resultado de ejecutar exactamente una función de true_branch
o false_branch
según el valor de pred
. De forma más formal, result =
pred ? true_branch() : false_branch()
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | pred |
Tensor de 0 dimensiones de tipo i1 |
|
(I2) | true_branch |
función | (C1-C3) |
(I3) | false_branch |
función | (C1), (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores, tensores cuantificados o tokens | (C3) |
Limitaciones
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
- (C3)
type(results...) = output_types(true_branch)
.
Ejemplos
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imag
Semántica
Extrae la parte imaginaria, por elemento, de operand
y produce un tensor result
. De manera más formal, para cada elemento x
: imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante | (C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1), (C2) |
Limitaciones
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.element_type(operand)
en caso contrario.
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
In-feed
Semántica
Lee datos del feed in-feed y produce results
.
La semántica de infeed_config
se define según la implementación.
results
consta de valores de carga útil que aparecen primero y un token que aparece
al final. En el futuro, planeamos dividir la carga útil y el token en dos resultados separados para mejorar la claridad (#670).
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
constante de tipo string |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores, tensores cuantificados o tokens | (C1-C3) |
Limitaciones
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Ejemplos
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semántica
Llena un tensor output
con valores en orden creciente a partir de cero a lo largo de la dimensión iota_dimension
. Más formalmente,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | iota_dimension |
si64 |
(C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
0 <= iota_dimension < rank(output)
Ejemplos
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semántica
Realiza una verificación por elemento para determinar si el valor en x
es finito (es decir, no es +Inf, -Inf ni NaN) y produce un tensor y
. Implementa la operación isFinite
de la especificación IEEE-754. Para los tipos cuantificados, el resultado siempre es true
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | x |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
y |
tensor de tipo booleano | (C1) |
Limitaciones
- (C1)
shape(x) = shape(y)
.
Ejemplos
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semántica
Realiza la operación de logaritmo a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
log
de IEEE-754. - Para números complejos: logaritmo complejo.
- Para tipos cuantificados:
dequantize_op_quantize(log, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semántica
Realiza una operación de logaritmo a nivel de elementos más uno en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
logp1
de IEEE-754. - Para números complejos: logaritmo complejo más uno.
- Para tipos cuantificados:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logística
Semántica
Realiza una operación logística en términos de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
division(1, addition(1, exp(-x)))
de IEEE-754. - Para números complejos: logística compleja.
- Para tipos cuantizados:
dequantize_op_quantize(logistic, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mapa
Semántica
Aplica una función de asignación computation
a inputs
junto con dimensions
y produce un tensor result
.
De forma más formal, result[result_index] = computation(inputs...[result_index])
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C1-C4) |
(I2) | dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C3) |
(I3) | computation |
función | (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C4) |
Limitaciones
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
tiene el tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, en el queEi = element_type(inputs[i])
yE' = element_type(result)
.
Ejemplos
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
máxima
Semántica
Realiza la operación de máximo a nivel de elementos en los tensores lhs
y rhs
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para valores booleanos: OR lógico.
- Para números enteros: número entero máximo.
- Para números de punto flotante:
maximum
de IEEE-754. - Para números complejos: Máximo lexicográfico para el par
(real, imaginary)
. Imponer un orden en los números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560). - Para tipos cuantificados:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
mínima
Semántica
Realiza la operación min a nivel de elementos en los tensores lhs
y rhs
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para valores booleanos: AND lógico.
- Para números enteros: Número entero mínimo.
- Para números de punto flotante:
minimum
de IEEE-754. - Para números complejos: Mínimo lexicográfico para el par
(real, imaginary)
. Imponer un orden en los números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560). - Para tipos cuantificados:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplicar
Semántica
Realiza el producto a nivel de elementos de dos tensores lhs
y rhs
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico AND.
- Para números enteros: multiplicación de números enteros
- Para números de punto flotante:
multiplication
de IEEE-754. - Para números complejos: multiplicación compleja.
- Para tipos cuantificados:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negativo
Semántica
Realiza la negación por elemento del tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros firmados: negación de números enteros
- Para números enteros sin firma: conversión de bits a número entero firmado, negación de números enteros, conversión de bits a número entero sin firma.
- Para números de punto flotante:
negate
de IEEE-754. - Para números complejos: Negación compleja.
- Para tipos cuantizados:
dequantize_op_quantize(negate, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
no
Semántica
Realiza la operación NOT a nivel de elementos del tensor operand
y produce un tensor result
.
Según el tipo de elemento, realiza las siguientes acciones:
- Para valores booleanos: NOT lógico.
- Para números enteros: NOT a nivel de bits.
Argumentos
Nombre | Tipo | Limitaciones |
---|---|---|
operand |
tensor de tipo booleano o número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano o número entero | (C1) |
Limitaciones
- (C1)
type(operand) = type(result)
.
Ejemplos
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semántica
Garantiza que las operaciones que producen operand
se ejecuten antes que cualquier operación que dependa de result
y evita que las transformaciones del compilador muevan operaciones a través de la barrera. Aparte de eso, la operación es una identidad, es decir, result = operand
.
Argumentos
Nombre | Tipo | Limitaciones |
---|---|---|
operand |
Cantidad variacional de tensores, tensores o tokens cuantificados por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
Cantidad variacional de tensores, tensores o tokens cuantificados por tensor | (C1) |
Limitaciones
- (C1)
type(operand...) = type(result...)
.
Ejemplos
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
o
Semántica
Realiza la operación OR a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para valores booleanos: OR lógico.
- Para números enteros: OR a nivel de bits.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo número entero o booleano | (C1) |
(I2) | rhs |
tensor de tipo número entero o booleano | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero o booleano | C1 |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
salida
Semántica
Escribe inputs
en el feed de salida y produce un token result
.
La semántica de outfeed_config
se define según la implementación.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
Cantidad variacional de tensores o tensores cuantificados |
(I2) | token |
token |
(I3) | outfeed_config |
constante de tipo string |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Ejemplos
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
almohadilla
Semántica
Expande operand
con el relleno alrededor del tensor y entre los elementos del tensor con el padding_value
determinado.
edge_padding_low
y edge_padding_high
especifican la cantidad de padding que se agrega en el extremo bajo (junto al índice 0) y en el extremo superior (junto al índice más alto) de cada dimensión, respectivamente. La cantidad de relleno puede ser negativa, en cuyo caso el valor absoluto del relleno negativo indica la cantidad de elementos que se quitarán de la dimensión especificada.
interior_padding
especifica la cantidad de padding que se agrega entre cualquier par de elementos en cada dimensión, que no puede ser negativa. El padding interior ocurre antes del padding de borde, de modo que el padding de borde negativo quitará elementos del operando con padding interior.
De forma más formal, result[result_index]
se define de la siguiente manera:
operand[operand_index]
si esresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.- De lo contrario,
padding_value
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C1) |
(I3) | edge_padding_low |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C4) |
(I4) | edge_padding_high |
Constante de tensor de 1 dimensión de tipo si64 |
(C1), (C4) |
(I5) | interior_padding |
Constante de tensor de 1 dimensión de tipo si64 |
(C2-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3-C6) |
Limitaciones
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Ejemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semántica
Produce partition_id
del proceso actual.
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de 0 dimensiones de tipo ui32 |
Ejemplos
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
Semántica
Realiza un recuento por elemento de la cantidad de bits establecidos en el tensor operand
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo número entero | (C1) |
Limitaciones
- (C1)
type(operand) = type(result)
.
Ejemplos
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
energía
Semántica
Realiza la exponenciación a nivel de elementos del tensor lhs
por el tensor rhs
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: exponente de números enteros.
- Para números de punto flotante:
pow
de IEEE-754. - Para números complejos: exponenciación compleja.
- Para tipos cuantificados:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semántica
Extrae la parte real, por elemento, de operand
y produce un tensor result
. De forma más formal, para cada elemento x
: real(x) = is_complex(x) ? real_part(x) : x
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante | (C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1), (C2) |
Limitaciones
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.element_type(operand)
en caso contrario.
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semántica
Recibe datos de un canal con channel_id
y produce results
.
Si is_host_transfer
es true
, la operación transfiere datos del
host. De lo contrario, transfiere datos desde otro dispositivo. Esto significa que se define en la implementación. Esta marca duplica la información proporcionada en channel_type
, por lo que, en el futuro, planeamos conservar solo una de ellas (#666).
results
consta de valores de carga útil que aparecen primero y un token que aparece
al final. En el futuro, planeamos dividir la carga útil y el token en dos resultados separados para mejorar la claridad (#670).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
constante de tipo si64 |
|
(I3) | channel_type |
enum de DEVICE_TO_DEVICE y HOST_TO_DEVICE |
(C1) |
(I4) | is_host_transfer |
constante de tipo i1 |
(C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores, tensores cuantificados o tokens | (C2-C4) |
Limitaciones
- (C1)
channel_type
se define de la siguiente manera:HOST_TO_DEVICE
siis_host_transfer = true
,- De lo contrario,
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
Ejemplos
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
Reducir
Semántica
Aplica una función de reducción body
a inputs
y init_values
a lo largo de dimensions
y produce tensores results
.
El orden de las reducciones se define en la implementación, lo que significa que body
y
init_values
deben formar un monoide para garantizar que la operación produzca los
mismos resultados para todas las entradas en todas las implementaciones. Sin embargo, esta condición
no se cumple para muchas reducciones populares. Por ejemplo, la suma de punto flotante para body
y cero para init_values
no forman un monoide porque la suma de punto flotante no es asociativa.
De forma más formal, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, en la que:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, en el que se insertan:
endimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
para algún árbol binarioschedule
en el que:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
es un árbol binario completo definido por la implementación cuyo recorrido en orden consiste en lo siguiente:- Valores
input_slices_converted...[index]
para todos losindex
enindex_space(input_slices_converted)
en el orden lexicográfico ascendente deindex
. - Se intercala con una cantidad de
init_values_converted
definida por la implementación en posiciones definidas por la implementación.
- Valores
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4), (C6) y (C7) |
(I2) | init_values |
número variádico de tensores de 0 dimensiones o tensores cuantificados por tensor | (C2), (C3) |
(I3) | dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C4), (C5), (C7) |
(I4) | body |
función | (C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C3), (C7), (C8) |
Limitaciones
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
dondeis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, excepto que no se incluyen los tamaños de dimensión deinputs...
correspondientes adimensions
. - (C8)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semántica
Realiza la conversión de operand
por elemento a otro tipo de número de punto flotante que usa exponent_bits
y mantissa_bits
y vuelve al tipo de número de punto flotante original y produce un tensor output
.
De forma más formal:
- Los bits de mantisa del valor original se actualizan para redondear el valor original al valor más cercano que se puede representar con
mantissa_bits
usando la semántica deroundToIntegralTiesToEven
. - Luego, si
mantissa_bits
es menor que la cantidad de bits de mantisa del valor original, los bits de mantisa se truncan amantissa_bits
. - Luego, si los bits exponentes del resultado intermedio no se ajustan al rango proporcionado por
exponent_bits
, el resultado intermedio se desborda hasta el infinito con el signo original o se subdesborda a cero con el signo original. - Para los tipos cuantificados, realiza
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
(I2) | exponent_bits |
constante de tipo si32 |
(C2) |
(I3) | mantissa_bits |
constante de tipo si32 |
(C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
Ejemplos
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, realiza la reducción, con computations
, sobre los valores del tensor operand
de cada proceso, divide el resultado de la reducción en partes a lo largo de scatter_dimension
y dispersa las partes divididas entre los procesos para producir result
.
La operación divide la cuadrícula del proceso de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
sichannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
sichannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
sichannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
para todos lossender
enprocess_group
, dondereceiver_index = process_group.index(receiver)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C7), (C8) |
(I2) | scatter_dimension |
constante de tipo si64 |
(C1), (C2), (C8) |
(I3) | replica_groups |
Constante de tensor de 2 dimensiones de tipo si64 |
(C3-C5) |
(I4) | channel_id |
constante de tipo si64 |
(C6) |
(I5) | use_global_device_ids |
constante de tipo i1 |
(C6) |
(I6) | computation |
función | (C7) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C8-C9) |
Limitaciones
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
- (C4)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Si
use_global_device_ids = true
, entonceschannel_id > 0
. - (C7)
computation
tiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>)
en el queis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
, excepto:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semántica
Aplica una función de reducción body
a las ventanas de inputs
y init_values
, y produce results
.
En el siguiente diagrama, se muestra cómo se calculan los elementos de results...
a partir de inputs...
con un ejemplo concreto.
De forma más formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(consulta reduce), en la que se cumple lo siguiente:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13) y (C15) |
(I2) | init_values |
Cantidad variacional de tensores de 0 dimensiones o tensores cuantificados por tensor | (C1), (C13) |
(I3) | window_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Constante de tensor de 1 dimensión de tipo si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Constante de tensor de 1 dimensión de tipo si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Constante de tensor unidimensional de tipo si64 |
(C10), (C11), (C15) |
(I7) | padding |
Constante de tensor de 2 dimensiones de tipo si64 |
(C12), (C15) |
(I8) | body |
función | (C13) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C1), (C14-C16) |
Limitaciones
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
dondeis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
en el que:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
resto
Semántica
Realiza el resto de los tensores lhs
y rhs
del dividendo y el divisor por elemento y produce un tensor result
.
De manera más formal, el signo del resultado se toma del dividendo, y el valor absoluto del resultado siempre es menor que el valor absoluto del divisor.
El resto se calcula como lhs - d * rhs
, donde d
se obtiene de la siguiente manera:
- Para números enteros:
stablehlo.divide(lhs, rhs)
. - Para números de punto flotante:
division(lhs, rhs)
de IEEE-754 con el atributo de redondeoroundTowardZero
. - Para números complejos: por definir (#997).
- Para tipos cuantizados:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
En el caso de los tipos de elementos de punto flotante, esta operación contrasta con la operación remainder
de la especificación IEEE-754, en la que d
es un valor integral más cercano al valor exacto de lhs/rhs
con empates para números pares.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semántica
Produce replica_id
del proceso actual.
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de 0 dimensiones de tipo ui32 |
Ejemplos
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
reshape
Semántica
Cambia la forma del tensor operand
a un tensor result
. Conceptualmente, equivale a mantener la misma representación canónica, pero posiblemente a cambiar la forma, p.ej., de tensor<2x3xf32>
a tensor<3x2xf32>
o tensor<6xf32>
.
De forma más formal, result[result_index] = operand[operand_index]
, en el que result_index
y operand_index
tienen la misma posición en el orden lexicográfico de index_space(result)
y index_space(operand)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C3) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
yquantization_dimension(result)
pueden diferir.
- (C2)
size(operand) = size(result)
. - (C3) Si
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
revertir
Semántica
Revierte el orden de los elementos en operand
a lo largo del dimensions
especificado y produce un tensor result
. De forma más formal, result[result_index] = operand[operand_index]
, en la que:
operand_index[d] = dim(result, d) - result_index[d] - 1
sid
endimensions
.- De lo contrario,
operand_index[d] = result_index[d]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1) y (C3) |
(I2) | dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C2) y (C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C3) |
Limitaciones
- (C1)
type(operand) = type(result)
- (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Ejemplos
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semántica
Genera números aleatorios con el algoritmo rng_distribution
y produce un tensor result
de una forma shape
determinada.
Si es rng_distribution = UNIFORM
, los números aleatorios se generan según la distribución uniforme en el intervalo [a, b)
. Si es a >= b
, el comportamiento no está definido.
Si es rng_distribution = NORMAL
, los números aleatorios se generan siguiendo la distribución normal con una media = a
y una desviación estándar = b
.
Si es b < 0
, el comportamiento no está definido.
La forma exacta en que se generan los números aleatorios se define en la implementación. Por ejemplo, pueden ser deterministas o no, y pueden usar o no un estado oculto.
En conversaciones con muchas partes interesadas, esta operación se consideró obsoleta, por lo que, en el futuro, planeamos quitarla (#597).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | a |
Tensor de 0 dimensiones de tipo entero, booleano o de punto flotante | (C1), (C2) |
(I2) | b |
Tensor de 0 dimensiones de tipo entero, booleano o de punto flotante | (C1), (C2) |
(I3) | shape |
Constante de tensor de 1 dimensión de tipo si64 |
(C3) |
(I4) | rng_distribution |
enum de UNIFORM y NORMAL |
(C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo número entero, booleano o de punto flotante | (C1-C3) |
Limitaciones
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Si
rng_distribution = NORMAL
, entoncesis_float(a)
. - (C3)
shape(result) = shape
.
Ejemplos
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semántica
Muestra un output
lleno de bits aleatorios uniformes y un estado de salida actualizado output_state
con el algoritmo de generador de números pseudoaleatorios rng_algorithm
dado un initial_state
de estado inicial. Se garantiza que el resultado sea una función determinística de initial_state
, pero no se garantiza que sea determinística entre implementaciones.
rng_algorithm
es una de las siguientes opciones:
DEFAULT
: Algoritmo definido por la implementación.THREE_FRY
: Variante definida por la implementación del algoritmo de Threefry.*PHILOX
: Es la variante definida por la implementación del algoritmo Philox.*
* Consulta: Salmon et al. SC 2011. Números aleatorios paralelos: es tan fácil como 1, 2, 3.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | rng_algorithm |
enum de DEFAULT , THREE_FRY y PHILOX |
(C2) |
(I2) | initial_state |
Tensor de 1 dimensión de tipo ui64 |
(C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output_state |
Tensor de 1 dimensión de tipo ui64 |
(C1) |
output |
tensor de tipo número entero o de punto flotante |
Limitaciones
- (C1)
type(initial_state) = type(output_state)
- (C2)
size(initial_state)
se define de la siguiente manera:- Se define en la implementación si es
rng_algorithm = DEFAULT
. 2
sirng_algorithm = THREE_FRY
.2
o3
sirng_algorithm = PHILOX
- Se define en la implementación si es
Ejemplos
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semántica
Realiza un redondeo por elemento hacia el número entero más cercano, rompiendo los empates lejos de cero, en el tensor operand
y produce un tensor result
. Implementa la operación roundToIntegralTiesToAway
de la especificación IEEE-754. Para los tipos cuantificados, realiza dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semántica
Realiza un redondeo por elemento hacia el número entero más cercano, rompiendo los empates hacia el número entero par, en el tensor operand
y produce un tensor result
. Implementa la operación roundToIntegralTiesToEven
de la especificación IEEE-754. Para los tipos cuantificados, realiza dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semántica
Realiza una operación de raíz cuadrada recíproca a nivel de los elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
rSqrt
de IEEE-754. - Para números complejos: raíz cuadrada recíproca compleja.
- Para tipos cuantificados:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
dispersión
Semántica
Produce tensores results
que son iguales a los tensores inputs
, excepto que varias porciones especificadas por scatter_indices
se actualizan con los valores updates
mediante update_computation
.
En el siguiente diagrama, se muestra cómo los elementos de updates...
se asignan a elementos de results...
con un ejemplo concreto. El diagrama elige algunos índices updates...
de ejemplo y explica en detalle a qué índices results...
corresponden.
De forma más formal, para todos los update_index
en index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
se define de la siguiente manera:scatter_indices[si0, ..., :, ..., siN]
, en el quesi
son elementos individuales enupdate_scatter_index
y:
se inserta en el índiceindex_vector_dim
, siindex_vector_dim
<rank(scatter_indices)
.[scatter_indices[update_scatter_index]]
en caso contrario.
- Para
d_input
enaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
si esd_input = scatter_dims_to_operand_dims[d_start]
.full_start_index[d_input] = 0
en caso contrario.
- Para
d_input
enaxes(inputs[0])
,full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_input = input_batching_dims[i_batching]
yd_start = scatter_indices_batching_dims[i_batching]
.full_batching_index[d_input] = 0
en caso contrario.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, en el quewi
son elementos individuales enupdate_window_index
y0
se inserta en los índices deinserted_window_dims
yinput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Dado eso, results = exec(schedule, inputs)
, donde:
schedule
es una permutación definida por la implementación deindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
en el que:- Si
result_index
está dentro de los límites deshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
es una copia deresults
conresults...[result_index]
configurado enupdated_values...
.- De lo contrario
updated_results = results
.
- Si
exec([], results) = results
.
Si indices_are_sorted
es true
, la implementación puede suponer que scatter_indices
están ordenados con respecto a scatter_dims_to_operand_dims
. De lo contrario, el comportamiento será indefinido. De forma más formal, para todos los i1 < i2
de indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Si unique_indices
es true
, la implementación puede suponer que todos los índices result_index
a los que se dispersan son únicos. Si unique_indices
es true
, pero los índices a los que se dispersan no son únicos, el comportamiento no está definido.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tensor de tipo número entero | (C4), (C15), (C19), (C22) |
(I3) | updates |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C3-C6), (C8) |
(I4) | update_window_dims |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Constante de tensor de 1 dimensión de tipo si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Constante de tensor de 1 dimensión de tipo si64 |
(C19-C21) |
(I9) | index_vector_dim |
constante de tipo si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
constante de tipo i1 |
|
(I11) | unique_indices |
constante de tipo i1 |
|
(I12) | update_computation |
función | (C23) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C24-C25) |
Limitaciones
- (C1)
same(shape(inputs...))
- (C2) "rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
en el que:update_scatter_dim_sizes = shape(scatter_indices)
, excepto que no se incluye el tamaño de la dimensión descatter_indices
que corresponde aindex_vector_dim
.update_window_dim_sizes <= shape(inputs[0])
, excepto que no se incluyen los tamaños de las dimensiones eninputs[0]
correspondientes ainserted_window_dims
yinput_batching_dims
.combine
colocaupdate_scatter_dim_sizes
en los ejes correspondientes aupdate_scatter_dims
yupdate_window_dim_sizes
en los ejes correspondientes aupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
- (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims))
- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
. - (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, en el queis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
seleccionar
Semántica
Produce un tensor result
en el que cada elemento se selecciona del tensor on_true
o on_false
según el valor del elemento correspondiente de pred
.
De forma más formal, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, donde pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Para los tipos cuantificados, realiza dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | pred |
tensor de tipo i1 |
(C1) |
(I2) | on_true |
tensor o tensor cuantificado por tensor | (C1-C2) |
(I3) | on_false |
tensor o tensor cuantificado por tensor | (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C2) |
Limitaciones
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Ejemplos
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semántica
Dispersa los valores del tensor source
con scatter
según el resultado de reduce_window
del tensor input
con select
y produce un tensor result
.
En el siguiente diagrama, se muestra cómo se calculan los elementos de result
a partir de operand
y source
con un ejemplo concreto.
De forma más formal:
selected_values = reduce_window_without_init(...)
con las siguientes entradas:inputs = [operand].
window_dimensions
,window_strides
ypadding
, que se usan tal como estánbase_dilations = windows_dilations = 1
.body
se define de la siguiente manera:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
en el que
E = element_type(operand)
yreduce_window_without_init
funcionan exactamente comoreduce_window
, excepto que elschedule
de lareduce
subyacente (consulta reduce) no incluye valores de inicialización. Actualmente, no se especifica lo que sucede si la ventana correspondiente no tiene valores (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
donde:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
siselected_values[source_index]
tiene el elementooperand
deoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensor o tensor cuantificado por tensor | (C1), (C2) |
(I3) | init_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C3) |
(I4) | window_dimensions |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C6), (C7) |
(I6) | padding |
Constante de tensor de 2 dimensiones de tipo si64 |
(C2), (C8) |
(I7) | select |
función | (C9) |
(I8) | scatter |
función | (C10) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C11-C12) |
Limitaciones
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
donde:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
tiene el tipo(tensor<E>, tensor<E>) -> tensor<i1>
en el queE = element_type(operand)
. - (C10)
scatter
tiene el tipo(tensor<E>, tensor<E>) -> tensor<E>
, en el queis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Ejemplos
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
enviar
Semántica
Envía inputs
a un canal channel_id
y produce un token result
.
Si is_host_transfer
es true
, la operación transfiere datos al host. De lo contrario, transfiere datos a otro dispositivo. Esto significa que se define en la implementación. Esta marca duplica la información proporcionada en channel_type
, por lo que, en el futuro, planeamos conservar solo una de ellas (#666).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
Cantidad variacional de tensores o tensores cuantificados | |
(I2) | token |
token |
|
(I3) | channel_id |
constante de tipo si64 |
|
(I4) | channel_type |
enum de DEVICE_TO_DEVICE y DEVICE_TO_HOST |
(C1) |
(I5) | is_host_transfer |
constante de tipo i1 |
(C1) |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Limitaciones
- (C1)
channel_type
se define de la siguiente manera:DEVICE_TO_HOST
siis_host_transfer = true
,- De lo contrario,
DEVICE_TO_DEVICE
.
Ejemplos
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semántica
Realiza la operación de desplazamiento a la izquierda en los elementos del tensor lhs
según el número de bits rhs
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo número entero | (C1) |
(I2) | rhs |
tensor de tipo número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo número entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semántica
Realiza una operación de desplazamiento a la derecha aritmética a nivel de elementos en el tensor lhs
por rhs
cantidad de bits y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo número entero | (C1) |
(I2) | rhs |
tensor de tipo número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo número entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semántica
Realiza una operación lógica de desplazamiento a la derecha en relación con los elementos en el tensor lhs
con un número de bits rhs
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo número entero | (C1) |
(I2) | rhs |
tensor de tipo número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo número entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
firmar
Semántica
Muestra el signo del operand
a nivel de elementos y produce un tensor result
.
De forma más formal, para cada elemento x
, la semántica se puede expresar con la sintaxis de Python de la siguiente manera:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Para los tipos cuantizados, realiza dequantize_op_quantize(sign, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de número entero firmado, punto flotante o tipo complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero con firma, de punto flotante o de tipo complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semántica
Realiza una operación de seno a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
sin
de IEEE-754. - Para números complejos: seno complejo.
- Para tipos cuantificados:
dequantize_op_quantize(sine, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
porción
Semántica
Extrae una porción de operand
con índices de inicio calculados de forma estática y produce un tensor result
. start_indices
contiene los índices iniciales de la porción de cada dimensión, limit_indices
contiene los índices finales (exclusivos) de la porción de cada dimensión y strides
contiene los segmentos de cada dimensión.
Más formalmente, result[result_index] = operand[operand_index]
donde operand_index = start_indices + result_index * strides
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C3), (C5) |
(I2) | start_indices |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C3), (C5) |
(I3) | limit_indices |
Constante de tensor de 1 dimensión de tipo si64 |
(C2), (C3), (C5) |
(I4) | strides |
Constante de tensor de 1 dimensión de tipo si64 |
(C2) y (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C5) |
Limitaciones
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Ejemplos
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
ordenar
Semántica
Ordena las secciones 1D de inputs
a lo largo de la dimensión dimension
, según una comparator
y produce results
.
A diferencia de las entradas similares en otras operaciones, dimension
permite valores negativos, con la semántica que se describe a continuación. En el futuro, es posible que esto no se permita por motivos de coherencia (#1377).
Si is_stable
es verdadero, el orden es estable, es decir, se conserva el orden relativo de los elementos que el comparador considera iguales. En el caso
de que haya una sola entrada, el comparador considera que dos elementos e1
y e2
son
iguales solo si
comparator(e1, e2) = comparator(e2, e1) = false
. Consulta la formalización a continuación para ver cómo se generaliza a varias entradas.
De forma más formal, para todos los result_index
en index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, en el queriN
son elementos individuales enresult_index
y:
se inserta enadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- donde
sort
ordena una porción de 1 dimensión en orden no descendente y espera quecomparator_together
muestretrue
si el argumento de la izquierda es menor que el segundo argumento de la derecha. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C5) |
(I2) | dimension |
constante de tipo si64 |
(C4) |
(I3) | is_stable |
constante de tipo i1 |
|
(I4) | comparator |
función | C5 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores o tensores cuantificados por tensor | (C2), (C3) |
Limitaciones
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, dondeR = rank(inputs[0])
. - (C5)
comparator
tiene el tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, dondeEi = element_type(inputs[i])
.
Ejemplos
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semántica
Realiza la operación raíz cuadrada a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
squareRoot
de IEEE-754. - Para números complejos: raíz cuadrada compleja
- Para tipos cuantificados:
dequantize_op_quantize(sqrt, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semántica
Realiza la resta a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números enteros: resta de números enteros.
- Para números de punto flotante:
subtraction
de IEEE-754. - Para números complejos: resta compleja.
- Para tipos cuantizados:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semántica
Realiza una operación tangente a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para números de punto flotante:
tan
de IEEE-754. - Para números complejos: tangente compleja.
- Para tipos cuantizados:
dequantize_op_quantize(tan, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semántica
Realiza la operación de tangente hiperbólica a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
tanh
de IEEE-754. - Para números complejos: tangente hiperbólica compleja
- Para tipos cuantizados:
dequantize_op_quantize(tanh, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transponer
Semántica
Permuta las dimensiones del tensor operand
con permutation
y produce un tensor result
. De forma más formal, result[result_index] = operand[operand_index]
, en el que result_index[d] = operand_index[permutation[d]]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C4) |
(I2) | permutation |
Constante de tensor de 1 dimensión de tipo si64 |
(C2-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3-C4) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
yquantization_dimension(result)
pueden diferir.
- (C2)
permutation
es una permutación derange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Si
is_per_axis_quantized(result)
, entoncesquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Ejemplos
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semántica
Resuelve lotes de sistemas de ecuaciones lineales con matrices de coeficientes triangulares inferior o superior.
De manera más formal, dados a
y b
, result[i0, ..., iR-3, :, :]
es la solución a op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
cuando left_side
es true
o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
cuando left_side
es false
, lo que resuelve la variable x
en la que op(a)
se determina mediante transpose_a
, que puede ser una de las siguientes opciones:
NO_TRANSPOSE
: Realiza la operación cona
tal como está.TRANSPOSE
: Realiza la operación en la transposición dea
.ADJOINT
: Realiza la operación en la transposición conjugada dea
.
Los datos de entrada se leen solo del triángulo inferior de a
, si lower
es true
o del triángulo superior de a
, de lo contrario. Los datos de salida se muestran en el mismo triángulo. Los valores del otro triángulo se definen según la implementación.
Si unit_diagonal
es verdadero, la implementación puede suponer que los elementos diagonales de a
son iguales a 1; de lo contrario, el comportamiento no está definido.
Para los tipos cuantizados, realiza dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | a |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1-C3) |
(I2) | b |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1-C4) |
(I3) | left_side |
constante de tipo i1 |
(C3) |
(I4) | lower |
constante de tipo i1 |
|
(I5) | unit_diagonal |
constante de tipo i1 |
|
(I6) | transpose_a |
enum de NO_TRANSPOSE , TRANSPOSE y ADJOINT |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relación entre
shape(a)
yshape(b)
se define de la siguiente manera:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Ejemplos
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semántica
Produce una tupla result
a partir de los valores val
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | val |
Cantidad de valores variadic | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tuple | C1 |
Limitaciones
- (C1)
result
tiene el tipotuple<E0, ..., EN-1>
en el queEi = type(val[i])
.
Ejemplos
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semántica
Realiza la conversión por elemento del tensor cuantizado operand
a un tensor de punto flotante result
según los parámetros de cuantización definidos por el tipo operand
.
De forma más formal, result = dequantize(operand)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor cuantificado | (C1), (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1), (C2) |
Limitaciones
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Ejemplos
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semántica
Realiza la conversión de elementos del tensor de punto flotante o del tensor cuantizado operand
a un tensor cuantizado result
según los parámetros de cuantificación definidos por el tipo result
.
Más formalmente,
- Si
is_float(operand)
:result = quantize(operand, type(result))
.
- Si
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo punto flotante o cuantificado | (C1), (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor cuantificado | (C1), (C2) |
Limitaciones
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Ejemplos
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
mientras
Semántica
Produce el resultado de ejecutar la función body
0 o más veces mientras la función cond
genera true
. De manera más formal, la semántica se puede expresar con la sintaxis de Python de la siguiente manera:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
El comportamiento de un bucle infinito está por definir (#383).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
Cantidad variacional de tensores, tensores cuantificados o tokens | (C1-C3) |
(I2) | cond |
función | C1 |
(I3) | body |
función | (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
Cantidad variacional de tensores, tensores cuantificados o tokens | (C3) |
Limitaciones
- (C1)
cond
tiene el tipo(T0, ..., TN-1) -> tensor<i1>
, en el queTi = type(operand[i])
. - (C2)
body
tiene el tipo(T0, ..., TN-1) -> (T0, ..., TN-1)
, en el queTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
Ejemplos
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semántica
Realiza XOR a nivel de elementos de dos tensores, lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, realiza las siguientes acciones:
- Para valores booleanos: XOR lógico.
- Para números enteros: XOR a nivel de bits.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo booleano o de número entero | (C1) |
(I2) | rhs |
tensor de tipo booleano o número entero | (C1) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano o número entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interoperabilidad de dialectos
Por el momento, los programas de StableHLO en el entorno a veces contienen operaciones que no está definida por StableHLO.
Módulo, función, llamada y devolución
StableHLO usa operaciones de MLIR upstream para ModuleOp, FuncOp, CallOp y ReturnOp. Esto se hizo para lograr una mejor interoperabilidad con el mecanismo de MLIR existente, ya que muchos pases útiles se escriben para FuncOp y ModuleOp, y muchas canalización de compilación esperan que estas operaciones estén presentes. Se aplican garantías de compatibilidad completas a estas operaciones. Si se produce algún cambio en estas operaciones de manera incompatible (es decir, eliminación), se agregarán equivalentes de StableHLO para preservar la compatibilidad.
CHLO
El conjunto de operaciones de CHLO contiene operaciones de nivel superior que se descomponen en StableHLO. Actualmente, no hay garantías de compatibilidad para CHLO. Para garantizar la compatibilidad, se debe usar el pase chlo-legalize-to-stablehlo antes de la serialización.
Operaciones de formas
Es un caso de uso común en la comunidad usar ciertas operaciones de dialectos principales del MLIR en programas dinámicos StableHLO para realizar cálculos de forma.
Por lo general, estas incluyen operaciones de dialecto shape
, como shape_of
o num_elements
, operaciones de dialecto tensor
, como dim
o from_elements
, y el tipo index
integrado.
La RFC de dinamismo > O2 indica que están fuera del alcance, sin embargo, se incluye cierta compatibilidad con los tipos index
para la interoperabilidad. No hay garantías de compatibilidad para estas operaciones o tipos. El pase shape-legalize-to-stablehlo se puede usar para convertir estas operaciones en operaciones de StableHLO totalmente compatibles.
Operaciones obsoletas
Hay varias operaciones de StableHLO que se heredaron de MHLO, que están obsoletas y están saliendo de StableHLO. Puedes encontrar todos los detalles sobre estas eliminaciones en Limpieza de StableHLO v1.0 #2283. El problema del seguimiento de estas bajas es #2340.
Estas operaciones se dividen en algunas categorías:
- Categoría "No en HLO" de las operaciones de StableHLO: Inicialmente, formaban parte del conjunto de operaciones de StableHLO, pero más tarde se consideró que no se ajustaban bien:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - Operaciones sin usar: Es posible que estas operaciones hayan sido útiles en algún momento, pero no se desarrollaron por completo o las canalizaciones que las usan se refactorizaron para que ya no las requieran. Esto incluye
map
,tuple
(#598),get_tuple_element
,rng
, comparacionescomplex
#560 y convoluciónwindow_reversal
(#1181).
Algunas de estas operaciones se pueden quitar con facilidad, ya que se pueden expresar mediante las operaciones existentes (broadcast
, create_token
, cross-replica-sum
, dot
, unary_einsum
) y se quitarán después de que pase el período de compatibilidad existente (6 meses). Aún se están explorando otras opciones para quitarlas (comparaciones einsum
, get_tuple_element
, map
, rng
, torch_index_select
, tuple
y complex
, window_reversal
). Según los comentarios de la comunidad, estas operaciones se quitarán o se agregarán a las especificaciones con compatibilidad total. Hasta que se conozcan estos futuros de operaciones, solo se garantiza 6 meses de compatibilidad.
Ejecución
Ejecución secuencial
Para ejecutar un programa de StableHLO, se proporcionan valores de entrada a la función main
y se calculan los valores de salida. Para calcular los valores de salida de una función, se ejecuta el gráfico de operaciones con raíz en la operación return
correspondiente.
El orden de ejecución se define en la implementación, siempre y cuando esté alineado con el flujo de datos, es decir, si las operaciones se ejecutan antes de sus usos. En StableHLO, todas las operaciones con efectos secundarios consumen un token y producen uno (se pueden multiplexar varios tokens en uno a través de after_all
), por lo que el orden de ejecución de los efectos secundarios también está alineado con el flujo de datos. Por ejemplo, en el siguiente programa, hay dos órdenes de ejecución posibles: %0
→ %1
→ %2
→ return
y %1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
De forma más formal, un proceso de StableHLO es una combinación de lo siguiente: 1) un programa de StableHLO, 2) estados de operación (aún no se ejecuta, ya se ejecutó) y 3) valores intermedios en los que el proceso está trabajando.
El proceso comienza con los valores de entrada de la función main
, avanza a través del gráfico de operaciones que actualizan los estados de operación y los valores intermedios, y finaliza con los valores de salida. La formalización adicional está pendiente (#484).
Ejecución paralela
Los programas StableHLO se pueden ejecutar en paralelo y se organizan en una cuadrícula de procesos 2D de num_replicas
por num_partitions
, que ambos tienen el tipo ui32
.
En la cuadrícula de procesos de StableHLO, num_replicas * num_partitions
de los procesos de StableHLO se ejecutan al mismo tiempo. Cada proceso tiene un process_id = (replica_id, partition_id)
único, en el que replica_id
en replica_ids = range(num_replicas)
y partition_id
en partition_ids = range(num_partitions)
, que tienen el tipo ui32
.
El tamaño de la cuadrícula de procesos se conoce de forma estática para cada programa (en el futuro, planeamos que sea una parte explícita de los programas de StableHLO #650) y la posición dentro de la cuadrícula de procesos se conoce de forma estática para cada proceso. Cada proceso tiene acceso a su posición dentro de la cuadrícula de procesos a través de las operaciones replica_id
y partition_id
.
Dentro de la cuadrícula de procesos, todos los programas pueden ser iguales (en el estilo "Programa único, varios datos"), pueden ser diferentes (en el estilo "Varios programas, varios datos") o algo intermedio. En el futuro, planeamos admitir otros lenguajes para definir programas paralelos de StableHLO, incluido GSPMD (#619).
Dentro de la cuadrícula de procesos, los procesos son en su mayoría independientes entre sí: tienen estados de operación, valores de entrada, intermedios y de salida separados, y la mayoría de las operaciones se ejecutan por separado entre procesos, a excepción de una pequeña cantidad de operaciones colectivas que se describen a continuación.
Dado que la ejecución de la mayoría de las operaciones solo usa valores del mismo proceso, por lo general, no es ambiguo hacer referencia a estos valores por sus nombres.
Sin embargo, cuando se describe la semántica de las operaciones colectivas, eso es insuficiente y da lugar a la notación name@process_id
para referirse al valor name
dentro de un proceso en particular. (Desde esa perspectiva, name
no calificado se puede ver como una versión abreviada de name@(replica_id(), partition_id())
).
El orden de ejecución entre procesos se define en la implementación, excepto por la sincronización que introducen las operaciones colectivas y la comunicación punto a punto, como se describe a continuación.
Comunicación punto a punto
Los procesos de StableHLO pueden comunicarse entre sí a través de los canales de StableHLO. Un canal se representa con un ID positivo de tipo si64
. A través de varias operaciones, es posible enviar valores a los canales y recibirlos de ellos.
Aún se está definiendo la formalización adicional, p. ej., de dónde provienen estos IDs de canal, cómo los programas de procesos los detectan y qué tipo de sincronización introducen (#484).
Comunicación de transmisión
Cada proceso de StableHLO tiene acceso a dos interfaces de transmisión:
- Infeed que se puede leer.
- Salida en la que se puede escribir.
A diferencia de los canales, que se usan para comunicarse entre procesos y, por lo tanto, tienen procesos en ambos extremos, los feeds de entrada y salida tienen el otro extremo definido por la implementación.
Aún se está definiendo la formalización adicional, p. ej., cómo la comunicación de transmisión influye en el orden de ejecución y qué tipo de sincronización introduce (#484).
Operaciones colectivas
Hay seis operaciones colectivas en StableHLO: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
y reduce_scatter
. Todas estas operaciones dividen los procesos en la cuadrícula de procesos de StableHLO en grupos de procesos de StableHLO y ejecutan un cálculo conjunto dentro de cada grupo de procesos, independientemente de otros grupos de procesos.
Dentro de cada grupo de procesos, las operaciones colectivas pueden introducir una barrera de sincronización. Aún se está definiendo la formalización adicional, p. ej., explicar cuándo se produce exactamente esta sincronización, cómo llegan exactamente los procesos a esta barrera y qué sucede si no lo hacen (#484).
Si el grupo de procesos incluye una comunicación entre particiones, es decir, hay procesos en el grupo de procesos cuyos IDs de partición son diferentes, la ejecución de la operación colectiva necesita un canal, y la operación colectiva debe proporcionar un channel_id
positivo de tipo si64
. La comunicación entre réplicas no necesita canales.
Los cálculos que realizan las operaciones colectivas son específicos de las operaciones individuales y se describen en las secciones de operaciones individuales anteriores. Sin embargo, las estrategias con las que la cuadrícula de procesos se divide en grupos de procesos se comparten entre estas operaciones y se describen en esta sección. De forma más formal, StableHLO admite las siguientes cuatro estrategias.
cross_replica
Solo se producen comunicaciones entre réplicas dentro de cada grupo de procesos. Esta estrategia toma replica_groups
, una lista de listas de IDs de réplicas, y calcula un producto cartesiano de replica_groups
por partition_ids
. replica_groups
debe tener elementos únicos y abarcar todos los replica_ids
. De forma más formal, con la sintaxis de Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]]
y num_partitions = 2
, cross_replica
producirá [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Solo se realizan comunicaciones entre particiones dentro de cada grupo de procesos. Esta estrategia toma partition_groups
, una lista de listas de IDs de partición, y calcula un producto cartesiano de partition_groups
por replica_ids
.
partition_groups
debe tener elementos únicos y abarcar todas las partition_ids
.
De forma más formal, con la sintaxis de Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para partition_groups = [[0, 1]]
y num_replicas = 4
, cross_partition
producirá [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Las comunicaciones entre réplicas y entre particiones pueden ocurrir dentro de cada
grupo de procesos. Esta estrategia toma replica_groups
, una lista de listas de IDs de réplicas, y calcula los productos cartesianos de cada replica_group
por partition_ids
. replica_groups
debe tener elementos únicos y abarcar todos los replica_ids
. De manera más formal, usando la sintaxis de Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]]
y num_partitions = 2
, cross_replica_and_partition
producirá [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Esta estrategia toma flattened_id_groups
, una lista de listas de IDs de procesos “aplanados” en forma de replica_id * num_partitions + partition_id
, y los convierte en IDs de procesos. flattened_id_groups
debe tener elementos únicos y abarcar todas las process_ids
. De forma más formal, con la sintaxis de Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
y num_partitions = 2
, flattened_ids
producirá [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Exactitud
Por el momento, StableHLO no ofrece garantías sobre la precisión numérica, pero esto puede cambiar en el futuro (#1156).
Semántica de ejecución de la operación cuantificada
La interpretación de las operaciones quantized StableHLO puede variar según los requisitos y las capacidades del hardware. Por ejemplo, parte del hardware puede optar por interpretar operaciones cuantizadas a través de una estrategia de "decuantizar, realizar operaciones de punto flotante y, por último, cuantizar". Otros pueden realizar todo el cálculo con aritmética de números enteros. Por lo tanto, la interpretación de las operaciones de StableHLO cuantificadas se determina exclusivamente por la implementación específica. La interpretación de la cuantificación híbrida (#1575) debe basarse en su semántica como se prescribe en la especificación (a través de 1792).
Errores
Los programas de StableHLO se validan a través de un amplio conjunto de restricciones para operaciones individuales, lo que descarta muchas clases de errores antes del tiempo de ejecución. Sin embargo, las condiciones de error aún son posibles, p.ej., a través de desbordamientos de números enteros, accesos fuera de los límites, etc. A menos que se mencionen explícitamente, todos estos errores generan un comportamiento definido por la implementación, pero esto puede cambiar en el futuro (#1157).
Excepciones de punto flotante
Como excepción a esta regla, las excepciones de punto flotante en los programas StableHLO tienen un comportamiento bien definido. Las operaciones que generan excepciones definidas por el estándar IEEE-754 (operación no válida, división por cero, desbordamiento, subflujo o excepciones inexactas) producen resultados predeterminados (como se define en el estándar) y continúan la ejecución sin generar la marca de estado correspondiente, similar al control de excepciones raiseNoFlag
del estándar. Las excepciones para operaciones no estándar (p.ej., aritmética compleja y ciertas funciones trascendentales) se definen en la implementación.
Discrepancias de forma
StableHLO admite tensores de forma dinámica. Sin embargo, las formas deben coincidir en el tiempo de ejecución; de lo contrario, el comportamiento es indefinido. StableHLO no proporciona de forma explícita una operación que pueda afirmar que un tensor tiene una forma determinada durante el tiempo de ejecución. Generar el código correcto es responsabilidad del productor.
Como ejemplo específico, el siguiente programa es válido. Sin embargo, durante el tiempo de ejecución, las formas exactas de %arg0
y %arg1
deberán ser iguales. De lo contrario, el comportamiento del programa no estará definido:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Para describir la sintaxis, en este documento se usa la variante ISO modificada de la sintaxis EBNF (ISO/IEC 14977:1996, Wikipedia), con dos modificaciones: 1) las reglas se definen con ::=
en lugar de =
.
2) La concatenación se expresa con yuxtaposición en lugar de ,
.
Para describir la semántica (es decir, en las secciones "Tipos", "Constantes" y "Operaciones"), usamos fórmulas que se basan en la sintaxis de Python extendida con compatibilidad para expresar de forma concisa las operaciones de array, como se describe a continuación. Esto funciona bien para fragmentos de código pequeños, pero en casos excepcionales en los que se necesitan fragmentos de código más grandes, usamos la sintaxis de Python estándar, que siempre se presenta de forma explícita.
Fórmulas
Exploremos cómo funcionan las fórmulas en función de un ejemplo de la especificación de dot_general
. Una de las restricciones de esta operación se ve de la siguiente manera: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Los nombres que se usan en esta fórmula provienen de dos fuentes: 1) funciones globales, es decir, dim
, 2) definiciones de miembros del elemento del programa correspondiente, es decir, entradas lhs
, lhs_batching_dimensions
, rhs
y rhs_batching_dimensions
definidas en la sección "Entradas" de dot_general
.
Como se mencionó anteriormente, la sintaxis de esta fórmula se basa en Python con algunas extensiones orientadas a la concisión. Para comprender la fórmula, transfórmala en sintaxis de Python estándar.
A) En estas fórmulas, usamos =
para representar la igualdad, por lo que el primer paso para obtener la sintaxis de Python es reemplazar =
por ==
, de la siguiente manera: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Además, estas fórmulas admiten puntos suspensivos (...
) que convierten expresiones escalares en expresiones de tensor. En pocas palabras, f(xs...)
significa aproximadamente "para cada x
escalar en el tensor xs
, calcula un f(x)
escalar y, luego, muestra todos estos resultados escalares juntos como un resultado de tensor". En la sintaxis normal de Python, nuestra fórmula de ejemplo se convierte en: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Gracias a los puntos suspensivos, a menudo es posible evitar trabajar a nivel de escalares individuales. Sin embargo, en algunos casos complicados, se puede usar la sintaxis semiinformal de nivel inferior como en la fórmula start_indices[bi0, ..., :, ..., biN]
de la especificación gather
. En el caso de la concisión, no proporcionamos un formalismo exacto para traducir esa sintaxis a Python normal, con la esperanza de que siga siendo comprensible de forma intuitiva caso por caso.
Avísanos si algunas fórmulas específicas parecen opacas, y trataremos de mejorarlas.
Además, notarás que las fórmulas usan puntos suspensivos para expandir todo tipo de listas, incluidos tensores, listas de tensores (que, por ejemplo, pueden surgir de una cantidad variacional de tensores), etcétera. Esta es otra área en la que no proporcionamos un formalismo exacto (p. ej., las listas ni siquiera forman parte del sistema de tipos de StableHLO) y, en su lugar, nos basamos en la comprensión intuitiva.
C) El último medio notable que utilizamos es la transmisión implícita. Si bien el conjunto de operaciones StableHLO no admite transmisiones implícitas, las fórmulas también lo hacen al servicio de la concisión. En pocas palabras, si se usa un escalar en un contexto en el que se espera un tensor, el escalar se transmite a la forma esperada.
Para continuar con el ejemplo de dot_general
, aquí hay otra restricción: 0 <= lhs_batching_dimensions < rank(lhs)
. Como se define en la especificación dot_general
, lhs_batching_dimensions
es un tensor. Sin embargo, 0
y rank(lhs)
son escalares. Después de aplicar la transmisión implícita, la fórmula se convertirá en [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Cuando se aplique a una operación dot_general
en particular, esta fórmula se
evaluará como un tensor de valores booleanos. Cuando las fórmulas se usan como restricciones, la restricción se mantiene si la fórmula se evalúa como true
o como un tensor que solo tiene elementos true
.
Nombres
En las fórmulas, el alcance léxico incluye lo siguiente: 1) funciones globales, 2) definiciones de miembros,
3) Definiciones locales. A continuación, se proporciona la lista de funciones globales. La lista de definiciones de elementos depende del elemento del programa al que se aplica la notación:
- Para las operaciones, las definiciones de miembros incluyen nombres ingresados en las secciones “Entradas” y “Salidas”.
- Para todo lo demás, las definiciones de miembros incluyen partes estructurales del elemento del programa, que se nombran según los no terminales de EBNF correspondientes. La mayoría de las veces, los nombres de estas partes estructurales se obtienen convirtiendo los nombres de los no terminales a formato de guion bajo (p. ej.,
IntegerLiteral
=>integer_literal
), pero a veces los nombres se abrevian en el proceso (p. ej.,QuantizationStorageType
=>storage_type
), en cuyo caso los nombres se introducen de forma explícita de manera similar a las secciones "Entradas" o "Salidas" en las especificaciones de operación. - Además, las definiciones de miembros siempre incluyen
self
para hacer referencia al elemento del programa correspondiente.
Valores
Cuando se evalúan las fórmulas, funcionan con los siguientes tipos de valores: 1) Value
(valores reales, p.ej., dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
; siempre se conocen sus tipos), 2) Placeholder
(valores futuros, p.ej., lhs
, rhs
o result
; aún no se conocen sus valores reales, solo se conocen sus tipos), 3) Type
(tipos como se define en la sección "Tipos") y 4) Function
(funciones globales como se define en la sección "Funciones").
Según el contexto, los nombres pueden hacer referencia a diferentes valores. Más precisamente, la sección "Semántica" para las operaciones (y los equivalentes para otros elementos del programa) define la lógica del entorno de ejecución, de modo que todas las entradas estén disponibles como Value
.
En cambio, la sección "Constraints" para las operaciones (y los equivalentes) define la lógica de "tiempo de compilación", es decir, algo que se suele ejecutar antes del tiempo de ejecución, por lo que solo las entradas constantes están disponibles como Value
y otras entradas solo están disponibles como Placeholder
.
Nombres | En "Semántica" | En "Constraints" |
---|---|---|
Funciones globales | Function |
Function |
Entradas constantes | Value |
Value |
Entradas no constantes | Value |
Placeholder |
Salidas | Value |
Placeholder |
Definiciones locales | Depende de la definición | Depende de la definición |
Veamos un ejemplo de operación transpose
:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Para esta operación, permutation
es una constante, por lo que está disponible como Value
en semántica y restricciones. Por el contrario, operand
y result
están disponibles como Value
en semántica, pero solo como Placeholder
en restricciones.
Funciones
Construcción de tipos
No hay funciones que se puedan usar para construir tipos. En su lugar, usamos directamente la sintaxis de tipo porque suele ser más concisa. p. ej., (tensor<E>, tensor<E>) -> (tensor<E>)
en lugar de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funciones en tipos
element_type
se define en tipos de tensores y tipos de tensores cuantizados, y muestra, respectivamente, la parteTensorElementType
oQuantizedTensorElementType
delTensorType
oQuantizedTensorType
correspondiente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
es un atajo parais_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
es un acceso directo parais_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
verifica si el tipox
se puede promover al tipoy
. Cuandox
yy
sonQuantizedTensorElementType
, la promoción solo se aplica astorage_type
. Actualmente, esta versión específica de promoción se usa en el contexto del cálculo de reducción (consulta la RFC para obtener más detalles).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
es un acceso directo parais_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponible para todos los tipos. Por ejemplo,is_float(x)
muestratrue
six
es unFloatType
. Six
es un valor o un marcador de posición, esta función es un atajo parais_type_name(type(x))
.max_value(x: Type) -> Value
muestra el valor máximo de unTensorElementType
. Six
no es unTensorElementType
, muestraNone
.min_value(x: Type) -> Value
muestra el valor mínimo posible de unTensorElementType
. Six
no es unaTensorElementType
, muestraNone
.member_name(x: Value | Placeholder | Type) -> Any
. Disponible para todas las definiciones de miembrosmember_name
de todos los tipos. Por ejemplo,tensor_element_type(x)
muestra la parteTensorElementType
de unTensorType
correspondiente. Six
es un valor o un marcador de posición, esta función es un atajo paramember_name(type(x))
. Six
no es un tipo que tenga un miembro apropiado, un valor o un marcador de posición de ese tipo, muestraNone
.is_empty_algorithm(*args: Type)
verifica si todos los campos del algoritmo de punto se establecen enNone
. Esto es necesario porque los algoritmos de punto tienen comportamientos predeterminados definidos por la implementación, por lo que especificar un valor predeterminado sería incorrecto.
Construcción de valores
operation_name(*xs: Value | Type) -> Value
. Disponible para todas las operaciones. Por ejemplo,add(lhs, rhs)
toma dos valores de tensorlhs
yrhs
, y muestra el resultado de la evaluación de la operaciónadd
con estas entradas. Para algunas operaciones, p. ej.,broadcast_in_dim
, los tipos de sus resultados son “portantes”, es decir, necesarios para evaluar una operación. En este caso, la función toma estos tipos como argumentos.
Funciones en valores
Todos los operadores y funciones de Python están disponibles. Por ejemplo, las notaciones de suscripción y corte de Python están disponibles para indexar en tensores, tensores cuantificados y tuplas.
to_destination_type(x: Value, destination_type: Type) -> Value
se define en tensores y muestra el valor convertido dex
en función detype(x)
ydestination_type
de la siguiente manera:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Hay una discusión preliminar sobre la combinación de operaciones convert
, uniform_quantize
y uniform_dequantize
(#1576).
Después de la combinación, no necesitamos la función anterior y podemos usar el nombre de la operación para convert
.
is_nan(x: Value) -> Value
se define en tensores y muestratrue
si todos los elementos dex
sonNaN
ofalse
de lo contrario. Six
no es un tensor, muestraNone
.is_sorted(x: Value) -> Value
se define en tensores y muestratrue
si los elementos dex
están ordenados de forma ascendente en relación con el orden lexicográfico ascendente de sus índices ofalse
de lo contrario. Six
no es un tensor, se muestraNone
.is_unique(x: Value) -> Value
se define en tensores y muestratrue
six
no tiene elementos duplicados ofalse
de lo contrario. Six
no es un tensor, muestraNone
.member_name(x: Value) -> Any
se define para todas las definiciones de miembrosmember_name
de todos los valores. Por ejemplo,real_part(x)
muestra la parteRealPart
de unComplexConstant
correspondiente. Six
no es un valor que tenga un miembro apropiado, muestraNone
.same(x: Value) -> Value
se define en tensores y muestratrue
si los elementos dex
son iguales entre sí o, de lo contrario,false
. Si el tensor no tiene elementos, se cuenta como “todos iguales entre sí”, es decir, la función muestratrue
. Six
no es un tensor, muestraNone
.split(x: Value, num_results: Value, axis: Value) -> Value
se define en tensores y muestra segmentosnum_results
dex
a lo largo del ejeaxis
. Six
no es un tensor nidim(x, axis) % num_results != 0
, muestraNone
.is_defined_in_parent_scope(x: Value) -> Value
se define en cadenas y muestratrue
six
es el nombre de una función definida en el mismo alcance que la función superior de la operación relevante.is_namespaced_op_name(x: Value) -> Value
se define en las cadenas y muestratrue
six
es un nombre de operación válido, es decir, respeta la siguiente expresión regular:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Cálculos de formas
axes(x: Value | Placeholder | Type) -> Value
es la combinación de teclas pararange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
es la combinación de teclas parashape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
es un acceso directo paralist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
se define en tensores y muestra índicessize(x)
para elTensorType
correspondiente ordenado en orden alfabético ascendente, es decir,[0, ..., 0]
,[0, ..., 1]
, …,shape(x) - 1
. Six
no es un tipo de tensor, un tipo de tensor cuantificado, o un valor o un marcador de posición de uno de estos tipos, muestraNone
.rank(x: Value | Placeholder | Type) -> Value
es un acceso directo parasize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
se define en la sección "Funciones en tipos" a través demember_name
.size(x: Value | Placeholder | Type) -> Value
es un acceso directo parareduce(lambda x, y: x * y, shape(x))
.
Cálculos de cuantización
def baseline_element_type(x: Value | Placeholder | Type) -> Type
es un atajo paraelement_type(baseline_type(x))
.baseline_type
se define en tipos de tensores y tipos de tensores cuantizados, y los transforma en un "modelo de referencia", es decir, un tipo con la misma forma, pero con los parámetros de cuantización del tipo de elemento restablecidos a los valores predeterminados. Esto se usa como un truco útil para comparar los tipos de tensores y tensores cuantificados de forma uniforme, lo que es necesario con bastante frecuencia. En el caso de los tipos cuantificados, esto permite comparar tipos sin tener en cuenta los parámetros de cuantificación, es decir,shape
,storage_type
,expressed_type
,storage_min
,storage_max
yquantization_dimension
(para el tipo cuantificado por eje) deben coincidir, peroscales
yzero points
pueden diferir.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
se define en tipos de tensores cuantificados y los convierte en tipos de tensores de punto flotante. Esto se logra convirtiendo los elementos cuantificados que representan valores enteros del tipo de almacenamiento en valores de punto flotante correspondientes del tipo expresado con el punto cero y la escala asociados con el tipo de elemento cuantificado.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
se define en tipos de tensores de punto flotante y los convierte en tipos de tensores cuantizados. Esto se logra convirtiendo los valores de punto flotante del tipo expresado en valores enteros correspondientes del tipo de almacenamiento con el punto cero y la escala asociados con el tipo de elemento cuantificado.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
se usa para especificar cálculos por elemento en tensores cuantificados. Descuantiza, es decir, convierte elementos cuantificados en sus tipos expresados, luego realiza una operación y, luego, cuantiza, es decir, vuelve a convertir los resultados en sus tipos de almacenamiento. Por el momento, esta función solo funciona para la cuantificación por tensor. La cuantificación por eje está en curso (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
se usa para especificar la cuantificación solo de peso para la operación híbrida que acepta lhs en punto flotante y rhs en tipos cuantificados. Descuantiza las entradas cuantificadas en sus tipos expresados y realiza el procesamiento en números de punto flotante. El tipo de elemento del tensor de lhs de flotante y el tipo expresado del tensor de rhs quantizado deben ser idénticos.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Cálculos de cuadrícula
cross_partition(replica_groups: Value) -> Value
. Consulta la sección "cross_replica" más arriba.cross_replica(replica_groups: Value) -> Value
. Consulta la sección "cross_replica" más arriba.cross_replica_and_partition(replica_groups: Value) -> Value
. Consulta la sección "cross_replica_and_partition" anterior.flattened_ids(replica_groups: Value) -> Value
. Consulta la sección anterior "flattened_ids".
Dinamismo
Los valores de StableHLO pueden tener tamaños de dimensión dinámicos, p. ej., tensor<?xi64>
.
Sin embargo, los valores de StableHLO no pueden tener una cantidad dinámica de dimensiones (dinamismo no clasificado, p. ej., tensor<*xi64>
). Los operandos y los resultados pueden usar tamaños de dimensiones dinámicos, incluso si hay restricciones en los tamaños. Las restricciones se verificarán de forma estática si es posible; de lo contrario, se diferirán al entorno de ejecución y las discrepancias generarán un comportamiento no definido. Consulta los ejemplos a continuación.
Discrepancias de forma para operaciones unarias por elemento
Considera el siguiente programa de juguetes:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Un programa de este tipo es inusual, ya que no es común conocer la forma del resultado, pero no la forma de la entrada. Sin embargo, este es un programa StableHLO válido. No es posible validar de forma estática la operación abs
en este programa, ya que se desconoce la forma exacta del operando. Sin embargo, las formas son compatibles, y esto se puede verificar de forma estática: ?
podría ser 2
en el tiempo de ejecución, y no habría ningún problema. Sin embargo, ?
también podría ser algún otro número entero, en cuyo caso el comportamiento no está definido.
Ten en cuenta que, si el tamaño de una dimensión es dinámico en el resultado, no puede haber un comportamiento indefinido. De hecho, no hay un tamaño “esperado”, por lo que no puede haber una discrepancia.
Discrepancias de forma para operaciones binarias por elemento
Considera el siguiente programa de juguete:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Cuando se trata de operaciones binarias por elementos, las formas de las entradas y el resultado deben coincidir durante el tiempo de ejecución. En el tiempo de compilación, las dimensiones estáticas deben ser iguales, de lo contrario, solo deben ser compatibles. Si cualquier dimensión es dinámica en las entradas, podría haber un comportamiento indefinido en el tiempo de ejecución, porque el tamaño dinámico puede no coincidir con el tamaño correspondiente en el otro operando (ya sea estático o dinámico). Si todas las entradas son estáticas, no importa si el resultado es dinámico o no: las dimensiones conocidas de forma estática se verificarán de forma estática, y las dimensiones dinámicas no imponen ninguna restricción.
Discrepancias de forma para las operaciones que toman su forma de salida como operando
Considera el siguiente programa de juguete:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Los valores del operando de forma en el tiempo de ejecución deben coincidir con la forma del resultado; de lo contrario, el comportamiento será indefinido. Es decir, en el tiempo de ejecución, %arg0
debe tener un valor de dense<[3, 4]> : tensor<2xi32>
. Si el operando de forma es constante, esto se puede verificar de forma estática. Si la forma del resultado es completamente dinámica, no puede haber una discrepancia.