StableHLO es un conjunto de operaciones para operaciones de alto nivel (HLO) en modelos de aprendizaje automático (AA). StableHLO funciona como una capa de portabilidad entre diferentes frameworks y compiladores de AA: los frameworks de AA que producen programas de StableHLO son compatibles con los compiladores de AA que consumen programas de StableHLO.
Nuestro objetivo es simplificar y acelerar el desarrollo del AA creando más interoperabilidad entre varios frameworks de AA (como TensorFlow, JAX y PyTorch) y compiladores de AA (como XLA y IREE). Para ello, este documento proporciona una especificación del lenguaje de programación StableHLO.
Esta especificación contiene tres secciones principales. Primero, la sección Programas describe la estructura de los programas de StableHLO, que constan de funciones de StableHLO, que, a su vez, constan de operaciones de StableHLO. Dentro de esa estructura, la sección Ops especifica la semántica de las operaciones individuales. La sección Ejecución proporciona la semántica para todas estas operaciones que se ejecutan juntas dentro de un programa. Por último, en la sección Notación, se analiza la notación que se usa en toda la especificación.
Para ver la especificación de una versión anterior de StableHLO, abre el repo en la versión etiquetada que te interese. Por ejemplo, la especificación de StableHLO v0.19.0. Para ver los cambios que se produjeron en cada incremento de versión secundaria de StableHLO, consulta el registro de versiones en VhloDialect.td.
Programas
Program ::= {Func}
Los programas de StableHLO constan de una cantidad arbitraria de funciones de StableHLO.
A continuación, se muestra un ejemplo de programa con una función @main que tiene 3 entradas (%image, %weights y %bias) y 1 salida. El cuerpo de la función tiene 6 operaciones.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funciones
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Las funciones de StableHLO (también llamadas funciones con nombre) tienen un identificador, entradas y salidas, y un cuerpo. En el futuro, planeamos introducir metadatos adicionales para las funciones y lograr una mejor compatibilidad con HLO (#425, #626, #740, #744).
Identificadores
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Los identificadores de StableHLO son similares a los identificadores de muchos lenguajes de programación, con dos peculiaridades: 1) Todos los identificadores tienen sigilos que distinguen diferentes tipos de identificadores y 2) los identificadores de valores pueden ser completamente numéricos para simplificar la generación de programas de StableHLO.
Tipos
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Los tipos de StableHLO se clasifican en tipos de valor (también llamados tipos de primera clase), que representan valores de StableHLO, y tipos que no son de valor, que describen otros elementos del programa. Los tipos de StableHLO son similares a los de muchos lenguajes de programación, y su principal peculiaridad es la naturaleza específica del dominio de StableHLO, que genera algunos resultados inusuales (p.ej., los tipos escalares no son tipos de valor).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Los tipos de tensores representan tensores, es decir, arrays multidimensionales. Tienen una forma y un tipo de elemento, en los que una forma representa tamaños de dimensión no negativos o desconocidos en orden ascendente de las dimensiones correspondientes (que también se denominan ejes) numeradas del 0 al R-1. La cantidad de dimensiones R se denomina rango. Por ejemplo, tensor<2x3xf32> es un tipo de tensor con forma 2x3 y tipo de elemento f32. Tiene dos dimensiones (o, en otras palabras, dos ejes): la dimensión 0 y la dimensión 1, cuyos tamaños son 2 y 3. Su rango es 2.
Las formas pueden ser parcial o completamente desconocidas (dinámicas), p.ej., tensor<?x2xf64> es parcialmente desconocida y tensor<?x?xf64> es completamente desconocida. Los tamaños de dimensiones dinámicas se representan con un ?. No se pueden quitar las clasificaciones de las formas.
En el futuro, planeamos explorar la posibilidad de extender los tipos de tensores más allá de los tamaños de las dimensiones y los tipos de elementos, por ejemplo, para incluir diseños (#629) y dispersión (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerLiteral
QuantizationStorageMax ::= IntegerLiteral
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerLiteral
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale [':' QuantizationZeroPoint]
QuantizationScale ::= FloatLiteral
QuantizationZeroPoint ::= IntegerLiteral
| Nombre | Tipo | Limitaciones |
|---|---|---|
storage_type |
tipo de número entero | (C1 a C3) y (C8) |
storage_min |
Constante entera | (C1), (C3), (C7) |
storage_max |
Constante entera | (C2), (C3), (C7) |
expressed_type |
tipo de punto flotante | (C4) |
quantization_dimension |
constante entera opcional | (C10-C12) |
scales |
Cantidad variable de constantes de punto flotante | (C4 a C6), (C9), (C10) y (C13) |
zero_points |
cantidad variable de constantes de números enteros | (C7 a C9) |
Los tipos de elementos cuantificados representan valores enteros de un tipo de almacenamiento en el rango de storage_min a storage_max (inclusive) que corresponden a valores de punto flotante de un tipo expresado. Para un valor entero determinado i, el valor de punto flotante correspondiente f se puede calcular como f = (i - zero_point) * scale, donde scale y zero_point se denominan parámetros de cuantificación. Los parámetros storage_min y storage_max son opcionales en la gramática, pero tienen valores predeterminados de min_value(storage_type) y max_value(storage_type), respectivamente. Los tipos de elementos cuantificados tienen las siguientes restricciones:
- (C1)
type(storage_min) = storage_type. - (C2)
type(storage_max) = storage_type. - (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type). - (C4)
type(scales...) = expressed_type. - (C5)
0 < scales. - (C6)
is_finite(scales...). - (C7)
storage_min <= zero_points <= storage_max. - (C8)
type(zero_points...) = storage_type. - (C9)
size(scales) = size(zero_points). - (C10) Si
is_empty(quantization_dimension), entoncessize(scales) = 1. - (C11)
0 <= quantization_dimension.
Por el momento, QuantizationScale es una constante de punto flotante, pero existe un gran interés en las escalas basadas en números enteros, representadas con multiplicadores y desplazamientos. Tenemos previsto explorar esta opción en el futuro cercano (#1404).
Hay un debate en curso sobre la semántica de QuantizationZeroPoint, incluido el tipo, los valores y si puede haber solo un punto cero o potencialmente varios en un tipo de tensor cuantificado. Según los resultados de esta conversación, es posible que la especificación sobre los puntos nulos cambie en el futuro (#1405).
Otro debate en curso involucra la semántica de QuantizationStorageMin y QuantizationStorageMax para determinar si se deben imponer restricciones sobre estos valores y sobre los valores de los tensores cuantificados (#1406).
Por último, planeamos explorar la representación de escalas desconocidas y puntos cero, de manera similar a cómo planeamos explorar la representación de tamaños de dimensiones desconocidos (#1407).
Los tipos de tensores cuantificados representan tensores con elementos cuantificados. Estos tensores son exactamente iguales a los tensores regulares, excepto que sus elementos tienen tipos de elementos cuantificados, en lugar de tipos de elementos regulares.
En los tensores cuantificados, la cuantificación puede ser por tensor, es decir, tener un scale y un zero_point para todo el tensor, o bien puede ser por eje, es decir, tener varios scales y zero_points, un par por segmento de una dimensión particular quantization_dimension. De manera más formal, en un tensor t con cuantificación por eje, hay dim(t, quantization_dimension) segmentos del quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], etc. Todos los elementos del segmento i usan scales[i] y zero_points[i] como parámetros de cuantificación. Los tipos de tensores cuantificados tienen las siguientes restricciones:
- Para la cuantificación por tensor, haz lo siguiente:
- Sin restricciones adicionales.
- Para la cuantificación por eje, haz lo siguiente:
- (C12)
quantization_dimension < rank(self). - (C13)
dim(self, quantization_dimension) = size(scales).
- (C12)
TokenType ::= 'token'
Los tipos de tokens representan tokens, es decir, valores opacos que producen y consumen algunas operaciones. Los tokens se usan para imponer el orden de ejecución en las operaciones, como se describe en la sección Ejecución.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Los tipos de búfer representan búferes. Por ejemplo, en XLA, los búferes son arrays multidimensionales con almacenamiento coherente. Al igual que los tipos de tensor, los tipos de búfer tienen una forma y un tipo de elemento, en los que una forma representa tamaños de dimensión no negativos o desconocidos en orden ascendente de las dimensiones correspondientes (que también se denominan ejes) numeradas del 0 al R-1. La cantidad de dimensiones R se denomina rango. Por ejemplo, memref<2x3xf32> es un tipo de búfer con forma 2x3 y tipo de elemento f32. Tiene dos dimensiones (o, en otras palabras, dos ejes): la dimensión 0 y la dimensión 1, cuyos tamaños son 2 y 3. Su rango es 2.
Los búferes se pueden asignar con un custom_call a CreateBuffer o Pin, y se pueden liberar con un custom_call a Unpin. Solo las operaciones de custom_call pueden leer y escribir el contenido dentro de los búferes. Consulta custom_call para obtener más detalles.
Los tipos de tuplas representan tuplas, es decir, listas heterogéneas. Las tuplas son una función heredada que solo existe para la compatibilidad con HLO. En HLO, las tuplas se usan para representar entradas y salidas variádicas. En StableHLO, las entradas y salidas variádicas se admiten de forma nativa, y el único uso de tuplas en StableHLO es para representar de forma integral la ABI de HLO, en la que, p.ej., T, tuple<T> y tuple<tuple<T>> pueden ser muy diferentes según una implementación en particular. En el futuro, planeamos realizar cambios en la ABI de HLO, lo que podría permitirnos quitar los tipos de tuplas de StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f4E2M1FN' | 'f6E2M3FN' | 'f6E3M2FN' | 'f8E3M4' | 'f8E4M3'
| 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ' | 'f8E5M2'
| 'f8E5M2FNUZ' | 'f8E8M0FNU' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Los tipos de elementos representan elementos de tipos de tensores. A diferencia de muchos lenguajes de programación, estos tipos no son de primera clase en StableHLO. Esto significa que los programas de StableHLO no pueden representar directamente valores de estos tipos (por lo tanto, es idiomático representar valores escalares de tipo T con valores de tensor de 0 dimensiones de tipo tensor<T>).
- El tipo booleano representa los valores booleanos
trueyfalse. - Los tipos de números enteros pueden ser con signo (
si) o sin signo (ui) y tener uno de los anchos de bits admitidos (2,4,8,16,32o64). Los tipossiNcon signo representan valores enteros desde-2^(N-1)hasta2^(N-1)-1inclusive, y los tiposuiNsin signo representan valores enteros desde0hasta2^N-1inclusive. - Los tipos de punto flotante pueden ser uno de los siguientes:
f8E3M4,f8E4M3yf8E5M2son números de punto flotante de 8 bits que siguen las convenciones de IEEE-754.- Tipos
f8E4M3FNyf8E5M2que corresponden, respectivamente, a las codificacionesE4M3yE5M2del formato FP8 que se describe en FP8 Formats for Deep Learning. - Tipos
f8E4M3FNUZyf8E5M2FNUZque corresponden a las codificacionesE4M3yE5M2de los formatos FP8 que se describen en Formatos numéricos de 8 bits para redes neuronales profundas. - Tipo
f8E4M3B11FNUZque corresponde a la codificaciónE4M3de los formatos FP8 que se describen en Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks. - Tipo
bf16que corresponde al formatobfloat16descrito en BFloat16: El secreto del alto rendimiento en las Cloud TPU. - Tipos
f16,f32yf64que corresponden, respectivamente, a los formatosbinary16(“precisión media”),binary32(“precisión simple”) ybinary64(“precisión doble”) que se describen en el estándar IEEE 754. - El tipo
tf32corresponde al formato TensorFloat32 y tiene asistencia limitada en StableHLO. - Tipos de MX (microescalado)
f4E2M1FN,f6E2M3FN,f6E3M2FNyf8E8M0FNUque se describen en la Especificación de formatos de microescalado de OCP.
- Los tipos complejos representan valores complejos que tienen una parte real y una parte imaginaria del mismo tipo de elemento. Los tipos complejos admitidos son
complex<f32>(ambas partes son de tipof32) ycomplex<f64>(ambas partes son de tipof64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Los tipos de funciones representan funciones con nombre y anónimas. Tienen tipos de entrada (la lista de tipos a la izquierda de ->) y tipos de salida (la lista de tipos a la derecha de ->). En muchos lenguajes de programación, los tipos de funciones son de primera clase, pero no en StableHLO.
StringType ::= 'string'
El tipo de cadena representa secuencias de bytes. A diferencia de muchos lenguajes de programación, el tipo de cadena no es de primera clase en StableHLO y solo se usa para especificar metadatos estáticos para los elementos del programa.
Operaciones
Las operaciones de StableHLO (también llamadas ops) representan un conjunto cerrado de operaciones de alto nivel en los modelos de aprendizaje automático. Como se mencionó anteriormente, la sintaxis de StableHLO se inspira en gran medida en MLIR, que no es necesariamente la alternativa más ergonómica, pero podría decirse que es la que mejor se adapta al objetivo de StableHLO de crear más interoperabilidad entre los frameworks de AA y los compiladores de AA.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Las operaciones de StableHLO (también llamadas ops) tienen un nombre, entradas y salidas, y una firma. El nombre consta del prefijo stablehlo. y una mnemotecnia que identifica de forma única una de las operaciones admitidas. A continuación, se incluye una lista completa de todas las operaciones compatibles.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Las operaciones consumen entradas y producen salidas. Las entradas se clasifican en valores de entrada (calculados durante la ejecución), funciones de entrada (proporcionadas de forma estática, ya que en StableHLO las funciones no son valores de primera clase) y atributos de entrada (también proporcionados de forma estática). El tipo de entradas y salidas que consume y produce una operación depende de su mnemónico. Por ejemplo, la operación add consume 2 valores de entrada y produce 1 valor de salida. En comparación, la operación select_and_scatter consume 3 valores de entrada, 2 funciones de entrada y 3 atributos de entrada.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Las funciones de entrada (también llamadas funciones anónimas) son muy similares a las funciones con nombre, excepto por lo siguiente: 1) No tienen un identificador (de ahí el nombre “anónimas”). 2) No declaran tipos de salida (los tipos de salida se infieren a partir de la operación return dentro de la función).
La sintaxis de las funciones de entrada incluye una parte que no se usa actualmente (consulta la producción Unused anterior), que está allí para la compatibilidad con MLIR. En MLIR, existe un concepto más general de "regiones" que pueden tener varios "bloques" de operaciones conectados a través de operaciones de salto. Estos bloques tienen IDs que corresponden a la producción de Unused, de modo que se puedan distinguir entre sí.
StableHLO no tiene operaciones de salto, por lo que la parte correspondiente de la sintaxis de MLIR no se usa (pero sigue allí).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Los atributos de entrada tienen un nombre y un valor que es una de las constantes admitidas. Son la forma principal de especificar metadatos estáticos para los elementos del programa. Por ejemplo, la operación concatenate usa el atributo dimension para especificar la dimensión a lo largo de la cual se concatenan sus valores de entrada. Del mismo modo, la operación slice usa varios atributos, como start_indices y limit_indices, para especificar los límites que se usan para segmentar el valor de entrada.
Por el momento, los programas de StableHLO en la naturaleza a veces contienen atributos que no se describen en este documento. En el futuro, planeamos absorber estos atributos en el conjunto de operaciones de StableHLO o prohibir que aparezcan en los programas de StableHLO. Mientras tanto, aquí tienes la lista de estos atributos:
layout(#629).mhlo.frontend_attributes(#628).mhlo.sharding(#619).output_operand_aliases(#740).- Metadatos de ubicación (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma de la operación consta de los tipos de todos los valores de entrada (la lista de tipos en el lado izquierdo de ->) y los tipos de todos los valores de salida (la lista de tipos en el lado derecho de ->). En sentido estricto, los tipos de entrada son redundantes, y los tipos de salida también lo son casi siempre (porque, para la mayoría de las operaciones de StableHLO, los tipos de salida se pueden inferir a partir de las entradas). Sin embargo, la firma de la operación forma parte deliberadamente de la sintaxis de StableHLO para garantizar la compatibilidad con MLIR.
A continuación, se muestra un ejemplo de una operación cuyo mnemónico es select_and_scatter. Consume 3 valores de entrada (%operand, %source y %init_value), 2 funciones de entrada y 3 atributos de entrada (window_dimensions, window_strides y padding). Ten en cuenta que la firma de la operación solo incluye los tipos de sus valores de entrada (pero no los tipos de funciones y atributos de entrada que se proporcionan de forma intercalada).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Las constantes de StableHLO tienen un literal y un tipo que, juntos, representan un valor de StableHLO. En general, el tipo forma parte de la sintaxis de la constante, excepto cuando no es ambiguo (p.ej., una constante booleana tiene de forma no ambigua el tipo i1, mientras que una constante de número entero puede tener varios tipos posibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Las constantes booleanas representan los valores booleanos true y false. Las constantes booleanas tienen el tipo i1.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Las constantes de números enteros representan valores de números enteros a través de cadenas que usan notación decimal o hexadecimal. No se admiten otras bases, como la binaria o la octal. Las constantes de números enteros tienen las siguientes restricciones:
- (C1)
is_wellformed(integer_literal, integer_type).
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Las constantes de punto flotante representan valores de punto flotante a través de cadenas que usan notación decimal o científica. Además, se puede usar la notación hexadecimal para especificar directamente los bits subyacentes en el formato de punto flotante del tipo correspondiente. Las constantes de punto flotante tienen las siguientes restricciones:
- (C1) Si se usa una notación no hexadecimal,
is_wellformed(float_literal, float_type). - (C2) Si se usa la notación hexadecimal, se debe incluir
size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Las constantes complejas representan valores complejos con listas de una parte real (primero) y una parte imaginaria (segundo). Por ejemplo, (1.0, 0.0) : complex<f32> representa 1.0 + 0.0i y (0.0, 1.0) : complex<f32> representa 0.0 + 1.0i. El orden en que estas partes se almacenan en la memoria está definido por la implementación. Las constantes complejas tienen las siguientes restricciones:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type)). - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Las constantes de tensor representan valores de tensor con listas anidadas especificadas a través de la notación de NumPy. Por ejemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> representa un valor de tensor con la siguiente asignación de índices a elementos: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. El orden en el que estos elementos se almacenan en la memoria está definido por la implementación. Las constantes de tensor tienen las siguientes restricciones:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type)), donde:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
- (C2)
has_shape(tensor_literal, shape(tensor_type)), donde:has_shape(element_literal: Syntax, []) = true.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).- De lo contrario,
false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Las constantes de tensor cuantificado representan valores de tensor cuantificado con la misma notación que las constantes de tensor, con elementos especificados como constantes de su tipo de almacenamiento. Las constantes de tensores cuantificados tienen las siguientes restricciones:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)). - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Los literales de cadena constan de bytes especificados con caracteres ASCII y secuencias de escape. No dependen de la codificación, por lo que la interpretación de estos bytes se define en la implementación. Los literales de cadena tienen el tipo string.
Operaciones
abs
Semántica
Realiza una operación abs a nivel de los elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números enteros con signo, es el módulo de números enteros.
- Para números de punto flotante:
absde IEEE-754. - Para números complejos: módulo complejo.
- Para tipos cuantificados:
dequantize_op_quantize(abs, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo complejo, de punto flotante o de número entero con signo, o tensor cuantificado por tensor | (C1-C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de número entero o punto flotante con signo, o tensor cuantificado por tensor | (C1-C2) |
Limitaciones
- (C1)
shape(result) = shape(operand). - (C2)
baseline_element_type(result)se define de la siguiente manera:complex_element_type(element_type(operand))siis_complex(operand).- En caso contrario,
baseline_element_type(operand).
Ejemplos
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand)< : (t>ens>or3xi32<) - t>ensor3xi32
// %result: [2, 0, 2]
add
Semántica
Realiza la suma de dos tensores lhs y rhs elemento por elemento, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: suma de números enteros.
- Para números de punto flotante:
additionde IEEE-754. - Para números complejos: suma compleja.
- Para tipos cuantificados:
dequantize_op_quantize(add, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado | (C1 a C6) |
| (I2) | rhs |
tensor o tensor cuantificado | (C1 a C5) y (C7) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C1 a C7) |
Limitaciones
- Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C1)
type(lhs) = type(rhs) = type(result).
- (C1)
- Si la operación usa tensores cuantizados, se aplica lo siguiente:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result). - (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result). - (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result). - (C6) Si
is_per_axis_quantized(lhs), entoncesquantization_dimension(lhs) = quantization_dimension(result). - (C7) Si
is_per_axis_quantized(rhs), entoncesquantization_dimension(rhs) = quantization_dimension(result).
- (C2)
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[6, 8], [10, 12]]
after_all
Semántica
Garantiza que las operaciones que producen el inputs se ejecuten antes que cualquier operación que dependa de result. La ejecución de esta operación no hace nada, solo existe para establecer dependencias de datos de result a inputs.
Entradas
| Etiqueta | Nombre | Tipo |
|---|---|---|
| (I1) | inputs |
Cantidad variable de token |
Salidas
| Nombre | Tipo |
|---|---|
result |
token |
Ejemplos
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehl>o.token) - !stablehlo.token
all_gather
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, concatena los valores de los tensores operands de cada proceso a lo largo de all_gather_dim y produce tensores results.
La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:
cross_replica(replica_groups)sichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sichannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sichannel_id > 0 and use_global_device_ids = true.
Luego, dentro de cada process_group, haz lo siguiente:
operands...@receiver = [operand@sender for sender in process_group]para todos losreceiverenprocess_group.results...@process = concatenate(operands...@process, all_gather_dim)para todos losprocessenprocess_group.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operands |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1) y (C6) |
| (I2) | all_gather_dim |
constante de tipo si64 |
(C1) y (C6) |
| (I3) | replica_groups |
Tensor constante bidimensional de tipo si64 |
(C2-C4) |
| (I4) | channel_id |
constante de tipo si64 |
(C5) |
| (I5) | use_global_device_ids |
constante de tipo i1 |
(C5) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores o tensores cuantificados por tensor | (C6) |
Limitaciones
- (C1)
0 <= all_gather_dim < rank(operands...). - (C2)
is_unique(replica_groups). - (C3)
size(replica_groups)se define de la siguiente manera:num_replicassi se usacross_replica.num_replicassi se usacross_replica_and_partition.num_processessi se usaflattened_ids.
- (C4)
0 <= replica_groups < size(replica_groups). - (C5) Si
use_global_device_ids = true, entonceschannel_id > 0. - (C6)
type(results...) = type(operands...), excepto:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1).
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64,
// channel_id = 0
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
// use_global_device_ids = false
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64)< - (ten>sor2x4xi<64, ten>sor2x4xi64)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, aplica una función de reducción computation a los valores de los tensores operands de cada proceso y produce tensores results.
La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:
cross_replica(replica_groups)sichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sichannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sichannel_id > 0 and use_global_device_ids = true.
Luego, dentro de cada process_group, haz lo siguiente:
results...@process[result_index] = exec(schedule)para algún árbol binarioscheduledonde:exec(node)=computation(exec(node.left), exec(node.right)).exec(leaf)=leaf.value.
schedulees un árbol binario definido por la implementación cuyo recorrido en orden esto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0])).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operands |
Cantidad variable de tensores o tensores cuantificados por tensor | (C5), (C6) |
| (I2) | replica_groups |
Cantidad variable de constantes de tensor unidimensionales de tipo si64 |
(C1 a C3) |
| (I3) | channel_id |
constante de tipo si64 |
(C4) |
| (I4) | use_global_device_ids |
constante de tipo i1 |
(C4) |
| (I5) | computation |
función | (C5) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores o tensores cuantificados por tensor | (C6-C7) |
Limitaciones
- (C1)
is_unique(replica_groups). - (C2)
size(replica_groups)se define de la siguiente manera:num_replicassi se usacross_replica.num_replicassi se usacross_replica_and_partition.num_processessi se usaflattened_ids.
- (C3)
0 <= replica_groups < size(replica_groups). - (C4) Si
use_global_device_ids = true, entonceschannel_id > 0. - (C5)
computationtiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>), dondeis_promotable(element_type(operand), E). - (C6)
shape(results...) = shape(operands...). - (C7)
element_type(results...) = E.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()<
}) {
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
// channel_id = 0
channel_hand<le = #stablehlo.chan>nel_handlehandle = 0, type = 0
// use_global_<devic>e_ids = <false>
} >: (tenso<r4xi6>4, tenso<r4xi6>4) - (tensor4xi64, tensor4xi64)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, divide los valores de los tensores operands a lo largo de split_dimension en partes, dispersa las partes divididas entre los procesos, concatena las partes dispersas a lo largo de concat_dimension y produce tensores results.
La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:
cross_replica(replica_groups)sichannel_id <= 0.cross_partition(replica_groups)sichannel_id > 0.
Luego, dentro de cada process_group, haz lo siguiente:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)para todosenderenprocess_group.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]dondereceiver_index = process_group.index(receiver).results...@process = concatenate(scattered_parts...@process, concat_dimension).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operands |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1 a C3) y (C9) |
| (I2) | split_dimension |
constante de tipo si64 |
(C1), (C2), (C9) |
| (I3) | concat_dimension |
constante de tipo si64 |
(C3), (C9) |
| (I4) | split_count |
constante de tipo si64 |
(C2), (C4), (C8) y (C9) |
| (I5) | replica_groups |
Tensor constante bidimensional de tipo si64 |
(C5-C8) |
| (I6) | channel_id |
constante de tipo si64 |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores o tensores cuantificados por tensor | (C9) |
Limitaciones
- (C1)
0 <= split_dimension < rank(operands...). - (C2)
dim(operands..., split_dimension) % split_count = 0. - (C3)
0 <= concat_dimension < rank(operands...). - (C4)
0 < split_count. - (C5)
is_unique(replica_groups). - (C6)
size(replica_groups)se define de la siguiente manera:num_replicassi se usacross_replica.num_partitionssi se usacross_partition.
- (C7)
0 <= replica_groups < size(replica_groups). - (C8)
dim(replica_groups, 1) = split_count. - (C9)
type(results...) = type(operands...), excepto sisplit_dimension != concat_dimension:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_grou<ps = den>se[[0, 1]<] : ten>sor1x2xi64
// channel_id = 0
}< : (ten>sor2x4xi<64, ten>sor>2x4xi64)< - (ten>sor4x2xi<64, ten>sor4x2xi64)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
y
Semántica
Realiza la operación AND a nivel del elemento de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: AND lógico.
- Para números enteros: AND bit a bit.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo booleano o entero | (C1) |
| (I2) | rhs |
tensor de tipo booleano o entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo booleano o entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result).
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 0]]
atan2
Semántica
Realiza una operación atan2 a nivel del elemento en los tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
atan2de IEEE-754. - Para números complejos: atan2 compleja.
- Para tipos cuantificados:
dequantize_op_quantize(atan2, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Ejemplos
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs)< : (t>ensor3xf<64, t>ens>or3xf64<) - t>ensor3xf64
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semántica
Calcula los gradientes de varias entradas de batch_norm_training a través de la retropropagación desde grad_output y produce tensores grad_operand, grad_scale y grad_offset. De manera más formal, esta operación se puede expresar como una descomposición en operaciones existentes de StableHLO con la siguiente sintaxis de Python:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Para los tipos cuantificados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1 a C3) y (C5) |
| (I2) | scale |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2), (C4) y (C5) |
| (I3) | mean |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2), (C4) |
| (I4) | variance |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2), (C4) |
| (I5) | grad_output |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2), (C3) |
| (I6) | epsilon |
constante de tipo f32 |
|
| (I7) | feature_index |
constante de tipo si64 |
(C1) y (C5) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
grad_operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2), (C3) |
grad_scale |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2), (C4) |
grad_offset |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2), (C4) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,mean,variance,grad_output,grad_operand,grad_scaleygrad_offsettienen el mismobaseline_element_type. - (C3)
operand,grad_outputygrad_operandtienen la misma forma. - (C4)
scale,mean,variance,grad_scaleygrad_offsettienen la misma forma. - (C5)
size(scale) = dim(operand, feature_index).
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf64,
< tenso>r2x>2x2xf64)< - (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semántica
Normaliza el tensor operand en todas las dimensiones, excepto en la dimensión feature_index, y produce un tensor result. De manera más formal, esta operación se puede expresar como una descomposición en operaciones existentes de StableHLO con la siguiente sintaxis de Python:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Para los tipos cuantificados, realiza dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1 a C7) |
| (I2) | scale |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2), (C3) |
| (I3) | offset |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2), (C4) |
| (I4) | mean |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C5) |
| (I5) | variance |
Tensor unidimensional de tipo cuantificado de punto flotante o por tensor | (C2) y (C6) |
| (I6) | epsilon |
constante de tipo f32 |
|
| (I7) | feature_index |
constante de tipo si64 |
(C1), (C3 a C6) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2), (C7) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,mean,varianceyresulttienen el mismobaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(mean) = dim(operand, feature_index). - (C6)
size(variance) = dim(operand, feature_index). - (C7)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ensor2xf<64, t>ens>or2xf64<) - tenso>r2x2x2xf64
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semántica
Calcula la media y la varianza en todas las dimensiones, excepto la dimensión feature_index, y normaliza el tensor operand para producir los tensores output, batch_mean y batch_var. De manera más formal, esta operación se puede expresar como una descomposición en operaciones existentes de StableHLO con la siguiente sintaxis de Python:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Para los tipos cuantificados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
| (I2) | scale |
Tensor unidimensional de valores de punto flotante o cuantificados por tensor | (C2), (C3) |
| (I3) | offset |
Tensor unidimensional de valores de punto flotante o cuantificados por tensor | (C2), (C4) |
| (I4) | epsilon |
constante de tipo f32 |
(C1), (C3 a C6) |
| (I5) | feature_index |
constante de tipo si64 |
(C1), (C3 a C6) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
output |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C7) |
batch_mean |
Tensor unidimensional de valores de punto flotante o cuantificados por tensor | (C2), (C5) |
batch_var |
Tensor unidimensional de punto flotante o cuantificado por tensor | (C2) y (C6) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand). - (C2)
operand,scale,offset,batch_mean,batch_varyoutputtienen el mismobaseline_element_type. - (C3)
size(scale) = dim(operand, feature_index). - (C4)
size(offset) = dim(operand, feature_index). - (C5)
size(batch_mean) = dim(operand, feature_index). - (C6)
size(batch_var) = dim(operand, feature_index). - (C7)
baseline_type(output) = baseline_type(operand).
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
}< : (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ens>or2xf64) -
< (tenso>r2x2x2xf<64, t>ensor2xf<64, t>ensor2xf64)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semántica
Realiza una operación de bitcast en el tensor operand y produce un tensor result en el que los bits de todo el tensor operand se reinterpretan con el tipo del tensor result.
De manera más formal, dadas las condiciones E = element_type(operand), E' = element_type(result) y R = rank(operand):
- Si
num_bits(E') < num_bits(E),bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]). - Si
num_bits(E') > num_bits(E),bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]). - Si
num_bits(E') = num_bits(E),bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
bits devuelve la representación en memoria de un valor determinado, y su comportamiento se define en la implementación, ya que la representación exacta de los tensores y los tipos de elementos también se definen en la implementación.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado | (C1-C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C1-C2) |
Limitaciones
- (C1) Dadas
E = is_quantized(operand) ? storage_type(operand) : element_type(operand),E' = is_quantized(result) ? storage_type(result) : element_type(result)yR = rank(operand):- Si es
num_bits(E') = num_bits(E),shape(result) = shape(operand). - Si
num_bits(E') < num_bits(E), haz lo siguiente: rank(result) = R + 1.dim(result, i) = dim(operand, i)para todos los0 <= i < Rdim(result, R) * num_bits(E') = num_bits(E).- Si
num_bits(E') > num_bits(E), haz lo siguiente: rank(result) = R - 1.dim(result, i) = dim(operand, i)para todos los0 <= i < Rdim(operand, R - 1) * num_bits(E) = num_bits(E').
- Si es
- (C2) Si
is_complex(operand) or is_complex(result), entoncesis_complex(operand) and is_complex(result).
Ejemplos
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand)< : >(te>nsorf64<) - t>ensor4xf16
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semántica
Expande las dimensiones o el rango de un tensor de entrada duplicando los datos en el tensor operand y produce un tensor result. De manera más formal, result[result_index] = operand[operand_index], donde para todo d en axes(operand), se cumple lo siguiente:
operand_index[d] = 0sidim(operand, d) = 1.- En caso contrario,
operand_index[d] = result_index[broadcast_dimensions[d]].
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado | (C1-C2), (C5-C6) |
| (I2) | broadcast_dimensions |
Tensor constante unidimensional de tipo si64 |
(C2-C6) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3), (C5-C6) |
Limitaciones
- (C1)
element_type(result)se calcula de la siguiente manera:element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), excepto quequantization_dimension(operand),scales(operand)yzero_points(operand)pueden diferir dequantization_dimension(result),scales(result)yzero_points(result), respectivamente, en otros casos.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Para todo
denaxes(operand):dim(operand, d) = 1odim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Si
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Si
dim(operand, quantization_dimension(operand)) = 1, entoncesscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
Ejemplos
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensio<ns = arra>yi64: 2, 1
}< : (ten>sor>1x3xi32<) - tenso>r2x3x2xi32
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
caso
Semántica
Produce el resultado de ejecutar exactamente una función de branches, según el valor de index. De manera más formal, result = selected_branch(), donde:
selected_branch = branches[index]si0 <= index < size(branches).- En caso contrario,
selected_branch = branches[-1].
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | index |
Tensor de 0 dimensiones del tipo si32 |
|
| (I2) | branches |
Cantidad variable de funciones | (C1-C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores, tensores cuantificados o tokens | (C4) |
Limitaciones
- (C1)
0 < size(branches). - (C2)
input_types(branches...) = []. - (C3)
same(output_types(branches...)). - (C4)
type(results...) = output_types(branches[0]).
Ejemplos
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %resul<t_bra>nch0) : <(tens>or2>xi64, tensor2xi64) - ()
}, {
"stablehlo.return"(%result_branc<h1, %>result_b<ranch>1) >: (tensor2xi64, <ten>sor>2xi64) -< ()
}>) : (ten<sori3>2) - (tensor2xi64, tensor2xi64)
// %result0: [1, 1]
// %result1: [1, 1]
cbrt
Semántica
Realiza una operación de raíz cúbica a nivel de los elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
rootn(x, 3)de IEEE-754. - Para números complejos: raíz cúbica compleja.
- Para los tipos cuantificados:
dequantize_op_quantize(cbrt, operand, type(result))
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand)< : (t>ens>or4xf64<) - t>ensor4xf64
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semántica
Realiza la operación ceil para cada elemento del tensor operand y produce un tensor result.
Implementa la operación roundToIntegralTowardPositive de la especificación IEEE-754. Para los tipos cuantificados, realiza dequantize_op_quantize(ceil, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
cholesky
Semántica
Calcula la descomposición de Cholesky de un lote de matrices.
De manera más formal, para todo i en index_space(result), result[i0, ..., iR-3, :, :] es una descomposición de Cholesky de a[i0, ..., iR-3, :, :], en forma de una matriz triangular inferior (si lower es true) o triangular superior (si lower es false).
Los valores de salida en el triángulo opuesto, es decir, el triángulo estrictamente superior o el triángulo estrictamente inferior, respectivamente, se definen según la implementación.
Si existe i donde la matriz de entrada no es una matriz hermitiana definida positiva, el comportamiento no está definido.
Para los tipos cuantificados, realiza dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | a |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1 a C3) |
| (I2) | lower |
constante de tipo i1 |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(a) = baseline_type(result). - (C2)
2 <= rank(a). - (C3)
dim(a, -2) = dim(a, -1).
Ejemplos
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
}< : (ten>sor>3x3xf32<) - ten>sor3x3xf64
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
restringir
Semántica
Ajusta cada elemento del tensor operand entre un valor mínimo y uno máximo, y produce un tensor result. De manera más formal, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element), donde min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Para los tipos cuantificados, realiza dequantize_op_quantize(clamp, min, operand, max, type(result)).
Imponer un orden en los números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | min |
tensor o tensor cuantificado por tensor | (C1), (C3) |
| (I2) | operand |
tensor o tensor cuantificado por tensor | (C1-C4) |
| (I3) | max |
tensor o tensor cuantificado por tensor | (C2), (C3) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C4) |
Limitaciones
- (C1)
rank(min) = 0 or shape(min) = shape(operand). - (C2)
rank(max) = 0 or shape(max) = shape(operand). - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max). - (C4)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max)< : (t>ensor3xi<32, t>ensor3xi<32, t>ens>or3xi32<) - t>ensor3xi32
// %result: [5, 13, 20]
collective_broadcast
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, envía el valor del tensor operand del proceso de origen a los procesos de destino y genera un tensor result.
La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:
cross_replica(replica_groups)sichannel_id <= 0.cross_partition(replica_groups)sichannel_id > 0.
Luego, result@process se calcula de la siguiente manera:
operand@process_groups[i, 0]si existe unital que el proceso se encuentra enprocess_groups[i]broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))en caso contrario.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C3) |
| (I2) | replica_groups |
Cantidad variable de constantes de tensor unidimensionales de tipo si64 |
(C1), (C2) |
| (I3) | channel_id |
constante de tipo si64 |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3) |
Limitaciones
- (C1)
is_unique(replica_groups). - (C2)
0 <= replica_groups < N, dondeNse define de la siguiente manera:num_replicassi se usacross_replica.num_partitionssi se usacross_partition.
- (C3)
type(result) = type(operand).
Ejemplos
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_grou<ps = den>se[[2, 1]<] : ten>sor1x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
} : (ten>sor>1x2xi64<) - ten>sor1x2xi64
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, envía el valor del tensor operand del proceso de origen al proceso de destino y produce un tensor result.
La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:
cross_replica(source_target_pairs)sichannel_id <= 0.cross_partition(source_target_pairs)sichannel_id > 0.
Luego, result@process se calcula de la siguiente manera:
operand@process_groups[i, 0], si existe unital queprocess_groups[i, 1] = processbroadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))en caso contrario.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C5) |
| (I2) | source_target_pairs |
Tensor constante bidimensional de tipo si64 |
(C1-C4) |
| (I3) | channel_id |
constante de tipo si64 |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, dondeNse define de la siguiente manera:num_replicassi se usacross_replica.num_partitionssi se usacross_partition.
- (C5)
type(result) = type(operand).
Ejemplos
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64,
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 0
}< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
comparar
Semántica
Realiza una comparación de cada elemento de los tensores lhs y rhs según comparison_direction y compare_type, y produce un tensor result.
Los valores de comparison_direction y compare_type tienen la siguiente semántica:
Para los tipos de elementos booleanos y de números enteros, haz lo siguiente:
EQ:lhs = rhs.NE:lhs != rhs.GE:lhs >= rhs.GT:lhs > rhs.LE:lhs <= rhs.LT:lhs < rhs.
Para los tipos de elementos de punto flotante con compare_type = FLOAT, la operación implementa las siguientes operaciones IEEE-754:
EQ:compareQuietEqual.NE:compareQuietNotEqual.GE:compareQuietGreaterEqual.GT:compareQuietGreater.LE:compareQuietLessEqual.LT:compareQuietLess.
Para los tipos de elementos de punto flotante con compare_type = TOTALORDER, la operación usa la combinación de operaciones totalOrder y compareQuietEqual de IEEE-754.
Para los tipos de elementos complejos, la comparación lexicográfica de los pares (real, imag) se realiza con los comparison_direction y compare_type proporcionados.
Imponer un orden en los números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos cuando comparison_direction sea GE, GT, LE o LT (#560).
Para tipos cuantificados, realiza dequantize_compare(lhs, rhs,
comparison_direction).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado por tensor | (C1 a C3) |
| (I2) | rhs |
tensor o tensor cuantificado por tensor | (C1-C2) |
| (I3) | comparison_direction |
Enumeración de EQ, NE, GE, GT, LE y LT |
|
| (I4) | compare_type |
Enumeración de FLOAT, TOTALORDER, SIGNED y UNSIGNED |
(C3) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo booleano | (C2) |
Limitaciones
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs). - (C2)
shape(lhs) = shape(rhs) = shape(result). - (C3)
compare_typese define de la siguiente manera:SIGNEDsiis_signed_integer(element_type(lhs)).UNSIGNEDsiis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).FLOAToTOTALORDERsi esis_float(element_type(lhs)).FLOATsiis_complex(element_type(lhs)).
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = <#stablehlocomparison_di>rection LT,
compare_type = <#stablehlocomparison_>type FLOAT
}< : (t>ensor2xf<32, t>ens>or2xf32<) - >tensor2xi1
// %result: [true, false]
emergencia compleja,
Semántica
Realiza una conversión de cada elemento en un valor complejo a partir de un par de valores reales e imaginarios, lhs y rhs, y produce un tensor result.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
Tensor de tipo f32 o f64 |
(C1 a C3) |
| (I2) | rhs |
Tensor de tipo f32 o f64 |
(C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo complejo | (C2), (C3) |
Limitaciones
- (C1)
type(lhs) = type(rhs). - (C2)
shape(result) = shape(lhs). - (C3)
element_type(result)tiene el tipocomplex<E>, dondeE = element_type(lhs).
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs)< : (t>ensor2xf<64, t>ens>or2xf64<) - tenso<r2x>>complexf64
// %result: [(1.0, 2.0), (3.0, 4.0)]
compuesto
Semántica
Encapsula una operación compuesta por otras operaciones de StableHLO, que toma inputs y composite_attributes y produce results. La semántica de la operación se implementa con el atributo decomposition. La operación composite se puede reemplazar por su descomposición sin cambiar la semántica del programa. En los casos en que la incorporación de la descomposición no proporciona la misma semántica de la operación, se prefiere usar custom_call.
El campo version (el valor predeterminado es 0) se usa para indicar cuándo cambian las semánticas de un elemento compuesto.
Entradas
| Etiqueta | Nombre | Tipo |
|---|---|---|
| (I1) | inputs |
Cantidad variable de valores |
| (I2) | name |
constante de tipo string |
| (I3) | composite_attributes |
diccionario de atributos |
| (I4) | decomposition |
constante de tipo string |
| (I5) | version |
constante de tipo si32 |
Salidas
| Nombre | Tipo |
|---|---|
results |
Cantidad variable de valores |
Limitaciones
- (C1)
is_namespaced_op_name(name) - (C2)
is_defined_in_parent_scope(decomposition) - (C3)
types(inputs...) == input_types(decomposition) - (C4)
types(results...) == output_types(decomposition)
Ejemplos
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
< ve>rsion = <1 :> i3>2
} : (<ten>sorf32, tensorf32) - tensorf32
concatenate
Semántica
Concatena inputs a lo largo de la dimensión dimension en el mismo orden que los argumentos proporcionados y produce un tensor result. De manera más formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], donde:
id = d0 + ... + dk-1 + kd.des igual adimension, yd0… son los tamaños de la dimensiónddeinputs.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1 a C6) |
| (I2) | dimension |
constante de tipo si64 |
(C2), (C4), (C6) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5-C6) |
Limitaciones
- (C1)
same(element_type(inputs...)). - (C2)
same(shape(inputs...)), exceptodim(inputs..., dimension). - (C3)
0 < size(inputs). - (C4)
0 <= dimension < rank(inputs[0]). - (C5)
element_type(result) = element_type(inputs[0]). - (C6)
shape(result) = shape(inputs[0]), excepto en los siguientes casos:dim(result, dimension) = dim(inputs[0], dimension) + ....
Ejemplos
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
}< : (ten>sor3x2xi<64, ten>sor>1x2xi64<) - ten>sor4x2xi64
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Semántica
Produce un tensor output a partir de una constante value.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | value |
constante | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
output |
tensor o tensor cuantificado | (C1) |
Limitaciones
- (C1)
type(value) = type(output).
Ejemplos
%output = "stablehlo.constant"() {
val<ue = dense[[0.0, 1.0], [>2.0, 3.0]<] : ten>sor2x2xf3>2
} : (<) - ten>sor2x2xf32
// %output: [[0.0, 1.0], [2.0, 3.0]]
generar una conversión
Semántica
Realiza una conversión de un tipo de elemento a otro en el tensor operand y produce un tensor result.
En el caso de las conversiones de boolean-to-any-supported-type, el valor false se convierte en cero y el valor true se convierte en uno. En el caso de las conversiones de any-supported-type-to-boolean, un valor cero se convierte en false y los valores distintos de cero se convierten en true. Consulta a continuación cómo funciona para los tipos complejos.
En el caso de las conversiones que involucran números enteros a números enteros, números enteros a números de punto flotante o números de punto flotante a números de punto flotante, si el valor de origen se puede representar exactamente en el tipo de destino, el valor del resultado es esa representación exacta. De lo contrario, el comportamiento está por definirse (#180).
En el caso de las conversiones que involucran floating-point-to-integer, se trunca la parte fraccionaria. Si el valor truncado no se puede representar en el tipo de destino, el comportamiento está pendiente de definición (TBD) (#180).
Las conversiones que involucran números complejos a complejos siguen el mismo comportamiento que las conversiones de punto flotante a punto flotante para convertir partes reales e imaginarias.
En el caso de las conversiones de complex-to-any-other-type y any-other-type-to-complex, se ignora el valor imaginario de la fuente o se establece en cero el valor imaginario del destino, respectivamente. La conversión de la parte real sigue las conversiones de punto flotante.
En principio, esta operación podría expresar la desantificación (conversión de tensores cuantificados a tensores regulares), la cuantificación (conversión de tensores regulares a tensores cuantificados) y la recuantificación (conversión entre tensores cuantificados), pero, por el momento, tenemos operaciones dedicadas para eso: uniform_dequantize para el primer caso de uso y uniform_quantize para el segundo y el tercer caso de uso. En el futuro, es posible que estas dos operaciones se fusionen en convert (#1576).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor | (C1) |
Limitaciones
- (C1)
shape(operand) = shape(result).
Ejemplos
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand)< : (t>ens>or3xi64<) - tenso<r3x>>complexf64
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolución
Semántica
Calcula los productos punto entre las ventanas de lhs y los segmentos de rhs, y produce result. En el siguiente diagrama, se muestra cómo se calculan los elementos de result a partir de lhs y rhs con un ejemplo concreto.
De manera más formal, considera el siguiente replanteamiento de las entradas en términos de lhs para poder expresar ventanas de lhs:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).lhs_window_strides = lhs_shape(1, window_strides, 1).lhs_padding = lhs_shape([0, 0], padding, [0, 0]).lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).
Este replanteamiento usa las siguientes funciones auxiliares:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]dondej[d] = i[permutation[d]].
Si feature_group_count = 1 y batch_group_count = 1, entonces para todo output_spatial_index en index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product donde:
padding_value = constant(0, element_type(lhs)).padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Esta función parece no usarse, por lo que, en el futuro, planeamos quitarla (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).
Si feature_group_count > 1, haz lo siguiente:
lhses = split(lhs, feature_group_count, input_feature_dimension).rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Si batch_group_count > 1, haz lo siguiente:
lhses = split(lhs, batch_group_count, input_batch_dimension).rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).result = concatenate(results, output_feature_dimension).
Para los tipos cuantificados, realiza dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result)).
Para los tipos cuantificados híbridos, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado por tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32), (C34) |
| (I2) | rhs |
tensor o tensor cuantificado | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
| (I3) | window_strides |
Tensor constante unidimensional de tipo si64 |
(C2-C3), (C25) |
| (I4) | padding |
Tensor constante bidimensional de tipo si64 |
(C4), (C25) |
| (I5) | lhs_dilation |
Tensor constante unidimensional de tipo si64 |
(C5-C6), (C25) |
| (I6) | rhs_dilation |
Tensor constante unidimensional de tipo si64 |
(C7-C8), (C25) |
| (I7) | window_reversal |
Tensor constante unidimensional de tipo i1 |
(C9) |
| (I8) | input_batch_dimension |
constante de tipo si64 |
(C10), (C13), (C25) |
| (I9) | input_feature_dimension |
constante de tipo si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Tensor constante unidimensional de tipo si64 |
(C12), (C13), (C25) |
| (I11) | kernel_input_feature_dimension |
constante de tipo si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
constante de tipo si64 |
(C15-C16), (C18), (C25), (C29) |
| (I13) | kernel_spatial_dimensions |
Tensor constante unidimensional de tipo si64 |
(C17-C18), (C25) |
| (I14) | output_batch_dimension |
constante de tipo si64 |
(C20), (C25) |
| (I15) | output_feature_dimension |
constante de tipo si64 |
(C20), (C25), (C30) |
| (I16) | output_spatial_dimensions |
Tensor constante unidimensional de tipo si64 |
(C19-C20), (C25) |
| (I17) | feature_group_count |
constante de tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
constante de tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
| (I19) | precision_config |
Cantidad variable de enumeraciones de DEFAULT, HIGH y HIGHEST |
(C24) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C25-C28), (C30), (C32-34) |
Limitaciones
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)se define de la siguiente manera:dim(lhs, input_batch_dimension) / batch_group_countsiresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)siresult_dim = output_feature_dimension.num_windowsen otros casos, donde:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Si la operación usa tensores cuantizados, se aplica lo siguiente:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Si
is_per_axis_quantized(rhs), entoncesquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Si
is_per_axis_quantized(result), entoncesquantization_dimension(result) = output_feature_dimension. - Si
is_quantized(lhs), haz lo siguiente: - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Si
is_per_tensor_quantized(rhs), entoncesis_per_tensor_quantized(result). - Si
!is_quantized(lhs), haz lo siguiente: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Ejemplos
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strid<es = arra>yi64: 4, 4,
paddi<n>g = dense<0 : ten>sor2x2xi64,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
// In the StableHLO dialect, dimension numbers are encoded vi<a:
// `[input >dim<ensions]x[kernel >di>mensions]-[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" a<re spatial dimensions.
d>imension_num>bers = #stablehlo.conv[b, 0, 1, f]x[0, 1, i, o]-[b, 0, 1, f],
batch_group_count = 1 : i64,
fea<ture_group_count >= 1 : i64,
< precision_config> = [#stablehl<oprecision >DEFAULT,< #stablehlo>pre>cision <DEFAULT]
} >: (tensor1x4x4x1xi64, tensor3x3x1x1xi64) - tensor1x2x2x1xi64
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semántica
Realiza una operación de coseno a nivel de los elementos en el tensor operand y genera un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
cosde IEEE-754. - Para números complejos, coseno complejo.
- Para tipos cuantificados:
dequantize_op_quantize(cosine, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semántica
Realiza un recuento por elemento de la cantidad de bits cero iniciales en el tensor operand y produce un tensor result.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo entero | (C1) |
Limitaciones
- (C1)
type(operand) = type(result).
Ejemplos
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand)< : (ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[64, 63], [56, 0]]
custom_call
Semántica
Encapsula una operación call_target_name definida por la implementación que toma inputs y called_computations, y produce results. Se pueden usar has_side_effect, backend_config y api_version para proporcionar metadatos adicionales definidos por la implementación.
Por el momento, esta operación contiene una colección de metadatos bastante desorganizada que refleja la evolución orgánica de su operación equivalente en el compilador de XLA. En el futuro, planeamos unificar estos metadatos (#741).
Entradas
| Etiqueta | Nombre | Tipo |
|---|---|---|
| (I1) | inputs |
Cantidad variable de valores |
| (I2) | call_target_name |
constante de tipo string |
| (I3) | has_side_effect |
constante de tipo i1 |
| (I4) | backend_config |
Constante de tipo string o diccionario de atributos |
| (I5) | api_version |
constante de tipo si32 |
| (I6) | called_computations |
Cantidad variable de constantes del tipo string |
| (I7) | output_operand_aliases |
especificar las partes de alias en las salidas y los operandos |
Salidas
| Nombre | Tipo |
|---|---|
results |
Cantidad variable de valores |
(Compatibilidad con GPU de XLA) Destinos de custom_call especiales
Existen tres call_target_name especiales relacionados con los tipos buffer:
CreateBuffer crea un buffer sin inicializar, Pin crea un buffer inicializado y Unpin desasigna un buffer y devuelve el contenido del buffer.
%uninitialized_buffer = "stablehlo.custom_call"() {
call_target_name = "CreateBuffer",
api_version> = 4 : <i32,
>} : () - memref4xf64
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
call_target_name = "Pin&quo<t;,
> ap>i_versi<on = >4 : i32,
} : (tensor4xf64) - memref4xf64
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
call_target_na<me = >&qu>ot;Unpi<n&quo>t;,
api_version = 4 : i32,
} : (memref4xf64) - tensor4xf64
Alias
Algunas operaciones custom_call pueden requerir que una parte de las salidas y una parte de los operandos compartan la misma memoria. Esto se puede expresar a través de output_operand_aliases. Una representación de par de alias consta de una lista de índices de tuplas de salida que representan la parte de salida y un operand_index junto con una lista de índices de tuplas de operandos que representan la parte del operando. La lista de índices de tuplas de salida o de operandos está vacía si el tipo correspondiente no es un tipo tuple y puede ser arbitrariamente larga para un tipo de tupla anidada arbitrariamente. Esto es similar a la representación del alias de XLA.
La parte de salida y la parte de entrada de un par de alias deben tener el mismo tipo. En el caso de las operaciones de custom_call que no son llamadas a CreateBuffer, Pin y Unpin, un operando buffer puede aparecer en, como máximo, un par de alias, y una salida buffer debe aparecer en un par de alias.
Ejemplos
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations <= [>@fo>o]
} : <(te>nsorf64) - tensorf64
%updated_buffer = "stablehlo.custom_call"(%buffer) {
call_target_name = "Update",
api_version = 4 : i32,
output_operand_aliases< = [
#stablehlo.output_operand_aliasoutput_tuple_indices = [],
operand_ind>ex = 0,
< oper>and>_tuple_<indic>es = []]
} : (memref4xf64) - memref4xf64
dividir
Semántica
Realiza la división de cada elemento de los tensores de dividendo lhs y divisor rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: división de números enteros que produce el cociente algebraico con cualquier parte fraccionaria descartada.
- Para números de punto flotante:
divisionde IEEE-754. - Para números complejos: división compleja.
- Para los tipos cuantificados:
dequantize_op_quantize(divide, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
Tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Ejemplos
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs)< : (t>ensor4xf<32, t>ens>or4xf32<) - t>ensor4xf32
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semántica
Calcula los productos punto entre las segmentaciones de lhs y las de rhs, y produce un tensor result.
De manera más formal, result[result_index] = dot_product, donde:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].result_batching_index + result_lhs_index + result_rhs_index = result_index, dondesize(result_batching_index) = size(lhs_batching_dimensions),size(result_lhs_index) = size(lhs_result_dimensions)ysize(result_rhs_index) = size(rhs_result_dimensions).transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).
Para los tipos cuantificados, realiza dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).
Para los tipos cuantificados híbridos, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs).
precision_config controla la compensación entre velocidad y precisión para los cálculos en los backends del acelerador. Puede ser uno de los siguientes (por el momento, la semántica de estos valores de enumeración no está especificada, pero planeamos abordar este problema en #755):
DEFAULT: Es el cálculo más rápido, pero la aproximación menos precisa al número original.HIGH: Cálculo más lento, pero aproximación más precisa al número original.HIGHEST: Es el cálculo más lento, pero la aproximación más precisa al número original.
Un DotAlgorithm define las propiedades principales del algoritmo que se usa para implementar la operación de punto, que también define la precisión. Si se configuran los campos de atributos del algoritmo, precision_config debe ser DEFAULT. DotAlgorithms
no tienen un valor predeterminado, ya que los parámetros predeterminados se definen en la implementación. Por lo tanto, todos los campos del algoritmo de puntos pueden establecerse en None para especificar un algoritmo de puntos vacío, que, en cambio, usará el valor precision_config.
Los campos de DotAlgorithm incluyen lo siguiente:
lhs_precision_typeyrhs_precision_type, las precisiones a las que se redondean el LHD y el RHD de la operación. Los tipos de precisión son independientes de los tipos de almacenamiento de las entradas y la salida.accumulation_typela precisión que se usa para la acumulación.lhs_component_count,rhs_component_countynum_primitive_operationsse aplican cuando usamos un algoritmo que descompone el LHD o el LHD en varios componentes y realiza varias operaciones de punto "primitivas" en esos valores, por lo general, para emular una mayor precisión (p. ej., Leveraging the bfloat16 Artificial Intelligence Datatype For Higher-Precision Computations: bf16_6x tf32_3x, etc.). En el caso de los algoritmos sin descomposición, estos valores deben establecerse en1.allow_imprecise_accumulationpara especificar si se permite la acumulación con una precisión más baja para algunos pasos (p.ej.,CUBLASLT_MATMUL_DESC_FAST_ACCUM).
Ejemplo de atributos DotAlgorithm:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Las implementaciones deciden qué combinaciones se admiten. En general, no se garantiza que cada algoritmo sea compatible con cada tipo de acelerador por el consumidor de StableHLO. Si no se admite un algoritmo determinado, se debe generar un error en lugar de recurrir a una alternativa. La verificación de StableHLO proporcionará la mejor verificación posible, lo que evitará algoritmos que no se sabe que sean compatibles con ningún hardware.
Consulta xla_data.proto > Algorithm para ver algunos valores de algoritmos admitidos. El ticket núm. 2483 registra el plan para crear un documento centralizado sobre los algoritmos admitidos por el backend.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado por tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
| (I2) | rhs |
tensor o tensor cuantificado | (C7-C10), (C12-C20) |
| (I3) | lhs_batching_dimensions |
Tensor constante unidimensional de tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
| (I4) | rhs_batching_dimensions |
Tensor constante unidimensional de tipo si64 |
(C1), (C4), (C7) y (C9) |
| (I5) | lhs_contracting_dimensions |
Tensor constante unidimensional de tipo si64 |
(C2), (C3), (C6) y (C10) |
| (I6) | rhs_contracting_dimensions |
Tensor constante unidimensional de tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
| (I7) | precision_config |
Cantidad variable de enumeraciones de DEFAULT, HIGH y HIGHEST |
(C11), (C21) |
| (I8) | lhs_precision_type |
FloatType o TensorFloat32 | (C21) |
| (I9) | rhs_precision_type |
FloatType o TensorFloat32 | (C21) |
| (I10) | accumulation_type |
FloatType o TensorFloat32 | (C21) |
| (I11) | lhs_component_count |
constante de tipo si32 |
(C21) y (C22) |
| (I12) | rhs_component_count |
constante de tipo si32 |
(C21), (C23) |
| (I13) | num_primitive_operations |
constante de tipo si32 |
(C21), (C24) |
| (I14) | allow_imprecise_accumulation |
constante de tipo bool |
(C21) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C12), (C14), (C18-C20) |
Limitaciones
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions). - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions). - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions). - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions). - (C5)
0 <= lhs_batching_dimensions < rank(lhs). - (C6)
0 <= lhs_contracting_dimensions < rank(lhs). - (C7)
0 <= rhs_batching_dimensions < rank(rhs). - (C8)
0 <= rhs_contracting_dimensions < rank(rhs). - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...). - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...). - (C11)
size(precision_config) = 2. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions). - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C13)
element_type(lhs) = element_type(rhs).
- (C13)
- Si la operación usa tensores cuantizados, se aplica lo siguiente:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C15)
zero_points(rhs) = 0. - (C16) Si
is_per_axis_quantized(rhs), entoncesquantization_dimension(rhs)no está enrhs_contracting_dimensions. - Si
is_quantized(lhs), haz lo siguiente: - (C17)
storage_type(lhs) = storage_type(rhs). - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C19) Si
is_per_tensor_quantized(rhs), entoncesis_per_tensor_quantized(result). - Si
!is_quantized(lhs), haz lo siguiente: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C14)
- Si
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation), haz lo siguiente:- (C21)
precision_config... = DEFAULT. - (C22)
0 < lhs_component_count. - (C23)
0 < rhs_component_count. - (C24)
0 < num_primitive_operations.
- (C21)
Ejemplos
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #sta<blehlo.dot
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimension>s = [1]
,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT],
algorithm = #stablehlo.dot<_algorithm
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation >= false
}< : (tenso>r2x2x2xi<64, tenso>r2x>2x2xi64<) - tenso>r2x2x2xi64
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semántica
Esta operación es funcionalmente idéntica a la operación broadcast_in_dim, pero la forma del resultado se especifica de forma dinámica a través de output_dimensions.
La operación también acepta atributos opcionales known_expanding_dimensions y known_nonexpanding_dimensions para expresar el conocimiento estático sobre el comportamiento de expansión de las dimensiones.
Si no se especifica, se supone que todas las dimensiones se pueden expandir.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado | (C1-C2), (C5-C6), (C9) |
| (I2) | output_dimensions |
Tensor unidimensional de tipo entero | (C7) |
| (I3) | broadcast_dimensions |
Tensor constante unidimensional de tipo entero | (C2-C6) |
| (I4) | known_expanding_dimensions |
Tensor constante unidimensional de tipo entero | (C8-C9) |
| (I5) | known_nonexpanding_dimensions |
Tensor constante unidimensional de tipo entero | (C8-C9) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3), (C5-C7) |
Limitaciones
- (C1)
element_type(result)se calcula de la siguiente manera:element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), excepto quequantization_dimension(operand),scales(operand)yzero_points(operand)pueden diferir dequantization_dimension(result),scales(result)yzero_points(result), respectivamente, en otros casos.
- (C2)
size(broadcast_dimensions) = rank(operand). - (C3)
0 <= broadcast_dimensions < rank(result). - (C4)
is_unique(broadcast_dimensions). - (C5) Para todo
denaxes(operand):dim(operand, d) = 1odim(operand, d) = dim(result, broadcast_dimensions[d]).
- (C6) Si
is_per_axis_quantized(result):quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].- Si
dim(operand, quantization_dimension(operand)) = 1, entoncesscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).
- (C7)
size(output_dimensions) = rank(result). - (C8)
is_unique(known_expanding_dimensions + known_nonexpanding_dimensions). - (C9)
0 <= known_expanding_dimensions < rank(operand). - (C10)
0 <= known_nonexpanding_dimensions < rank(operand).
Ejemplos
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensio<ns = arra>yi64: 2, 1,
known_expanding_dimensio<ns = a>rrayi64: 0,
known_nonexpanding_dimensio<ns = a>rrayi64: 1
}< : (ten>sor1x3xi<64, t>ens>or3xi64<) - tenso>r2x3x2xi64
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semántica
Esta operación es funcionalmente idéntica a la operación de convolución, pero el padding se especifica de forma dinámica a través de padding.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado por tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31), (C33) |
| (I2) | rhs |
tensor o tensor cuantificado | (C1), (C14-C16), (C26-C28), (C30-C33) |
| (I3) | padding |
Tensor bidimensional de tipo entero | (C4) |
| (I4) | window_strides |
Tensor constante unidimensional de tipo si64 |
(C2-C3) |
| (I5) | lhs_dilation |
Tensor constante unidimensional de tipo si64 |
(C5-C6) |
| (I6) | rhs_dilation |
Tensor constante unidimensional de tipo si64 |
(C7-C8) |
| (I7) | window_reversal |
Tensor constante unidimensional de tipo i1 |
(C9) |
| (I8) | input_batch_dimension |
constante de tipo si64 |
(C10), (C13) |
| (I9) | input_feature_dimension |
constante de tipo si64 |
(C11), (C13-C14) |
| (I10) | input_spatial_dimensions |
Tensor constante unidimensional de tipo si64 |
(C12), (C13) |
| (I11) | kernel_input_feature_dimension |
constante de tipo si64 |
(C14), (C18) |
| (I12) | kernel_output_feature_dimension |
constante de tipo si64 |
(C15-C16), (C18), (C28) |
| (I13) | kernel_spatial_dimensions |
Tensor constante unidimensional de tipo si64 |
(C17-C18) |
| (I14) | output_batch_dimension |
constante de tipo si64 |
(C20) |
| (I15) | output_feature_dimension |
constante de tipo si64 |
(C20), (C29) |
| (I16) | output_spatial_dimensions |
Tensor constante unidimensional de tipo si64 |
(C19-C20) |
| (I17) | feature_group_count |
constante de tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
| (I18) | batch_group_count |
constante de tipo si64 |
(C10), (C15), (C22) y (C23) |
| (I19) | precision_config |
Cantidad variable de enumeraciones de DEFAULT, HIGH y HIGHEST |
(C24) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C25 a C27), (C29), (C31 a C33) |
Limitaciones
- (C1)
N = rank(lhs) = rank(rhs). - (C2)
size(window_strides) = N - 2. - (C3)
0 < window_strides. - (C4)
shape(padding) = [N - 2, 2]. - (C5)
size(lhs_dilation) = N - 2. - (C6)
0 < lhs_dilation. - (C7)
size(rhs_dilation) = N - 2. - (C8)
0 < rhs_dilation. - (C9)
size(window_reversal) = N - 2. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0. - (C12)
size(input_spatial_dimensions) = N - 2. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:is_unique(input_dimensions).0 <= input_dimensions < N.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0. - (C17)
size(kernel_spatial_dimensions) = N - 2. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:is_unique(kernel_dimensions).0 <= kernel_dimensions < N.
- (C19)
size(output_spatial_dimensions) = N - 2. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:is_unique(output_dimensions).0 <= output_dimensions < N.
- (C21)
0 < feature_group_count. - (C22)
0 < batch_group_count. - (C23)
feature_group_count = 1 or batch_group_count = 1. - (C24)
size(precision_config) = 2. - (C25)
dim(result, result_dim)se define de la siguiente manera:dim(lhs, input_batch_dimension) / batch_group_countsiresult_dim = output_batch_dimension.dim(rhs, kernel_output_feature_dimension)siresult_dim = output_feature_dimension.num_windowsen otros casos, donde:output_spatial_dimensions[spatial_dim] = result_dim.lhs_dim = input_spatial_dimensions[spatial_dim].rhs_dim = kernel_spatial_dimensions[spatial_dim].dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
- (C26)
rank(result) = N. - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result).
- (C27)
- Si la operación usa tensores cuantizados, se aplica lo siguiente:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs). - (C29) Si
is_per_axis_quantized(rhs), entoncesquantization_dimension(rhs) = kernel_output_feature_dimension. - (C30) Si
is_per_axis_quantized(result), entoncesquantization_dimension(result) = output_feature_dimension. - Si
is_quantized(lhs), haz lo siguiente: - (C31)
storage_type(lhs) = storage_type(rhs). - (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result). - (C33) Si
is_per_tensor_quantized(rhs), entoncesis_per_tensor_quantized(result). - Si
!is_quantized(lhs), haz lo siguiente: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result).
- (C28)
Ejemplos
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strid<es = arra>yi64: 4, 4,
lhs_dilati<on = arra>yi64: 2, 2,
rhs_dilati<on = arra>yi64: 1, 1,
window_revers<al = arrayi1: fa>lse, false,
dimension_numbers = #stab<lehlo.convraw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions => [1, 2]
,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [<#stablehloprecisi>on DEFAULT, <#stablehloprecisi>on DEFAULT]
}< : (tensor1>x4x4x1xi<64, tensor3>x3x1x1xi<64, ten>sor>2x2xi64<) - tensor1>x2x2x1xi64
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semántica
Esta operación es funcionalmente idéntica a la operación gather, con el slice_sizes especificado de forma dinámica como un valor.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C7), (C10-C12), (C14) |
| (I2) | start_indices |
Tensor de tipo entero | (C2), (C3), (C13) |
| (I3) | slice_sizes |
Tensor unidimensional de tipo entero | (C8), (C11 a C13) |
| (I4) | offset_dims |
Tensor constante unidimensional de tipo si64 |
(C1), (C4-C5), (C13) |
| (I5) | collapsed_slice_dims |
Tensor constante unidimensional de tipo si64 |
(C1), (C6 a C8), (C13) |
| (I6) | start_index_map |
Tensor constante unidimensional de tipo si64 |
(C3), (C9) y (C10) |
| (I7) | index_vector_dim |
constante de tipo si64 |
(C2), (C3), (C13) |
| (I8) | indices_are_sorted |
constante de tipo i1 |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5), (C13-C14) |
Limitaciones
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims). - (C7)
0 <= collapsed_slice_dims < rank(operand). - (C8)
slice_sizes[collapsed_slice_dims...] <= 1. - (C9)
is_unique(start_index_map). - (C10)
0 <= start_index_map < rank(operand). - (C11)
size(slice_sizes) = rank(operand). - (C12)
0 <= slice_sizes <= shape(operand). - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)donde:batch_dim_sizes = shape(start_indices), excepto que no se incluye el tamaño de la dimensión destart_indicescorrespondiente aindex_vector_dim.offset_dim_sizes = shape(slice_sizes), excepto que no se incluyen los tamaños de dimensión enslice_sizescorrespondientes acollapsed_slice_dims.combinecolocabatch_dim_sizesen los ejes correspondientes abatch_dimsyoffset_dim_sizesen los ejes correspondientes aoffset_dims.
- (C14)
element_type(operand) = element_type(result).
Ejemplos
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stable<hlo.gather
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vect>or_dim = 2,
indices_are_sorted = false
}< : (tenso>r3x4x2xi<64, tenso>r2x3x2xi<64, t>ens>or3xi64<) - tensor2>x3x2x2xi64
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semántica
Esta operación es funcionalmente idéntica a la operación iota, pero la forma del resultado se especifica de forma dinámica a través de output_shape.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | output_shape |
Tensor unidimensional de tipo entero | (C1), (C2) |
| (I2) | iota_dimension |
si64 |
(C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C2) |
Limitaciones
- (C1)
0 <= iota_dimension < size(output_shape). - (C2)
rank(result) = size(output_shape).
Ejemplos
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
}< : (t>ens>or2xi64<) - ten>sor4x5xi64
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semántica
Esta operación es funcionalmente idéntica a la operación pad, pero con edge_padding_low, edge_padding_high y interior_padding especificados de forma dinámica como valores.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C1) |
| (I3) | edge_padding_low |
Tensor unidimensional de tipo entero | (C1), (C4) |
| (I4) | edge_padding_high |
Tensor unidimensional de tipo entero | (C1), (C4) |
| (I5) | interior_padding |
Tensor unidimensional de tipo entero | (C2-C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3 a C6) |
Limitaciones
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Ejemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
)< : (ten>sor2x3xi<64,> tensori<64, t>ensor2xi<64, t>ensor2xi<64, t>ens>or2xi64<) - ten>sor5x9xi64
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semántica
Esta operación es funcionalmente idéntica a la operación reshape, pero la forma del resultado se especifica de forma dinámica a través de output_shape.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado | (C1 a C3) |
| (I2) | output_shape |
Tensor unidimensional de tipo entero | (C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C1-C4) |
Limitaciones
- (C1)
element_type(result)se calcula de la siguiente manera:element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), excepto quequantization_dimension(operand)yquantization_dimension(result)pueden diferir.
- (C2)
size(operand) = size(result). - (C3) Si
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
- (C4)
size(output_shape) = rank(result).
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape)< : (ten>sor2x3xi<64, t>ens>or2xi64<) - ten>sor3x2xi64
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semántica
Extrae un segmento de operand usando índices iniciales calculados de forma dinámica y produce un tensor result. start_indices contiene los índices iniciales del segmento para cada dimensión sujeta a un posible ajuste, y slice_sizes contiene los tamaños del segmento para cada dimensión. De manera más formal, result[result_index] = operand[operand_index], en el que:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).operand_index = adjusted_start_indices + result_index.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
| (I2) | start_indices |
Cantidad variable de tensores de 0 dimensiones de tipo entero | (C2), (C3) |
| (I3) | slice_sizes |
Tensor constante unidimensional de tipo si64 |
(C2), (C4) y (C5) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) y (C5) |
Limitaciones
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(slice_sizes) = rank(operand). - (C3)
same(type(start_indices...)). - (C4)
0 <= slice_sizes <= shape(operand). - (C5)
shape(result) = slice_sizes.
Ejemplos
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_siz<es = arra>yi64: 2, 2
}< : (ten>sor4x4xi<32,> tensori<64,> te>nsori64<) - ten>sor2x2xi32
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semántica
Produce un tensor result que es igual al tensor operand, excepto que la porción que comienza en start_indices se actualiza con los valores de update.
De manera más formal, result[result_index] se define de la siguiente manera:
update[update_index]si0 <= update_index < shape(update), donde:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).update_index = result_index - adjusted_start_indices.
- En caso contrario,
operand[result_index].
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C4), (C6) |
| (I2) | update |
tensor o tensor cuantificado por tensor | (C2), (C3), (C6) |
| (I3) | start_indices |
Cantidad variable de tensores de 0 dimensiones de tipo entero | (C4), (C5) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
type(operand) = type(result). - (C2)
element_type(update) = element_type(operand). - (C3)
rank(update) = rank(operand). - (C4)
size(start_indices) = rank(operand). - (C5)
same(type(start_indices...)). - (C6)
0 <= shape(update) <= shape(operand).
Ejemplos
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
< : (ten>sor4x4xi<32, ten>sor2x2xi<32,> tensori<64,> te>nsori64<) - ten>sor4x4xi32
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponencial
Semántica
Realiza una operación exponencial a nivel de los elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
expde IEEE-754. - Para números complejos: exponencial compleja.
- Para los tipos cuantificados, se usa
dequantize_op_quantize(exponential, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semántica
Realiza una operación exponencial menos uno a nivel de los elementos en el tensor operand y genera un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
expm1de IEEE-754. - Para números complejos: exponencial compleja menos uno.
- Para los tipos cuantificados, se usa
dequantize_op_quantize(exponential_minus_one, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand)< : (t>ens>or2xf64<) - t>ensor2xf64
// %result: [0.0, 1.71828187]
fft
Semántica
Realiza las transformaciones de Fourier directa e inversa para entradas y salidas reales y complejas.
fft_type es una de las siguientes opciones:
FFT: Es la FFT de complejo a complejo directa.IFFT: FFT inversa de complejo a complejo.RFFT: FFT de real a complejo directa.IRFFT: FFT inversa de real a complejo (es decir, toma números complejos y devuelve números reales).
De manera más formal, dada la función fft que toma tensores unidimensionales de tipos complejos como entrada, produce tensores unidimensionales de los mismos tipos como salida y calcula la transformada discreta de Fourier:
En fft_type = FFT, result se define como el resultado final de una serie de L cálculos en los que L = size(fft_length). Por ejemplo, para L = 3:
result1[i0, ..., :] = fft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Además, dada la función ifft, que tiene la misma firma de tipo y calcula la inversa de fft:
En fft_type = IFFT, result se define como la inversa de los cálculos para fft_type = FFT. Por ejemplo, para L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = ifft(result2[i0, ..., :]).
Además, dada la función rfft, que toma tensores unidimensionales de tipos de punto flotante, produce tensores unidimensionales de tipos complejos con la misma semántica de punto flotante y funciona de la siguiente manera:
rfft(real_operand) = truncated_resultdondecomplex_operand... = (real_operand..., 0.0).complex_result = fft(complex_operand).truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].
(Cuando se calcula la transformada discreta de Fourier para operandos reales, los primeros N/2 + 1 elementos del resultado definen de forma inequívoca el resto del resultado, por lo que el resultado de rfft se trunca para evitar el cálculo de elementos redundantes).
En fft_type = RFFT, result se define como el resultado final de una serie de L cálculos en los que L = size(fft_length). Por ejemplo, para L = 3:
result1[i0, ..., :] = rfft(operand[i0, ..., :]).result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).
Por último, dada la función irfft, que tiene la misma firma de tipo y calcula la inversa de rfft:
En fft_type = IRFFT, result se define como la inversa de los cálculos para fft_type = RFFT. Por ejemplo, para L = 3:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).result[i0, ..., :] = irfft(result2[i0, ..., :]).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante | (C1), (C2), (C4), (C5) |
| (I2) | fft_type |
Enumeración de FFT, IFFT, RFFT y IRFFT |
(C2), (C5) |
| (I3) | fft_length |
Tensor constante unidimensional de tipo si64 |
(C1), (C3), (C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante | (C2), (C4) y (C5) |
Limitaciones
- (C1)
size(fft_length) <= rank(operand). - (C2) La relación entre los tipos de elementos
operandyresultvaría de la siguiente manera:- Si
fft_type = FFT,element_type(operand)yelement_type(result)tienen el mismo tipo complejo - Si
fft_type = IFFT,element_type(operand)yelement_type(result)tienen el mismo tipo complejo - Si
fft_type = RFFT,element_type(operand)es un tipo de punto flotante yelement_type(result)es un tipo complejo con la misma semántica de punto flotante. - Si es
fft_type = IRFFT,element_type(operand)es un tipo complejo yelement_type(result)es un tipo de punto flotante con la misma semántica de punto flotante.
- Si
- (C3)
1 <= size(fft_length) <= 3. - (C4) Si entre
operandyresult, hay un tensorrealde un tipo de punto flotante, entoncesshape(real)[-size(fft_length):] = fft_length. - (C5)
shape(result) = shape(operand), excepto en los siguientes casos:- Si
fft_type = RFFT,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1. - Si
fft_type = IRFFT,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.
- Si
Ejemplos
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = <#stablehloff>t_type FFT,
fft_leng<th = a>rrayi64: 4
}< : (tenso<r4x>>com>plexf32<) - tenso<r4x>>complexf32
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piso
Semántica
Realiza la función de piso para cada elemento del tensor operand y produce un tensor result.
Implementa la operación roundToIntegralTowardNegative de la especificación IEEE-754. Para los tipos cuantificados, realiza dequantize_op_quantize(floor, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand)< : (t>ens>or5xf32<) - t>ensor5xf32
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
recopila
Semántica
Recopila segmentos del tensor operand a partir de las compensaciones especificadas en start_indices y produce un tensor result.
En el siguiente diagrama, se muestra cómo se asignan los elementos de result a los elementos de operand con un ejemplo concreto. El diagrama elige algunos índices de result como ejemplo y explica en detalle a qué índices de operand corresponden.
De manera más formal, result[result_index] = operand[operand_index], en la que:
batch_dims = [d for d in axes(result) and d not in offset_dims].batch_index = result_index[batch_dims...].start_indexse define de la siguiente manera:start_indices[bi0, ..., :, ..., biN], dondebison elementos individuales enbatch_indexy:se inserta en el índiceindex_vector_dim, siindex_vector_dim<rank(start_indices).- En caso contrario,
[start_indices[batch_index]].
- Para
d_operandenaxes(operand),full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])sid_operand = start_index_map[d_start].- En caso contrario,
full_start_index[d_operand] = 0.
- Para
d_operandenaxes(operand),full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sid_operand = operand_batching_dims[i_batching]yd_start = start_indices_batching_dims[i_batching].- En caso contrario,
full_batching_index[d_operand] = 0.
offset_index = result_index[offset_dims...].full_offset_index = [oi0, ..., 0, ..., oiN], dondeoison elementos individuales enoffset_indexy0se inserta en los índices decollapsed_slice_dimsyoperand_batching_dims.operand_index = full_start_index + full_batching_index + full_offset_index.
Si indices_are_sorted es true, la implementación puede suponer que start_indices se ordena con respecto a start_index_map; de lo contrario, el comportamiento es indefinido. De manera más formal, para todo i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
| (I2) | start_indices |
Tensor de tipo entero | (C2-C3), (C14), (C17), (C22) |
| (I3) | offset_dims |
Tensor constante unidimensional de tipo si64 |
(C1), (C4-C5), (C22) |
| (I4) | collapsed_slice_dims |
Tensor constante unidimensional de tipo si64 |
(C1), (C6-C9), (C22) |
| (I5) | operand_batching_dims |
Tensor constante unidimensional de tipo si64 |
(C1), (C6), (C10 a C12), (C16 a C18), (C22) |
| (I6) | start_indices_batching_dims |
Tensor constante unidimensional de tipo si64 |
(C13 a C17) |
| (I7) | start_index_map |
Tensor constante unidimensional de tipo si64 |
(C3), (C18-C19) |
| (I8) | index_vector_dim |
constante de tipo si64 |
(C2-C3), (C15), (C22) |
| (I9) | slice_sizes |
Tensor constante unidimensional de tipo si64 |
(C9), (C12), (C20 a C22) |
| (I10) | indices_are_sorted |
constante de tipo i1 |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5), (C22-C23) |
Limitaciones
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims). - (C2)
0 <= index_vector_dim <= rank(start_indices). - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims). - (C5)
0 <= offset_dims < rank(result). - (C6)
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims)) - (C7)
is_sorted(collapsed_slice_dims). - (C8)
0 <= collapsed_slice_dims < rank(operand). - (C9)
slice_sizes[collapsed_slice_dims...] <= 1. - (C10)
is_sorted(operand_batching_dims). - (C11)
0 <= operand_batching_dims < rank(operand). - (C12)
slice_sizes[operand_batching_dims...] <= 1. - (C13)
is_unique(start_indices_batching_dims). - (C14)
0 <= start_indices_batching_dims < rank(start_indices). - (C15)
index_vector_dim not in start_indices_batching_dims. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims). - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...). - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims)). - (C19)
0 <= start_index_map < rank(operand). - (C20)
size(slice_sizes) = rank(operand). - (C21)
0 <= slice_sizes <= shape(operand). - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes), donde:batch_dim_sizes = shape(start_indices), excepto que no se incluye el tamaño de la dimensión destart_indicescorrespondiente aindex_vector_dim.offset_dim_sizes = slice_sizes, excepto que no se incluyen los tamaños de dimensión enslice_sizescorrespondientes acollapsed_slice_dimsyoperand_batching_dims.combinecolocabatch_dim_sizesen los ejes correspondientes abatch_dimsyoffset_dim_sizesen los ejes correspondientes aoffset_dims.
- (C23)
element_type(operand) = element_type(result).
Ejemplos
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stable<hlo.gather
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vect>or_dim = 3,
slice_siz<es = arrayi64: >1, 1, 2, 2,
indices_are_sorted = false
}< : (tensor2>x3x4x2xi<32, tensor2>x2x>3x2xi64<) - tensor2x2>x3x2x2xi32
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semántica
Produce el tamaño del dimension determinado del operand. De manera más formal, result = dim(operand, dimension). La semántica solo se relaciona con el componente de forma del tipo. El tipo de elemento puede ser cualquiera.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado | (C1) |
| (I2) | dimension |
constante de tipo si64 |
(C1) |
Salidas
| Nombre | Tipo |
|---|---|
result |
Tensor de 0 dimensiones del tipo si32 |
Limitaciones
- (C1)
0 <= dimension < rank(operand).
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
}< : (ten>sor>2x3xi64<) -> tensori32
// %result: 3
get_tuple_element
Semántica
Extrae el elemento en la posición index de la tupla operand y produce un result. De manera más formal, result = operand[index].
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tuple | (C1), (C2) |
| (I2) | index |
constante de tipo si32 |
(C1), (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
cualquier valor | (C2) |
Limitaciones
- (C1)
0 <= index < size(operand). - (C2)
type(result) = tuple_element_types(operand)[index].
Ejemplos
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(<%operand) {index >= 0 : i32<} : (t<uplet>ensor2x<f64, t<upl>>>ete>nsori64<) - t>ensor2xf64
// %result: [1.0, 2.0]
si
Semántica
Genera el resultado de ejecutar exactamente una función de true_branch o false_branch según el valor de pred. De manera más formal, result =
pred ? true_branch() : false_branch().
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | pred |
Tensor de 0 dimensiones del tipo i1 |
|
| (I2) | true_branch |
función | (C1 a C3) |
| (I3) | false_branch |
función | (C1), (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores, tensores cuantificados o tokens | (C3) |
Limitaciones
- (C1)
input_types(true_branch) = input_types(false_branch) = []. - (C2)
output_types(true_branch) = output_types(false_branch). - (C3)
type(results...) = output_types(true_branch).
Ejemplos
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_tr<ue_>bra>nch) : (tensori32) - ()
}, {
"stablehlo.return"(%<res>ult>_false_branch) :< (>ten>sori32)< - >()
}) : (tensori1) - tensori32
// %result: 10
imag
Semántica
Extrae la parte imaginaria, elemento por elemento, de operand y produce un tensor result. De manera más formal, para cada elemento x, se cumple que imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante | (C1), (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante | (C1), (C2) |
Limitaciones
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)se define de la siguiente manera:complex_element_type(element_type(operand))siis_complex(operand).- En caso contrario,
element_type(operand).
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [2.0, 4.0]
infeed
Semántica
Lee datos del feed integrado y produce results.
La semántica de infeed_config se define según la implementación.
results consta de valores de carga útil que aparecen primero y un token que aparece al final. En el futuro, planeamos dividir la carga útil y el token en dos salidas separadas para mejorar la claridad (#670).
Entradas
| Etiqueta | Nombre | Tipo |
|---|---|---|
| (I1) | token |
token |
| (I2) | infeed_config |
constante de tipo string |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores, tensores cuantificados o tokens | (C1 a C3) |
Limitaciones
- (C1)
0 < size(results). - (C2)
is_empty(result[:-1])ois_tensor(type(results[:-1])). - (C3)
is_token(type(results[-1])).
Ejemplos
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : >(!stable<hlo.tok>en) - (tensor2x2xi64, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config> = "<;">
} : (!stablehlo.token) - (tensor2x2xi64, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semántica
Completa un tensor output con valores en orden creciente a partir de cero a lo largo de la dimensión iota_dimension. De manera más formal,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | iota_dimension |
si64 |
(C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
output |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
0 <= iota_dimension < rank(output).
Ejemplos
%output = "stablehlo.iota"() {
iota_dimension = 0 : i6>4
} : (<) - ten>sor4x5xi32
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimensio>n = 1 :< i64
} >: () - tensor4x5xi32
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semántica
Realiza una verificación de cada elemento para determinar si el valor en x es finito (es decir, no es +Inf, -Inf ni NaN) y produce un tensor y. Implementa la operación isFinite de la especificación IEEE-754. En el caso de los tipos cuantificados, el resultado siempre es true.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | x |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
y |
Tensor de tipo booleano | (C1) |
Limitaciones
- (C1)
shape(x) = shape(y).
Ejemplos
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x)< : (tens>or7xf64<) - >tensor7xi1
// %y: [false, false, false, true, true, true, true]
log
Semántica
Realiza una operación de logaritmo para cada elemento del tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
logde IEEE-754. - Para números complejos: logaritmo complejo.
- Para tipos cuantificados:
dequantize_op_quantize(log, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semántica
Realiza una operación de logaritmo más uno para cada elemento en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
logp1de IEEE-754. - Para números complejos:
complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1)) - Para los tipos cuantificados, se usa
dequantize_op_quantize(log_plus_one, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logística
Semántica
Realiza la operación logística a nivel de los elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
division(1, addition(1, exp(-x)))de IEEE-754. - Para números complejos: logística compleja.
- Para los tipos cuantificados, se usa
dequantize_op_quantize(logistic, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mapa
Semántica
Aplica una función de mapa computation a inputs a lo largo de dimensions y produce un tensor result.
De manera más formal, result[result_index] = computation(inputs...[result_index]).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1-C4) |
| (I2) | dimensions |
Tensor constante unidimensional de tipo si64 |
(C3) |
| (I3) | computation |
función | (C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C4) |
Limitaciones
- (C1)
shape(inputs...) = shape(result). - (C2)
0 < size(inputs) = N. - (C3)
dimensions = range(rank(inputs[0])). - (C4)
computationtiene el tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, dondeEi = element_type(inputs[i])yE' = element_type(result).
Ejemplos
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = stablehlo.multiply %arg0, %arg<1 :> tensori64
stablehlo.return %<0 :> tensori64
}) {
dimensio<ns = arra>yi64: 0, 1
}< : (ten>sor2x2xi<64, ten>sor>2x2xi64<) - ten>sor2x2xi64
// %result: [[0, 5], [12, 21]]
máximo
Semántica
Realiza una operación de máximo a nivel de los elementos en los tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: Es el valor máximo del número entero.
- Para números de punto flotante:
maximumde IEEE-754. - Para números complejos: máximo lexicográfico para el par
(real, imaginary). Imponer un orden en los números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560). - Para los tipos cuantificados:
dequantize_op_quantize(maximum, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 8]]
mínima
Semántica
Realiza una operación de min a nivel de los elementos en los tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: AND lógico.
- Para números enteros: Es el valor mínimo del número entero.
- Para números de punto flotante:
minimumde IEEE-754. - Para números complejos: mínimo lexicográfico para el par
(real, imaginary). Imponer un orden en los números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560). - Para los tipos cuantificados:
dequantize_op_quantize(minimum, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[1, 2], [3, 4]]
multiplicar
Semántica
Realiza el producto de dos tensores lhs y rhs elemento por elemento, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: AND lógico.
- Para números enteros: multiplicación de números enteros.
- Para números de punto flotante:
multiplicationde IEEE-754. - Para números complejos: multiplicación compleja.
- Para los tipos cuantificados:
dequantize_op_quantize(multiply, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 12], [21, 32]]
negativo
Semántica
Realiza la negación de cada elemento del tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números enteros con signo: negación de números enteros.
- Para números enteros sin signo: conversión de tipo a número entero con signo, negación de número entero y conversión de tipo de nuevo a número entero sin signo.
- Para números de punto flotante:
negatede IEEE-754. - Para números complejos: negación compleja.
- Para los tipos cuantificados, se usa
dequantize_op_quantize(negate, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand)< : (t>ens>or2xi32<) - t>ensor2xi32
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"<(%operand<) :>> (t>ensor1x<complexf3<2) >>- tensor1xcomplexf32
// %result: [-2.5, -0.0]
no
Semántica
Realiza una operación NOT a nivel del elemento del tensor operand y produce un tensor result.
Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: NOT lógico.
- Para números enteros: NOT a nivel de bits.
Argumentos
| Nombre | Tipo | Limitaciones |
|---|---|---|
operand |
tensor de tipo booleano o entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo booleano o entero | (C1) |
Limitaciones
- (C1)
type(operand) = type(result).
Ejemplos
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand)< : (ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"<(%op>era>nd) : (<tens>or2xi1) - tensor2xi1
// %result: [false, true]
optimization_barrier
Semántica
Garantiza que las operaciones que producen el operand se ejecuten antes que cualquier operación que dependa del result y evita que las transformaciones del compilador muevan operaciones a través de la barrera. Aparte de eso, la operación es una identidad, es decir, result = operand.
Argumentos
| Nombre | Tipo | Limitaciones |
|---|---|---|
operand |
Cantidad variable de tensores, tensores o tokens cuantificados por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Cantidad variable de tensores, tensores o tokens cuantificados por tensor | (C1) |
Limitaciones
- (C1)
type(operand...) = type(result...).
Ejemplos
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1)< : >(tensorf<32,> te>nsorf32)< - >(tensorf<32,> tensorf32)
// %result0: 0.0
// %result1: 1.0
o
Semántica
Realiza la operación OR a nivel de los elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: OR bit a bit.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
Tensor de tipo booleano o número entero | (C1) |
| (I2) | rhs |
Tensor de tipo booleano o número entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo booleano o número entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result).
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, true]]
Salida
Semántica
Escribe inputs en el feed externo y produce un token de result.
La semántica de outfeed_config se define según la implementación.
Entradas
| Etiqueta | Nombre | Tipo |
|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados |
| (I2) | token |
token |
| (I3) | outfeed_config |
constante de tipo string |
Salidas
| Nombre | Tipo |
|---|---|
result |
token |
Ejemplos
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = &quo<t;"<>/span>
} : (tensor2x2x2xi64,> !stablehlo.token) - !stablehlo.token
tope
Semántica
Expande operand agregando padding alrededor del tensor y entre los elementos del tensor con el padding_value determinado.
edge_padding_low y edge_padding_high especifican la cantidad de padding que se agrega en el extremo inferior (junto al índice 0) y el extremo superior (junto al índice más alto) de cada dimensión, respectivamente. La cantidad de padding puede ser negativa, en cuyo caso el valor absoluto del padding negativo indica la cantidad de elementos que se quitarán de la dimensión especificada.
interior_padding especifica la cantidad de padding que se agrega entre dos elementos cualesquiera en cada dimensión, que no puede ser negativa. El padding interior se aplica antes del padding de borde, de modo que el padding de borde negativo quitará elementos del operando con padding interior.
De manera más formal, result[result_index] se define de la siguiente manera:
operand[operand_index]siresult_index = edge_padding_low + operand_index * (interior_padding + 1).- En caso contrario,
padding_value.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
| (I2) | padding_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C1) |
| (I3) | edge_padding_low |
Tensor constante unidimensional de tipo si64 |
(C1), (C4) |
| (I4) | edge_padding_high |
Tensor constante unidimensional de tipo si64 |
(C1), (C4) |
| (I5) | interior_padding |
Tensor constante unidimensional de tipo si64 |
(C2-C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3 a C6) |
Limitaciones
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result). - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand). - (C3)
0 <= interior_padding. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.
Ejemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_l<ow = arra>yi64: 0, 1,
edge_padding_hi<gh = arra>yi64: 2, 1,
interior_paddi<ng = arra>yi64: 1, 2
}< : (ten>sor2x3xi<32,> te>nsori32<) - ten>sor5x9xi32
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semántica
Produce partition_id del proceso actual.
Salidas
| Nombre | Tipo |
|---|---|
result |
Tensor de 0 dimensiones del tipo ui32 |
Ejemplos
%result = "stablehlo.partition_id">;() : (<) - >tensorui32
popcnt
Semántica
Realiza un recuento bit a bit de la cantidad de bits establecidos en el tensor operand y produce un tensor result.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo entero | (C1) |
Limitaciones
- (C1)
type(operand) = type(result).
Ejemplos
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand)< : (t>ens>or4xi64<) - t>ensor4xi64
// %result: [0, 1, 1, 7]
energía
Semántica
Realiza la exponenciación según los elementos del tensor lhs por el tensor rhs y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: potenciación de números enteros.
- Para números de punto flotante:
powde IEEE-754. - Para números complejos: exponenciación compleja.
- Para tipos cuantificados:
dequantize_op_quantize(power, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs)< : (t>ensor6xf<64, t>ens>or6xf64<) - t>ensor6xf64
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semántica
Extrae la parte real, elemento por elemento, de operand y produce un tensor result. De manera más formal, para cada elemento x, se cumple que real(x) = is_complex(x) ? real_part(x) : x.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante | (C1), (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante | (C1), (C2) |
Limitaciones
- (C1)
shape(result) = shape(operand). - (C2)
element_type(result)se define de la siguiente manera:complex_element_type(element_type(operand))siis_complex(operand).- En caso contrario,
element_type(operand).
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand)< : (tenso<r2x>>com>plexf32<) - t>ensor2xf32
// %result: [1.0, 3.0]
recv
Semántica
Recibe datos de un canal con channel_id y produce results.
Si is_host_transfer es true, la operación transfiere datos desde el host. De lo contrario, transfiere datos desde otro dispositivo según los valores de source_target_pairs. Esta marca duplica la información proporcionada en channel_type, por lo que, en el futuro, planeamos conservar solo una de ellas (#666). Si is_host_transfer = false y source_target_pairs es None o está vacío, se considera un comportamiento indefinido.
results consta de valores de carga útil que aparecen primero y un token que aparece al final. En el futuro, planeamos dividir la carga útil y el token en dos salidas separadas para mejorar la claridad (#670).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | token |
token |
|
| (I2) | channel_id |
constante de tipo si64 |
|
| (I3) | channel_type |
Enumeración de DEVICE_TO_DEVICE y DEVICE_TO_HOST |
(C5) |
| (I4) | is_host_transfer |
constante de tipo i1 |
(C5-C6) |
| (I5) | source_target_pairs |
Tensor constante bidimensional de tipo si64 |
(C1-C4), (C6) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores, tensores cuantificados o tokens | (C2-C4) |
Limitaciones
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, dondeNse define de la siguiente manera:num_replicassi se usacross_replica.num_partitionssi se usacross_partition.
- (C5)
channel_typese define de la siguiente manera:DEVICE_TO_HOSTsiis_host_transfer = true,- En caso contrario,
DEVICE_TO_DEVICE.
Ejemplos
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
} : (!stablehl>o.token)< - (ten>sor2x2xi64, !stablehlo.token)
Reducir
Semántica
Aplica una función de reducción body a inputs y init_values a lo largo de dimensions y produce tensores results.
El orden de las reducciones se define en la implementación, lo que significa que body y init_values deben formar un monoide para garantizar que la operación produzca los mismos resultados para todas las entradas en todas las implementaciones. Sin embargo, esta condición no se cumple para muchas reducciones populares. Por ejemplo, la suma de punto flotante para body y cero para init_values no forman un monoide porque la suma de punto flotante no es asociativa.
De manera más formal, results...[j0, ..., jR-1] = reduce(input_slices_converted), en la que:
input_slices = inputs...[j0, ..., :, ..., jR-1], donde:se insertan endimensions.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).reduce(input_slices_converted) = exec(schedule)para algún árbol binarioscheduledonde:exec(node) = body(exec(node.left), exec(node.right)).exec(leaf) = leaf.value.
schedulees un árbol binario completo definido por la implementación cuya transversal en orden consiste en lo siguiente:- Valores de
input_slices_converted...[index]para todos losindexenindex_space(input_slices_converted)en el orden lexicográfico ascendente deindex. - Se intercalan con una cantidad de
init_values_converteddefinida por la implementación en posiciones definidas por la implementación.
- Valores de
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1 a C4), (C6) y (C7) |
| (I2) | init_values |
Cantidad variable de tensores de 0 dimensiones o tensores cuantificados por tensor | (C2), (C3) |
| (I3) | dimensions |
Tensor constante unidimensional de tipo si64 |
(C4), (C5) y (C7) |
| (I4) | body |
función | (C6) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores o tensores cuantificados por tensor | (C3), (C7), (C8) |
Limitaciones
- (C1)
same(shape(inputs...)). - (C2)
element_type(inputs...) = element_type(init_values...). - (C3)
0 < size(inputs) = size(init_values) = size(results) = N. - (C4)
0 <= dimensions < rank(inputs[0]). - (C5)
is_unique(dimensions). - (C6)
bodytiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), dondeis_promotable(element_type(inputs[i]), Ei). - (C7)
shape(results...) = shape(inputs...), excepto que no se incluyen los tamaños de dimensión deinputs...correspondientes adimensions. - (C8)
element_type(results[i]) = Eipara todos losien[0,N).
Ejemplos
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) <- ()
}>) {
dimens<ions = >arrayi64<: 1>
} >: (tens<or1x6>xi64, tensori64) - tensor1xi64
// %result = [15]
reduce_precision
Semántica
Realiza la conversión de cada elemento de operand a otro tipo de punto flotante que usa exponent_bits y mantissa_bits, y viceversa, para volver al tipo de punto flotante original y producir un tensor output.
De forma más formal:
- Los bits de la mantisa del valor original se actualizan para redondear el valor original al valor más cercano que se puede representar con
mantissa_bitsusando la semántica deroundToIntegralTiesToEven. - Luego, si
mantissa_bitses menor que la cantidad de bits de mantisa del valor original, los bits de mantisa se truncan amantissa_bits. - Luego, si los bits del exponente del resultado intermedio no se ajustan al rango proporcionado por
exponent_bits, el resultado intermedio se desborda hasta el infinito con el signo original o se desborda hacia abajo hasta cero con el signo original. - Para los tipos cuantificados, realiza
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
| (I2) | exponent_bits |
constante de tipo si32 |
(C2) |
| (I3) | mantissa_bits |
constante de tipo si32 |
(C3) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
output |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(output). - (C2)
1 <= exponent_bits. - (C3)
0 <= mantissa_bits.
Ejemplos
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
}< : (t>ens>or6xf64<) - t>ensor6xf64
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, realiza una reducción con computations sobre los valores del tensor operand de cada proceso, divide el resultado de la reducción a lo largo de scatter_dimension en partes y dispersa las partes divididas entre los procesos para producir el tensor result.
La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:
cross_replica(replica_groups)sichannel_id <= 0 and use_global_device_ids = false.cross_replica_and_partition(replica_groups)sichannel_id > 0 and use_global_device_ids = false.flattened_ids(replica_groups)sichannel_id > 0 and use_global_device_ids = true.
Luego, dentro de cada process_group, haz lo siguiente:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).result@receiver = parts@sender[receiver_index]para todosenderenprocess_group, dondereceiver_index = process_group.index(receiver).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C7) y (C8) |
| (I2) | scatter_dimension |
constante de tipo si64 |
(C1), (C2), (C8) |
| (I3) | replica_groups |
Tensor constante bidimensional de tipo si64 |
(C3 a C5) |
| (I4) | channel_id |
constante de tipo si64 |
(C6) |
| (I5) | use_global_device_ids |
constante de tipo i1 |
(C6) |
| (I6) | computation |
función | (C7) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C8-C9) |
Limitaciones
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0. - (C2)
0 <= scatter_dimension < rank(operand). - (C3)
is_unique(replica_groups). - (C4)
size(replica_groups)se define de la siguiente manera:num_replicassi se usacross_replica.num_replicassi se usacross_replica_and_partition.num_processessi se usaflattened_ids.
- (C5)
0 <= replica_groups < size(replica_groups). - (C6) Si
use_global_device_ids = true, entonceschannel_id > 0. - (C7)
computationtiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>), dondeis_promotable(element_type(operand), E). - (C8)
shape(result) = shape(operand), excepto:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
- (C9)
element_type(result) = E.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimension = 1 :< i64,
>replica_g<roups => dense[[0, 1]] : tensor1x2xi64,
channel_hand<le = #stablehlo.chan>nel_handleha<ndle = >0, >type = <0
} : (>tensor2x4xi64) - tensor2x2xi64
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semántica
Aplica una función de reducción body a ventanas de inputs y init_values, y produce results.
En el siguiente diagrama, se muestra cómo se calculan los elementos en results... a partir de inputs... con un ejemplo concreto.
De manera más formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (consulta reduce), donde:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).window_start = result_index * window_strides.window_end = window_start + (window_dimensions - 1) * window_dilations + 1.windows = slice(padded_inputs..., window_start, window_end, window_dilations).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1 a C4), (C6), (C8), (C10), (C12), (C13) y (C15) |
| (I2) | init_values |
Cantidad variable de tensores de 0 dimensiones o tensores cuantificados por tensor | (C1), (C13) |
| (I3) | window_dimensions |
Tensor constante unidimensional de tipo si64 |
(C4), (C5) y (C15) |
| (I4) | window_strides |
Tensor constante unidimensional de tipo si64 |
(C6), (C7) y (C15) |
| (I5) | base_dilations |
Tensor constante unidimensional de tipo si64 |
(C8), (C9) y (C15) |
| (I6) | window_dilations |
Tensor constante unidimensional de tipo si64 |
(C10), (C11), (C15) |
| (I7) | padding |
Tensor constante bidimensional de tipo si64 |
(C12), (C15) |
| (I8) | body |
función | (C13) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1), (C14 a C16) |
Limitaciones
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N. - (C2)
same(shape(inputs...)). - (C3)
element_type(inputs...) = element_type(init_values...). - (C4)
size(window_dimensions) = rank(inputs[0]). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(inputs[0]). - (C7)
0 < window_strides. - (C8)
size(base_dilations) = rank(inputs[0]). - (C9)
0 < base_dilations. - (C10)
size(window_dilations) = rank(inputs[0]). - (C11)
0 < window_dilations. - (C12)
shape(padding) = [rank(inputs[0]), 2]. - (C13)
bodytiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), dondeis_promotable(element_type(inputs[i]), Ei). - (C14)
same(shape(results...)). - (C15)
shape(results[0]) = num_windowsdonde:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
- (C16)
element_type(results[i]) = Eipara todoien[0,N).
Ejemplos
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
})< {
wind>ow_dimensions = arrayi64: <2, 1,
w>indow_strides = arrayi64: <4, 1,
b>ase_dilations = arrayi64: 2,< 1,
win>dow_dilations = arr<ayi64: 3, 1,
p>adding = <dense[[>2, 1], [0, 0<]] : te>nsor2x2x<i64>
} >: (tens<or3x2xi>64, tensori64) - tensor2x2xi64
// %result = [[0, 0], [3, 4]]
resto
Semántica
Realiza el resto de la división de los tensores de dividendo lhs y divisor rhs elemento por elemento, y produce un tensor result.
De manera más formal, el signo del resultado se toma del dividendo, y el valor absoluto del resultado siempre es menor que el valor absoluto del divisor.
El resto se calcula como lhs - d * rhs, donde d se obtiene de la siguiente manera:
- Para números enteros:
stablehlo.divide(lhs, rhs). - Para números de punto flotante:
division(lhs, rhs)de IEEE-754 con el atributo de redondeoroundTowardZero. - Para números complejos: Por definir (#997).
- Para los tipos cuantificados:
dequantize_op_quantize(remainder, lhs, rhs, type(result)).
Para los tipos de elementos de punto flotante, esta operación contrasta con la operación remainder de la especificación IEEE-754, en la que d es un valor integral más cercano al valor exacto de lhs/rhs con empates a pares.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
Tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs)< : (t>ensor4xi<64, t>ens>or4xi64<) - t>ensor4xi64
// %result: [2, -2, 2, -2]
replica_id
Semántica
Produce replica_id del proceso actual.
Salidas
| Nombre | Tipo |
|---|---|
result |
Tensor de 0 dimensiones del tipo ui32 |
Ejemplos
%result = "stablehlo.replica_id">;() : (<) - >tensorui32
reshape
Semántica
Cambia la forma del tensor operand a un tensor result. Conceptualmente, equivale a mantener la misma representación canónica, pero posiblemente cambiar la forma, p.ej., de tensor<2x3xf32> a tensor<3x2xf32> o tensor<6xf32>.
De manera más formal, result[result_index] = operand[operand_index], donde result_index y operand_index tienen la misma posición en el orden lexicográfico de index_space(result) y index_space(operand).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado | (C1 a C3) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C1 a C3) |
Limitaciones
- (C1)
element_type(result)se calcula de la siguiente manera:element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), excepto quequantization_dimension(operand)yquantization_dimension(result)pueden diferir.
- (C2)
size(operand) = size(result). - (C3) Si
is_per_axis_quantized(operand):reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand)< : (ten>sor>2x3xi32<) - ten>sor3x2xi32
// %result: [[1, 2], [3, 4], [5, 6]]
revertir
Semántica
Invierte el orden de los elementos en el operand a lo largo del dimensions especificado y produce un tensor result. De manera más formal, result[result_index] = operand[operand_index], en el que:
operand_index[d] = dim(result, d) - result_index[d] - 1sidendimensions- En caso contrario,
operand_index[d] = result_index[d].
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C3) |
| (I2) | dimensions |
Tensor constante unidimensional de tipo si64 |
(C2), (C3) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C3) |
Limitaciones
- (C1)
type(operand) = type(result). - (C2)
is_unique(dimensions). - (C3)
0 <= dimensions < rank(result).
Ejemplos
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensio<ns = a>rrayi64: 1
}< : (ten>sor>3x2xi32<) - ten>sor3x2xi32
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semántica
Genera números aleatorios con el algoritmo rng_distribution y produce un tensor result con una forma determinada shape.
Si es rng_distribution = UNIFORM, los números aleatorios se generan según la distribución uniforme en el intervalo [a, b). Si es a >= b, el comportamiento no está definido.
Si es rng_distribution = NORMAL, los números aleatorios se generan según la distribución normal con una media = a y una desviación estándar = b.
Si es b < 0, el comportamiento no está definido.
La forma exacta en que se generan los números aleatorios se define en la implementación. Por ejemplo, pueden ser determinísticos o no, y pueden usar o no un estado oculto.
En conversaciones con muchas partes interesadas, esta operación se ha considerado como efectivamente obsoleta, por lo que, en el futuro, planeamos explorar la posibilidad de quitarla (#597).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | a |
Tensor de 0 dimensiones de tipo entero, booleano o de punto flotante | (C1), (C2) |
| (I2) | b |
Tensor de 0 dimensiones de tipo entero, booleano o de punto flotante | (C1), (C2) |
| (I3) | shape |
Tensor constante unidimensional de tipo si64 |
(C3) |
| (I4) | rng_distribution |
Enumeración de UNIFORM y NORMAL |
(C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo entero, booleano o de punto flotante | (C1 a C3) |
Limitaciones
- (C1)
element_type(a) = element_type(b) = element_type(result). - (C2) Si
rng_distribution = NORMAL, entoncesis_float(a). - (C3)
shape(result) = shape.
Ejemplos
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = <#stablehlorng_distributi>on UNIFORM
}< : >(tensori<32,> tensori<32, t>ens>or2xi64<) - ten>sor3x3xi32
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semántica
Devuelve un output completado con bits aleatorios uniformes y un estado de salida output_state actualizado con el algoritmo del generador de números pseudoaleatorios rng_algorithm, dado un estado inicial initial_state. Se garantiza que el resultado es una función determinística de initial_state, pero no se garantiza que sea determinístico entre las implementaciones.
rng_algorithm es una de las siguientes opciones:
DEFAULT: Es un algoritmo definido por la implementación.THREE_FRY: Variante del algoritmo Threefry definida por la implementación.*PHILOX: Variante del algoritmo Philox definida por la implementación.*
* Consulta: Salmon et al. SC 2011. Números aleatorios paralelos: tan fácil como contar del 1 al 3.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | rng_algorithm |
Enumeración de DEFAULT, THREE_FRY y PHILOX |
(C2) |
| (I2) | initial_state |
Tensor unidimensional de tipo ui64 |
(C1), (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
output_state |
Tensor unidimensional de tipo ui64 |
(C1) |
output |
Tensor de tipo entero o de punto flotante |
Limitaciones
- (C1)
type(initial_state) = type(output_state). - (C2)
size(initial_state)se define de la siguiente manera:- Se define la implementación si es
rng_algorithm = DEFAULT. 2sirng_algorithm = THREE_FRY.2o3si esrng_algorithm = PHILOX.
- Se define la implementación si es
Ejemplos
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = <#stablehlorng_algorithm> THREE_FRY
}< : (te>nso>r2xui64)< - (te>nsor2xui<64, tens>or2x2xui64)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semántica
Realiza un redondeo según cada elemento hacia el número entero más cercano, rompiendo los empates lejos de cero, en el tensor operand y produce un tensor result. Implementa la operación roundToIntegralTiesToAway de la especificación IEEE-754. Para los tipos cuantificados, realiza dequantize_op_quantize(round_nearest_afz, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semántica
Realiza un redondeo según cada elemento hacia el número entero más cercano, con desempates hacia el número entero par, en el tensor operand y produce un tensor result. Implementa la operación roundToIntegralTiesToEven de la especificación IEEE-754. Para los tipos cuantificados, realiza dequantize_op_quantize(round_nearest_even, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semántica
Realiza una operación de raíz cuadrada recíproca a nivel del elemento en el tensor operand y genera un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
rSqrtde IEEE-754. - Para números complejos: raíz cuadrada recíproca compleja.
- Para tipos cuantificados:
dequantize_op_quantize(rsqrt, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
Dispersión
Semántica
Produce tensores results que son iguales a los tensores inputs, excepto que varias segmentaciones especificadas por scatter_indices se actualizan con los valores updates usando update_computation.
En el siguiente diagrama, se muestra cómo se asignan los elementos de updates... a los elementos de results... con un ejemplo concreto. El diagrama elige algunos índices de updates... como ejemplo y explica en detalle a qué índices de results... corresponden.
De manera más formal, para todo update_index en index_space(updates[0]), se cumple lo siguiente:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].update_scatter_index = update_index[update_scatter_dims...].start_indexse define de la siguiente manera:scatter_indices[si0, ..., :, ..., siN], dondesison elementos individuales enupdate_scatter_indexy:se inserta en el índiceindex_vector_dim, siindex_vector_dim<rank(scatter_indices).- En caso contrario,
[scatter_indices[update_scatter_index]].
- Para
d_inputenaxes(inputs[0]),full_start_index[d_input] = start_index[d_start]sid_input = scatter_dims_to_operand_dims[d_start].- En caso contrario,
full_start_index[d_input] = 0.
- Para
d_inputenaxes(inputs[0]),full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]sid_input = input_batching_dims[i_batching]yd_start = scatter_indices_batching_dims[i_batching].- En caso contrario,
full_batching_index[d_input] = 0.
update_window_index = update_index[update_window_dims...].full_window_index = [wi0, ..., 0, ..., wiN], dondewison elementos individuales enupdate_window_indexy0se inserta en los índices deinserted_window_dimsyinput_batching_dims.result_index = full_start_index + full_batching_index + full_window_index.
Dado que results = exec(schedule, inputs), donde:
schedulees una permutación definida por la implementación deindex_space(updates[0]).exec([update_index, ...], results) = exec([...], updated_results), donde:- Si
result_indexestá dentro de los límites deshape(results...) updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )updated_values = update_computation(results...[result_index], updates_converted)updated_resultses una copia deresultsconresults...[result_index]establecido enupdated_values....- De lo contrario
updated_results = results.
- Si
exec([], results) = results.
Si indices_are_sorted es true, la implementación puede suponer que los elementos de scatter_indices están ordenados con respecto a scatter_dims_to_operand_dims. De lo contrario, el comportamiento es indefinido. De manera más formal, para todo i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).
Si unique_indices es true, la implementación puede suponer que todos los índices de result_index a los que se dispersan son únicos. Si unique_indices es true, pero los índices a los que se dispersan no son únicos, el comportamiento no está definido.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
| (I2) | scatter_indices |
Tensor de tipo entero | (C4), (C15), (C19), (C22) |
| (I3) | updates |
Cantidad variable de tensores o tensores cuantificados por tensor | (C3 a C6) y (C8) |
| (I4) | update_window_dims |
Tensor constante unidimensional de tipo si64 |
(C2), (C4), (C7-C8) |
| (I5) | inserted_window_dims |
Tensor constante unidimensional de tipo si64 |
(C2), (C4), (C9-C11) |
| (I6) | input_batching_dims |
Tensor constante unidimensional de tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
| (I7) | scatter_indices_batching_dims |
Tensor constante unidimensional de tipo si64 |
(C14-C18) |
| (I8) | scatter_dims_to_operand_dims |
Tensor constante unidimensional de tipo si64 |
(C19-C21) |
| (I9) | index_vector_dim |
constante de tipo si64 |
(C4), (C16), (C19) y (C22) |
| (I10) | indices_are_sorted |
constante de tipo i1 |
|
| (I11) | unique_indices |
constante de tipo i1 |
|
| (I12) | update_computation |
función | (C23) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores o tensores cuantificados por tensor | (C24-C25) |
Limitaciones
- (C1)
same(shape(inputs...)). - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims) + size(input_batching_dims). - (C3)
same(shape(updates...)). - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)donde:update_scatter_dim_sizes = shape(scatter_indices), excepto que no se incluye el tamaño de la dimensión descatter_indicescorrespondiente aindex_vector_dim.update_window_dim_sizes <= shape(inputs[0]), excepto que no se incluyen los tamaños de dimensión eninputs[0]correspondientes ainserted_window_dimsyinput_batching_dims.combinecolocaupdate_scatter_dim_sizesen los ejes correspondientes aupdate_scatter_dimsyupdate_window_dim_sizesen los ejes correspondientes aupdate_window_dims.
- (C5)
0 < size(inputs) = size(updates) = N. - (C6)
element_type(updates...) = element_type(inputs...). - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims). - (C8)
0 <= update_window_dims < rank(updates[0]). - (C9)
is_unique(concatenate(inserted_window_dims, input_batching_dims)) - (C10)
is_sorted(inserted_window_dims). - (C11)
0 <= inserted_window_dims < rank(inputs[0]). - (C12)
is_sorted(input_batching_dims). - (C13)
0 <= input_batching_dims < rank(inputs[0])). - (C14)
is_unique(scatter_indices_batching_dims). - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices). - (C16)
index_vector_dim not in scatter_indices_batching_dims. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims). - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...). - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims)). - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0]). - (C22)
0 <= index_vector_dim <= rank(scatter_indices). - (C23)
update_computationtiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), dondeis_promotable(element_type(inputs[i]), Ei). - (C24)
shape(inputs...) = shape(results...). - (C25)
element_type(results[i]) = Eipara todoien[0,N).
Ejemplos
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ],
// [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.add"(%ar<g0,> %arg1) <: (>ten>sori64,< te>nsori64) - tensori64
"stable<hlo>.re>turn"(%0) : (tensori64) - ()
}) {
scatter_dimensio<n_numbers = #stablehlo.scatter
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2>, 1],
index_vector_dim = 3,
indices_are_sorted = false,
uniq<ue_indices >= false
<} : (tensor>2x3x4x2x<i64, tensor2x>2x3>x2xi64,< tensor2x2x>3x2x2xi64) - tensor2x3x4x2xi64
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
seleccionar
Semántica
Produce un tensor result en el que cada elemento se selecciona del tensor on_true o on_false según el valor del elemento correspondiente de pred.
De manera más formal, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index], donde pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]. Para los tipos cuantificados, realiza dequantize_select_quantize(pred, on_true, on_false, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | pred |
tensor de tipo i1 |
(C1) |
| (I2) | on_true |
tensor o tensor cuantificado por tensor | (C1-C2) |
| (I3) | on_false |
tensor o tensor cuantificado por tensor | (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C2) |
Limitaciones
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true). - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).
Ejemplos
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false)< : (te>nsor2x2x<i1, ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semántica
Dispersa los valores del tensor source con scatter según el resultado de reduce_window del tensor input con select y produce un tensor result.
En el siguiente diagrama, se muestra cómo se calculan los elementos de result a partir de operand y source con un ejemplo concreto.
De forma más formal:
selected_values = reduce_window_without_init(...)con las siguientes entradas:inputs = [operand].window_dimensions,window_stridesypadding, que se usan tal cual.base_dilations = windows_dilations = 1.bodyse define de la siguiente manera:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;donde
E = element_type(operand)yreduce_window_without_initfuncionan exactamente igual quereduce_window, excepto que elscheduledelreducesubyacente (consulta reduce) no incluye valores de inicialización. Actualmente, no se especifica qué sucede si la ventana correspondiente no tiene valores (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)donde:source_values = [source[source_index] for source_index in source_indices].selected_index(source_index) = operand_indexsiselected_values[source_index]tiene el elementooperanddeoperand_index.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C4), (C6), (C8-C11) |
| (I2) | source |
tensor o tensor cuantificado por tensor | (C1), (C2) |
| (I3) | init_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C3) |
| (I4) | window_dimensions |
Tensor constante unidimensional de tipo si64 |
(C2), (C4) y (C5) |
| (I5) | window_strides |
Tensor constante unidimensional de tipo si64 |
(C2), (C6) y (C7) |
| (I6) | padding |
Tensor constante bidimensional de tipo si64 |
(C2), (C8) |
| (I7) | select |
función | (C9) |
| (I8) | scatter |
función | (C10) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C11-C12) |
Limitaciones
- (C1)
element_type(operand) = element_type(source). - (C2)
shape(source) = num_windows, donde:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
- (C3)
element_type(init_value) = element_type(operand). - (C4)
size(window_dimensions) = rank(operand). - (C5)
0 < window_dimensions. - (C6)
size(window_strides) = rank(operand). - (C7)
0 < window_strides. - (C8)
shape(padding) = [rank(operand), 2]. - (C9)
selecttiene el tipo(tensor<E>, tensor<E>) -> tensor<i1>, dondeE = element_type(operand). - (C10)
scattertiene el tipo(tensor<E>, tensor<E>) -> tensor<E>, dondeis_promotable(element_type(operand), E). - (C11)
shape(operand) = shape(result). - (C12)
element_type(result) = E.
Ejemplos
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>E
} <: (>ten>sori64,< t>ensori64) - tensori1
"stable<hl>o.r>eturn"(%0) : (tensori1) <- (>)
}, {
^bb0(%<arg>0: tensori64, %arg1: tensori64):
%0 = "sta<ble>hlo.add&<quo>t;(>%arg0, <%ar>g1) : (tensori64, tensori64) - tensor<i64>
> "stablehlo.return"(%0) :< (tensori>64) - ()
}) {
window_dim<ensions => arrayi64: 3, 1,
<window_strides => arrayi64<: 2, 1,>
padding =< dense[>[0, 1], <[0, 0]]> : tenso<r2x>2xi>64
} : <(tensor>4x2xi64, tensor2x2xi64, tensori64) - tensor4x2xi64
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
enviar
Semántica
Envía inputs a un canal channel_id. Luego, las entradas se envían a otros dispositivos en el orden especificado por source_target_pairs. La operación produce un token result.
Si is_host_transfer es true, la operación transfiere datos al host. De lo contrario, transfiere datos a otro dispositivo según los valores de source_target_pairs. Esta marca duplica la información proporcionada en channel_type, por lo que, en el futuro, planeamos conservar solo una de ellas (#666). Si is_host_transfer = false y source_target_pairs es None o está vacío, se considera un comportamiento indefinido.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados | |
| (I2) | token |
token |
|
| (I3) | channel_id |
constante de tipo si64 |
|
| (I4) | channel_type |
Enumeración de DEVICE_TO_DEVICE y DEVICE_TO_HOST |
(C5) |
| (I5) | is_host_transfer |
constante de tipo i1 |
(C5-C6) |
| (I6) | source_target_pairs |
Tensor constante bidimensional de tipo si64 |
(C1-C4), (C6) |
Salidas
| Nombre | Tipo |
|---|---|
result |
token |
Limitaciones
- (C1)
dim(source_target_pairs, 1) = 2. - (C2)
is_unique(source_target_pairs[:, 0]). - (C3)
is_unique(source_target_pairs[:, 1]). - (C4)
0 <= source_target_pairs < N, dondeNse define de la siguiente manera:num_replicassi se usacross_replica.num_partitionssi se usacross_partition.
- (C5)
channel_typese define de la siguiente manera:DEVICE_TO_HOSTsiis_host_transfer = true,- En caso contrario,
DEVICE_TO_DEVICE.
Ejemplos
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.chan<nel_handlehandle = 0>, type = 1,
is_host_transfer = false,
source_target_pai<rs = dense[[0, 1>], [1, 2]<] : ten>sor2x2xi64
}< : (ten>sor2x2xi64, !stablehl>o.token) - !stablehlo.token
shift_left
Semántica
Realiza una operación de desplazamiento a la izquierda a nivel del elemento en el tensor lhs por la cantidad de bits rhs y produce un tensor result.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
Tensor de tipo entero | (C1) |
| (I2) | rhs |
Tensor de tipo entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result).
Ejemplos
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-2, 0, 8]
shift_right_arithmetic
Semántica
Realiza una operación de desplazamiento a la derecha aritmética a nivel del elemento en el tensor lhs por una cantidad de bits de rhs y genera un tensor result.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
Tensor de tipo entero | (C1) |
| (I2) | rhs |
Tensor de tipo entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result).
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [-1, 0, 1]
shift_right_logical
Semántica
Realiza la operación de desplazamiento a la derecha lógico a nivel del elemento en el tensor lhs por una cantidad de bits rhs y genera un tensor result.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
Tensor de tipo entero | (C1) |
| (I2) | rhs |
Tensor de tipo entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result).
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs<): (t>ensor3xi<64, t>ens>or3xi64<) - t>ensor3xi64
// %result: [9223372036854775807, 0, 1]
firmar
Semántica
Devuelve el signo del elemento operand y genera un tensor result.
De manera más formal, para cada elemento x, la semántica se puede expresar con la sintaxis de Python de la siguiente manera:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Para los tipos cuantificados, realiza dequantize_op_quantize(sign, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo complejo, de punto flotante o de número entero con signo, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo complejo, de punto flotante o de número entero con signo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand)< : (t>ens>or5xf64<) - t>ensor5xf64
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semántica
Realiza una operación seno a nivel del elemento en el tensor operand y genera un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
sinde IEEE-754. - Para números complejos: seno complejo.
- Para tipos cuantificados:
dequantize_op_quantize(sine, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [0.0, -1.0]]
porción
Semántica
Extrae un segmento del tensor operand con índices iniciales calculados de forma estática y produce un tensor result. start_indices contiene los índices iniciales del segmento para cada dimensión, limit_indices contiene los índices finales (exclusivos) del segmento para cada dimensión y strides contiene los pasos para cada dimensión.
De manera más formal, result[result_index] = operand[operand_index], donde operand_index = start_indices + result_index * strides.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado por tensor | (C1 a C3) y (C5) |
| (I2) | start_indices |
Tensor constante unidimensional de tipo si64 |
(C2), (C3) y (C5) |
| (I3) | limit_indices |
Tensor constante unidimensional de tipo si64 |
(C2), (C3) y (C5) |
| (I4) | strides |
Tensor constante unidimensional de tipo si64 |
(C2), (C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) y (C5) |
Limitaciones
- (C1)
element_type(operand) = element_type(result). - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand). - (C3)
0 <= start_indices <= limit_indices <= shape(operand). - (C4)
0 < strides. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides).
Ejemplos
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indic<es = arra>yi64: 1, 2,
limit_indic<es = arra>yi64: 3, 4,
strid<es = arra>yi64: 1, 1
}< : (ten>sor>3x4xi64<) - ten>sor2x2xi64
// % result: [
// [1, 1],
// [1, 1]
// ]
sort
Semántica
Ordena las porciones unidimensionales de inputs a lo largo de la dimensión dimension, según un comparator, y produce results.
A diferencia de las entradas similares en otras operaciones, dimension permite valores negativos, con la semántica que se describe a continuación. En el futuro, es posible que esto no se permita por motivos de coherencia (#1377).
Si is_stable es verdadero, el ordenamiento es estable, es decir, se conserva el orden relativo de los elementos que el comparador considera iguales. En el caso en que haya una sola entrada, el comparador considera que dos elementos e1 y e2 son iguales si y solo si comparator(e1, e2) = comparator(e2, e1) = false. Consulta la formalización a continuación para ver cómo se generaliza esto para múltiples entradas.
De manera más formal, para todo result_index en index_space(results[0]), se cumple lo siguiente:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.result_slice = [ri0, ..., :, ..., riR-1], donderiNson elementos individuales enresult_indexy:se inserta enadjusted_dimension.inputs_together = (inputs[0]..., ..., inputs[N-1]...).results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).- donde
sortordena un segmento unidimensional en orden no descendente, y se espera quecomparator_togetherdevuelvatruesi el argumento del lado izquierdo es menor que el segundo argumento del lado derecho. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)(results[0]..., ..., results[N-1]...) = results_together.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | inputs |
Cantidad variable de tensores o tensores cuantificados por tensor | (C1 a C5) |
| (I2) | dimension |
constante de tipo si64 |
(C4) |
| (I3) | is_stable |
constante de tipo i1 |
|
| (I4) | comparator |
función | (C5) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de tensores o tensores cuantificados por tensor | (C2), (C3) |
Limitaciones
- (C1)
0 < size(inputs). - (C2)
type(inputs...) = type(results...). - (C3)
same(shape(inputs...) + shape(results...)). - (C4)
-R <= dimension < R, dondeR = rank(inputs[0]). - (C5)
comparatortiene el tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, dondeEi = element_type(inputs[i]).
Ejemplos
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64, %ar<g2:> tensori64, %ar<g3:> tensori64):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_di<rection = #stablehlocom>parison_directio<n G>T
} <: (>ten>sori64,< t>ensori64) - tensori1
"stablehlo.retu<rn>&qu>ot;(%predicate) : (tensori1) - ()
}) {
dimension = 0 : i64,
< is_st>able = t<rue
} :> (t>ensor2x3<xi64, t>ensor2x3<xi64) -> (tensor2x3xi64, tensor2x3xi64)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semántica
Realiza la operación de raíz cuadrada para cada elemento del tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
squareRootde IEEE-754. - Para números complejos: raíz cuadrada compleja.
- Para tipos cuantificados:
dequantize_op_quantize(sqrt, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand)< : (ten>sor>2x2xf32<) - ten>sor2x2xf32
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semántica
Realiza la resta de dos tensores lhs y rhs elemento por elemento, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: resta de números enteros.
- Para números de punto flotante:
subtractionde IEEE-754. - Para números complejos, resta de números complejos.
- Para los tipos cuantificados:
dequantize_op_quantize(subtract, lhs, rhs, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
| (I2) | rhs |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).
Ejemplos
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs)< : (ten>sor2x2xf<32, ten>sor>2x2xf32)< - (ten>sor2x2xf32)
// %result: [[1, 2], [3, 4]]
tan
Semántica
Realiza una operación tangente a nivel del elemento en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
tande IEEE-754. - Para números complejos: tangente compleja.
- Para tipos cuantificados:
dequantize_op_quantize(tan, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand)< : (ten>sor>2x2xf64<) - ten>sor2x2xf64
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semántica
Realiza la operación de tangente hiperbólica para cada elemento del tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
tanhde IEEE-754. - Para números complejos: tangente hiperbólica compleja.
- Para los tipos cuantificados:
dequantize_op_quantize(tanh, operand, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result).
Ejemplos
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand)< : (t>ens>or3xf32<) - t>ensor3xf32
// %result: [-0.76159416, 0.0, 0.76159416]
transposición
Semántica
Permuta las dimensiones del tensor operand con permutation y genera un tensor result. De manera más formal, result[result_index] = operand[operand_index], donde result_index[d] = operand_index[permutation[d]].
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor o tensor cuantificado | (C1-C4) |
| (I2) | permutation |
Tensor constante unidimensional de tipo si64 |
(C2-C4) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3-C4) |
Limitaciones
- (C1)
element_type(result)se calcula de la siguiente manera:element_type(operand), si!is_per_axis_quantized(operand).element_type(operand), excepto quequantization_dimension(operand)yquantization_dimension(result)pueden diferir.
- (C2)
permutationes una permutación derange(rank(operand)). - (C3)
shape(result) = dim(operand, permutation...). - (C4) Si
is_per_axis_quantized(result), entoncesquantization_dimension(operand) = permutation(quantization_dimension(result)).
Ejemplos
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutati<on = arrayi6>4: 2, 1, 0
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semántica
Resuelve lotes de sistemas de ecuaciones lineales con matrices de coeficientes triangulares inferiores o superiores.
De manera más formal, dados a y b, result[i0, ..., iR-3, :, :] es la solución para op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] cuando left_side es true o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] cuando left_side es false, y se resuelve para la variable x, en la que op(a) se determina con transpose_a, que puede ser uno de los siguientes:
NO_TRANSPOSE: Realiza la operación usandoatal como está.TRANSPOSE: Realiza la operación en la transpuesta dea.ADJOINT: Realiza la operación en la transpuesta conjugada dea.
Los datos de entrada se leen solo desde el triángulo inferior de a, si lower es true o el triángulo superior de a, de lo contrario. Los datos de salida se devuelven en el mismo triángulo; los valores del otro triángulo se definen según la implementación.
Si unit_diagonal es verdadero, la implementación puede suponer que los elementos diagonales de a son iguales a 1; de lo contrario, el comportamiento no está definido.
Para los tipos cuantificados, realiza dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | a |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1 a C3) |
| (I2) | b |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1-C4) |
| (I3) | left_side |
constante de tipo i1 |
(C3) |
| (I4) | lower |
constante de tipo i1 |
|
| (I5) | unit_diagonal |
constante de tipo i1 |
|
| (I6) | transpose_a |
Enumeración de NO_TRANSPOSE, TRANSPOSE y ADJOINT |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo complejo o de punto flotante, o tensor cuantificado por tensor | (C1) |
Limitaciones
- (C1)
baseline_element_type(a) = baseline_element_type(b). - (C2)
2 <= rank(a) = rank(b) = R. - (C3) La relación entre
shape(a)yshape(b)se define de la siguiente manera:shape(a)[:-3] = shape(b)[:-3].dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
- (C4)
baseline_type(b) = baseline_type(result).
Ejemplos
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = <#stablehlotranspose NO>_TRANSPOSE
}< : (ten>sor3x3xf<32, ten>sor>3x3xf32<) - ten>sor3x3xf32
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semántica
Produce una tupla result a partir de los valores val.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | val |
Cantidad variable de valores | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tuple | (C1) |
Limitaciones
- (C1)
resulttiene el tipotuple<E0, ..., EN-1>, dondeEi = type(val[i]).
Ejemplos
// %val0: memref[1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1)< : (m>emref2x<f32, t<upl>>ete>nsori3<2) - t<uplem>emref2x<f32, t<upl>>>etensori32
// %result: (memref[1.0, 2.0], (3))
uniform_dequantize
Semántica
Realiza la conversión de cada elemento del tensor cuantizado operand en un tensor de punto flotante result según los parámetros de cuantización definidos por el tipo operand.
De manera más formal, result = dequantize(operand).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
tensor cuantificado | (C1), (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
Tensor de tipo de punto flotante | (C1), (C2) |
Limitaciones
- (C1)
shape(operand) = shape(result). - (C2)
element_type(result) = expressed_type(operand).
Ejemplos
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand)< : (tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0>.5:-20}<) - t>ensor2xf32
// %result: [4.0, 15.0]
uniform_quantize
Semántica
Realiza la conversión de cada elemento del tensor de punto flotante o del tensor cuantizado operand en un tensor cuantizado result según los parámetros de cuantización definidos por el tipo result.
De manera más formal,
- Si
is_float(operand), haz lo siguiente:result = quantize(operand, type(result)).
- Si
is_quantized(operand), haz lo siguiente:float_result = dequantize(operand).result = quantize(float_result, type(result)).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Tensor de tipo de punto flotante o cuantificado | (C1), (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor cuantificado | (C1), (C2) |
Limitaciones
- (C1)
shape(operand) = shape(result). - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).
Ejemplos
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand)< : (t>ens>or2xf32<) - tensor2x!qua<nt.uniformi8:f32:0, {0.1:-3>>0,0.5:-20}
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"<(%operand) : (te<nsor2x!quant.uniformi8:f32:>>0, >{0.1:-3<0,0.5:-20}) - te<nsor2x!quant.uniformi8:f32:>>0, {0.1:-20,0.2:-30}
// %result: [20, 45]
mientras
Semántica
Produce el resultado de la ejecución de la función body 0 o más veces mientras la función cond genera true. De manera más formal, la semántica se puede expresar con la sintaxis de Python de la siguiente manera:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
El comportamiento de un bucle infinito está por definirse (#383).
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | operand |
Cantidad variable de valores | (C1 a C3) |
| (I2) | cond |
función | (C1) |
| (I3) | body |
función | (C2) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
results |
Cantidad variable de valores | (C3) |
Limitaciones
- (C1)
condtiene el tipo(T0, ..., TN-1) -> tensor<i1>, dondeTi = type(operand[i]). - (C2)
bodytiene el tipo(T0, ..., TN-1) -> (T0, ..., TN-1), dondeTi = type(operand[i]). - (C3)
type(results...) = type(operand...).
Ejemplos
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%ar<g0:> tensori64, %ar<g1:> tensori64):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_di<rection = #stablehlocom>parison_directio<n L>T
} <: (>ten>sori64,< t>ensori64) - tensori1
stablehlo.r<et>urn %cond : tensori1
}, {
< ^>bb0(%arg0: tens<ori>64, %arg1: tensori64):
%new_sum = stablehlo.add <%ar>g1, %one : tensori64
%new_i = stablehlo.add <%ar>g0, %one : tensori64
stablehlo.return %new_<i, >%new_sum< : >tensori64, te<nso>ri64
}) <: (>ten>sori64, <ten>sori64) <- (>tensori64, tensori64)
// %results0: 10
// %results1: 10
xor
Semántica
Realiza una operación XOR a nivel de los elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: XOR lógico.
- Para números enteros: XOR bit a bit.
Entradas
| Etiqueta | Nombre | Tipo | Limitaciones |
|---|---|---|---|
| (I1) | lhs |
tensor de tipo booleano o entero | (C1) |
| (I2) | rhs |
tensor de tipo booleano o entero | (C1) |
Salidas
| Nombre | Tipo | Limitaciones |
|---|---|---|
result |
tensor de tipo booleano o entero | (C1) |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result).
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs)< : (ten>sor2x2xi<32, ten>sor>2x2xi32<) - ten>sor2x2xi32
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%<lhs, %>rhs) : (<tensor>2x2>xi1, te<nsor2x>2xi1) - tensor2x2xi1
// %result: [[false, true], [true, false]]
Interoperabilidad de dialectos
Por el momento, los programas de StableHLO en uso a veces contienen operaciones que no están definidas por StableHLO.
Módulo, función, llamada y devolución
StableHLO usa operaciones MLIR ascendentes para ModuleOp, FuncOp, CallOp y ReturnOp. Esto se hizo para mejorar la interoperabilidad con la infraestructura existente de MLIR, ya que muchos pases útiles se escriben para FuncOp y ModuleOp, y muchas canalizaciones de compilación esperan que estos ops estén presentes. Estas operaciones tienen garantías de compatibilidad completas. Si alguna vez cambia algo sobre estas operaciones de una manera incompatible (es decir, se quitan), se agregarán equivalentes de StableHLO para preservar la compatibilidad.
CHLO
El conjunto de operaciones de CHLO contiene operaciones de nivel superior que se descomponen en StableHLO. Actualmente, no hay garantías de compatibilidad para CHLO. Para garantizar la compatibilidad, se debe usar el pase chlo-legalize-to-stablehlo antes de la serialización.
Operaciones de formas
En la comunidad, es un caso de uso común utilizar ciertas operaciones de dialectos centrales de MLIR en programas dinámicos de StableHLO para realizar cálculos de formas.
Por lo general, incluyen operaciones de shape dialect como shape_of o num_elements, operaciones de tensor dialect como dim o from_elements, y el tipo index integrado.
La RFC de dinamismo > O2 indica que estos están fuera del alcance, pero se incluye cierta compatibilidad con los tipos index para fines de interoperabilidad. No hay garantías de compatibilidad para estas operaciones o tipos. El pase shape-legalize-to-stablehlo se puede usar para convertir estas operaciones en operaciones de StableHLO totalmente compatibles.
Operaciones obsoletas
Hay varias operaciones de StableHLO que se heredaron de MHLO, que ya no están disponibles y están en proceso de dejar de formar parte de StableHLO. Puedes encontrar todos los detalles sobre estas eliminaciones en StableHLO v1.0 Cleanup #2283. El problema de seguimiento de estas bajas es #2340.
Estas operaciones se dividen en algunas categorías:
- Categoría "No en HLO" de las operaciones de StableHLO: Inicialmente, formaban parte del conjunto de operaciones de StableHLO, pero luego se consideró que no se ajustaban bien a él:
broadcast,create_token,cross-replica-sum,dot,einsum,torch_index_select,unary_einsum(#3). - Operaciones no utilizadas: Es posible que estas operaciones hayan sido útiles en algún momento, pero estaban subdesarrolladas o las canalizaciones que las usaban se refactorizaron para que ya no las requieran. Esto incluye
map,tuple(#598), comparaciones deget_tuple_element,rngycomplex#560, y convoluciónwindow_reversal(#1181).
Algunas de estas operaciones se pueden quitar fácilmente, ya que se pueden expresar con operaciones existentes (broadcast, create_token, cross-replica-sum, dot, unary_einsum) y se quitarán después de que pase el período de compatibilidad existente (6 meses). Se sigue explorando la posibilidad de quitar otras (einsum, get_tuple_element, map, rng, torch_index_select, tuple, comparaciones de complex, window_reversal). Según los comentarios de la comunidad, estas operaciones se quitarán o se agregarán a la especificación con compatibilidad total. Hasta que se conozcan los futuros de estas operaciones, solo se garantiza una compatibilidad de 6 meses.
Ejecución
Ejecución secuencial
Un programa de StableHLO se ejecuta proporcionando valores de entrada a la función main y calculando valores de salida. Los valores de salida de una función se calculan ejecutando el gráfico de operaciones que tiene como raíz la operación return correspondiente.
El orden de ejecución se define en la implementación, siempre y cuando se alinee con el flujo de datos, es decir, si las operaciones se ejecutan antes de sus usos. En StableHLO, todas las operaciones con efectos secundarios consumen un token y producen un token (se pueden multiplexar varios tokens en uno solo a través de after_all), por lo que el orden de ejecución de los efectos secundarios también se alinea con el flujo de datos. Por ejemplo, en el siguiente programa, hay dos órdenes de ejecución posibles: %0 → %1 → %2 → return y %1 → %0 → %2 → return.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
De manera más formal, un proceso de StableHLO es una combinación de lo siguiente:
1) un programa de StableHLO, 2) estados de operación (aún no se ejecutó, ya se ejecutó) y 3) valores intermedios en los que trabaja el proceso.
El proceso comienza con valores de entrada para la función main, avanza a través del gráfico de operaciones que actualizan los estados de las operaciones y los valores intermedios, y finaliza con valores de salida. La formalización adicional está por definirse (#484).
Ejecución paralela
Los programas de StableHLO se pueden ejecutar en paralelo y se organizan en una cuadrícula de procesos 2D de num_replicas por num_partitions, en la que ambos tienen el tipo ui32.
En la cuadrícula de procesos de StableHLO, se ejecutan num_replicas * num_partitions procesos de StableHLO al mismo tiempo. Cada proceso tiene un process_id = (replica_id, partition_id) único, donde replica_id en replica_ids = range(num_replicas) y partition_id en partition_ids = range(num_partitions), ambos con el tipo ui32.
El tamaño de la cuadrícula de procesos se conoce de forma estática para cada programa (en el futuro, planeamos que sea una parte explícita de los programas de StableHLO #650), y la posición dentro de la cuadrícula de procesos se conoce de forma estática para cada proceso. Cada proceso tiene acceso a su posición dentro de la cuadrícula de procesos a través de las operaciones replica_id y partition_id.
Dentro de la cuadrícula de procesos, los programas pueden ser todos iguales (en el estilo "Un solo programa, varios datos"), pueden ser todos diferentes (en el estilo "Varios programas, varios datos") o algo intermedio. En el futuro, planeamos introducir compatibilidad con otros modismos para definir programas de StableHLO paralelos, incluido GSPMD (#619).
Dentro de la cuadrícula de procesos, los procesos son, en su mayoría, independientes entre sí: tienen estados de operación separados, valores de entrada, intermedios y de salida separados, y la mayoría de las operaciones se ejecutan por separado entre los procesos, con la excepción de una pequeña cantidad de operaciones colectivas que se describen a continuación.
Dado que la ejecución de la mayoría de las operaciones solo usa valores del mismo proceso, suele ser inequívoco hacer referencia a estos valores por sus nombres.
Sin embargo, cuando se describen las semánticas de las operaciones colectivas, esto no es suficiente, y da lugar a la notación name@process_id para hacer referencia al valor name dentro de un proceso en particular. (Desde esa perspectiva, el name no calificado se puede considerar como una abreviatura de name@(replica_id(), partition_id())).
El orden de ejecución entre los procesos se define en la implementación, excepto por la sincronización que se introduce con la comunicación punto a punto y las operaciones colectivas, como se describe a continuación.
Comunicación de punto a punto
Los procesos de StableHLO pueden comunicarse entre sí a través de canales de StableHLO. Un canal se representa con un ID positivo de tipo si64. A través de varias operaciones, es posible enviar valores a los canales y recibirlos de ellos.
La formalización adicional, p.ej., de dónde provienen estos IDs de canal, cómo los procesos de los programas los conocen y qué tipo de sincronización introducen, está pendiente (#484).
Comunicación de transmisión
Cada proceso de StableHLO tiene acceso a dos interfaces de transmisión:
- Es el objeto Infeed del que se puede leer.
- Outfeed en el que se puede escribir.
A diferencia de los canales, que se usan para la comunicación entre procesos y, por lo tanto, tienen procesos en ambos extremos, los infeeds y los outfeeds tienen su otro extremo definido por la implementación.
La formalización adicional, p.ej., cómo la comunicación de transmisión influye en el orden de ejecución y qué tipo de sincronización introduce, está pendiente (#484).
Operaciones colectivas
Hay seis operaciones colectivas en StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute y reduce_scatter. Todas estas operaciones dividen los procesos en la cuadrícula de procesos de StableHLO en grupos de procesos de StableHLO y ejecutan un cálculo conjunto dentro de cada grupo de procesos, independientemente de otros grupos de procesos.
Dentro de cada grupo de procesos, las operaciones colectivas pueden introducir una barrera de sincronización. La formalización adicional, por ejemplo, la explicación de cuándo ocurre exactamente esta sincronización, cómo llegan exactamente los procesos a esta barrera y qué sucede si no lo hacen, está pendiente (#484).
Si el grupo de procesos implica comunicación entre particiones, es decir, hay procesos en el grupo de procesos cuyos IDs de partición son diferentes, la ejecución de la operación colectiva necesita un canal, y la operación colectiva debe proporcionar un channel_id positivo de tipo si64. La comunicación entre réplicas no necesita canales.
Los cálculos que realizan las operaciones colectivas son específicos para cada operación y se describen en las secciones de operaciones individuales anteriores. Sin embargo, las estrategias por las que la cuadrícula de procesos se divide en grupos de procesos se comparten entre estas operaciones y se describen en esta sección. De manera más formal, StableHLO admite las siguientes cuatro estrategias.
cross_replica
Solo se producen comunicaciones entre réplicas dentro de cada grupo de procesos. Esta estrategia toma replica_groups, una lista de listas de IDs de réplicas, y calcula un producto cartesiano de replica_groups por partition_ids. replica_groups debe tener elementos únicos y abarcar todos los replica_ids. De manera más formal, con la sintaxis de Python, se escribe de la siguiente manera:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]] y num_partitions = 2, cross_replica producirá [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].
cross_partition
Solo se producen comunicaciones entre particiones dentro de cada grupo de procesos. Esta estrategia toma partition_groups, una lista de listas de IDs de partición, y calcula un producto cartesiano de partition_groups por replica_ids.
partition_groups debe tener elementos únicos y cubrir todos los partition_ids.
De manera más formal, con la sintaxis de Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para partition_groups = [[0, 1]] y num_replicas = 4, cross_partition producirá [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].
cross_replica_and_partition
Las comunicaciones entre réplicas y entre particiones pueden ocurrir dentro de cada grupo de procesos. Esta estrategia toma replica_groups, una lista de listas de IDs de réplicas, y calcula los productos cartesianos de cada replica_group por partition_ids. replica_groups debe tener elementos únicos y cubrir todos los replica_ids. De manera más formal, con la sintaxis de Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]] y num_partitions = 2, cross_replica_and_partition producirá [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].
flattened_ids
Esta estrategia toma flattened_id_groups, una lista de listas de IDs de procesos "aplanados" en forma de replica_id * num_partitions + partition_id, y los convierte en IDs de procesos. flattened_id_groups debe tener elementos únicos y cubrir todos los process_ids. De manera más formal, con la sintaxis de Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 y num_partitions = 2, flattened_ids producirá [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].
Exactitud
Por el momento, StableHLO no ofrece garantías sobre la precisión numérica, pero esto podría cambiar en el futuro (#1156).
Semántica de ejecución de la operación cuantificada
La interpretación de las operaciones de StableHLO cuantificadas puede variar según los requisitos y las capacidades del hardware. Por ejemplo, es posible que algunos dispositivos de hardware opten por interpretar las operaciones cuantificadas con una estrategia de "descuantificar, realizar la operación de punto flotante y, finalmente, cuantificar". Otros pueden realizar todo el cálculo con aritmética de números enteros. Por lo tanto, la interpretación de las operaciones de StableHLO cuantizadas se determina exclusivamente por la implementación específica. La interpretación de la cuantificación híbrida (#1575) debe basarse en su semántica, según lo prescrito en la especificación (a través de 1792).
Errores
Los programas de StableHLO se validan a través de un amplio conjunto de restricciones para las operaciones individuales, lo que descarta muchas clases de errores antes del tiempo de ejecución. Sin embargo, aún son posibles las condiciones de error, p.ej., a través de desbordamientos de números enteros, accesos fuera de límites, etcétera. A menos que se indique explícitamente, todos estos errores generan un comportamiento definido por la implementación, pero esto puede cambiar en el futuro (#1157).
Excepciones de punto flotante
Como excepción a esta regla, las excepciones de punto flotante en los programas de StableHLO tienen un comportamiento bien definido. Las operaciones que generan excepciones definidas por el estándar IEEE-754 (operación no válida, división por cero, desbordamiento, subdesbordamiento o excepciones inexactas) producen resultados predeterminados (según se definen en el estándar) y continúan la ejecución sin activar la marca de estado correspondiente, de manera similar al control de excepciones raiseNoFlag del estándar. Las excepciones para las operaciones no estándar (p.ej., operaciones aritméticas complejas y ciertas funciones trascendentales) se definen en la implementación.
Discrepancias de forma
StableHLO admite tensores con formas dinámicas. Sin embargo, las formas deben coincidir en el tiempo de ejecución; de lo contrario, el comportamiento es indefinido. StableHLO no proporciona de forma explícita una operación que pueda afirmar que un tensor tiene una forma determinada en el tiempo de ejecución. El productor es responsable de generar el código correcto.
Como ejemplo específico, el siguiente programa es válido. Sin embargo, en el tiempo de ejecución, las formas exactas de %arg0 y %arg1 deberán ser las mismas; de lo contrario, el comportamiento del programa no estará definido:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Para describir la sintaxis, este documento usa la variante ISO modificada de la sintaxis de EBNF (ISO/IEC 14977:1996, Wikipedia), con dos modificaciones: 1) Las reglas se definen con ::= en lugar de =.
2) La concatenación se expresa con yuxtaposición en lugar de ,.
Para describir la semántica (es decir, dentro de las secciones "Tipos", "Constantes" y "Operaciones"), usamos fórmulas basadas en la sintaxis de Python extendida con compatibilidad para expresar de forma concisa las operaciones de array, como se describe a continuación. Esto funciona bien para fragmentos de código pequeños, pero, en casos excepcionales en los que se necesitan fragmentos de código más grandes, usamos la sintaxis de Python estándar, que siempre se presenta de forma explícita.
Fórmulas
Exploremos cómo funcionan las fórmulas con un ejemplo de la especificación de dot_general. Una de las restricciones para esta operación se ve de la siguiente manera:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
Los nombres que se usan en esta fórmula provienen de dos fuentes: 1) funciones globales, es decir, dim, y 2) definiciones de miembros del elemento del programa correspondiente, es decir, entradas lhs, lhs_batching_dimensions, rhs y rhs_batching_dimensions definidas en la sección "Entradas" de dot_general.
Como se mencionó anteriormente, la sintaxis de esta fórmula se basa en Python con algunas extensiones orientadas a la concisión. Para comprender la fórmula, transformémosla en sintaxis de Python simple.
A) En estas fórmulas, usamos = para representar la igualdad, por lo que el primer paso para obtener la sintaxis de Python es reemplazar = por ==, como se muestra a continuación:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).
B) Además, estas fórmulas admiten elipses (...) que convierten las expresiones escalares en expresiones de tensor. En pocas palabras, f(xs...) significa aproximadamente "para cada escalar x en el tensor xs, calcula un escalar f(x) y, luego, devuelve todos estos resultados escalares juntos como un resultado de tensor". En la sintaxis de Python estándar, nuestra fórmula de ejemplo se convierte en: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions].
Gracias a los puntos suspensivos, a menudo es posible evitar trabajar a nivel de los escalares individuales. Sin embargo, en algunos casos complejos, se puede usar una sintaxis semi informal de nivel inferior, como en la fórmula start_indices[bi0, ..., :, ..., biN] de la especificación gather. En aras de la concisión, no proporcionamos un formalismo exacto para traducir esa sintaxis a Python estándar, con la esperanza de que siga siendo intuitivamente comprensible caso por caso.
Avísanos si algunas fórmulas específicas te parecen poco claras y trataremos de mejorarlas.
Además, notarás que las fórmulas usan puntos suspensivos para expandir todo tipo de listas, incluidos los tensores, las listas de tensores (que, p.ej., pueden surgir de una cantidad variable de tensores), etcétera. Esta es otra área en la que no proporcionamos un formalismo exacto (p.ej., las listas ni siquiera forman parte del sistema de tipos de StableHLO) y, en cambio, nos basamos en la comprensión intuitiva.
C) El último vehículo notacional notable que empleamos es la transmisión implícita. Si bien el conjunto de operaciones de StableHLO no admite la transmisión implícita, las fórmulas sí lo hacen, también en pos de la concisión. En resumen, si se usa un escalar en un contexto en el que se espera un tensor, el escalar se transmite a la forma esperada.
Para continuar con el ejemplo de dot_general, aquí tienes otra restricción:
0 <= lhs_batching_dimensions < rank(lhs). Según se define en la especificación de dot_general, lhs_batching_dimensions es un tensor, pero 0 y rank(lhs) son escalares. Después de aplicar la transmisión implícita, la fórmula se convertirá en [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].
Cuando se aplica a una operación dot_general en particular, esta fórmula se evaluará como un tensor de valores booleanos. Cuando las fórmulas se usan como restricciones, la restricción se mantiene si la fórmula se evalúa como true o como un tensor que solo tiene elementos true.
Nombres
En las fórmulas, el alcance léxico incluye: 1) funciones globales y 2) definiciones de miembros.
3) Definiciones locales A continuación, se proporciona la lista de funciones globales. La lista de definiciones de elementos depende del elemento del programa al que se aplica la notación:
- En el caso de las operaciones, las definiciones de miembros incluyen los nombres que se introducen en las secciones "Entradas" y "Salidas".
- Para todo lo demás, las definiciones de miembros incluyen partes estructurales del elemento del programa, que se denominan según los no terminales de la EBNF correspondientes. La mayoría de las veces, los nombres de estas partes estructurales se obtienen convirtiendo los nombres de los no terminales a snake case (p.ej.,
IntegerLiteral=>integer_literal), pero, a veces, los nombres se abrevian en el proceso (p.ej.,QuantizationStorageType=>storage_type), en cuyo caso los nombres se introducen explícitamente de manera similar a las secciones "Entradas" y "Salidas" en las especificaciones de operación. - Además, las definiciones de miembros siempre incluyen
selfpara hacer referencia al elemento del programa correspondiente.
Valores
Cuando se evalúan las fórmulas, funcionan con los siguientes tipos de valores:
1) Value (valores reales, p. ej., dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; siempre conocen sus tipos)
2) Placeholder (valores futuros, p. ej., lhs, rhs o result; aún no se conocen sus valores reales, solo sus tipos)
3) Type (tipos según se definen en la sección "Tipos")
4) Function (funciones globales según se definen en la sección "Funciones")
Según el contexto, los nombres pueden hacer referencia a diferentes valores. Más específicamente, la sección "Semantics" para las operaciones (y los equivalentes para otros elementos del programa) define la lógica de tiempo de ejecución, por lo que todas las entradas están disponibles como Value.
En cambio, la sección "Constraints" para las operaciones (y sus equivalentes) define la lógica de "tiempo de compilación", es decir, algo que se ejecuta antes del tiempo de ejecución, por lo que solo las entradas constantes están disponibles como Value y otras entradas solo están disponibles como Placeholder.
| Nombres | En "Semantics" | En "Restricciones" |
|---|---|---|
| Funciones globales | Function |
Function |
| Entradas constantes | Value |
Value |
| Entradas no constantes | Value |
Placeholder |
| Salidas | Value |
Placeholder |
| Definiciones locales | Depende de la definición | Depende de la definición |
Consideremos una operación de transpose como ejemplo:
%result = "stablehlo.transpose"(%operand) {
permutati<on = dens>e[2, 1, 0<] : t>ensor3xi64
}< : (tenso>r2x>3x2xi32<) - tenso>r2x3x2xi32
Para esta operación, permutation es una constante, por lo que está disponible como Value en la semántica y las restricciones. En cambio, operand y result están disponibles como Value en la semántica, pero solo como Placeholder en las restricciones.
Funciones
Construcción de tipos
No hay funciones que se puedan usar para construir tipos. En cambio, usamos directamente la sintaxis de tipos porque suele ser más concisa. p. ej., (tensor<E>, tensor<E>) -> (tensor<E>) en lugar de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
Funciones en tipos
element_typese define en los tipos de tensores y los tipos de tensores cuantizados, y devuelve, respectivamente, la parteTensorElementTypeoQuantizedTensorElementTypedelTensorTypeoQuantizedTensorTypecorrespondiente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Valuees un atajo parais_quantized(x) and quantization_dimension(x) is not None.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Valuees un atajo parais_quantized(x) and quantization_dimension(x) is None.is_promotable(x: Type, y: Type) -> boolverifica si el tipoxse puede promover al tipoy. CuandoxyysonQuantizedTensorElementType, la promoción se aplica solo alstorage_type. Actualmente, esta versión específica de la promoción se usa en el contexto del cálculo de la reducción (consulta la RFC para obtener más detalles).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Valuees un acceso directo parais_quantized_tensor_element_type(x).is_type_name(x: Value | Placeholder | Type) -> Value. Disponible para todos los tipos. Por ejemplo,is_float(x)devuelvetruesixes unFloatType. Sixes un valor o un marcador de posición, esta función es un atajo parais_type_name(type(x)).max_value(x: Type) -> Valuedevuelve el valor máximo de unTensorElementType. Sixno es unTensorElementType, devuelveNone.min_value(x: Type) -> Valuedevuelve el valor mínimo posible de unTensorElementType. Sixno es unTensorElementType, devuelveNone.member_name(x: Value | Placeholder | Type) -> Any. Disponible para todas las definiciones de miembrosmember_namede todos los tipos. Por ejemplo,tensor_element_type(x)devuelve la parteTensorElementTypede unTensorTypecorrespondiente. Sixes un valor o un marcador de posición, esta función es un atajo paramember_name(type(x)). Sixno es un tipo que tiene un miembro adecuado, o un valor o un marcador de posición de ese tipo, devuelveNone.is_empty_algorithm(*args: Type)verifica si todos los campos del algoritmo de puntos están configurados comoNone. Esto es necesario, ya que los algoritmos de puntos tienen comportamientos predeterminados definidos por la implementación, por lo que especificar un valor predeterminado sería incorrecto.
Construcción de valores
operation_name(*xs: Value | Type) -> Value. Disponible para todas las operaciones. Por ejemplo,add(lhs, rhs)toma dos valores de tensorlhsyrhs, y devuelve el resultado de evaluar la operaciónaddcon estas entradas. En algunas operaciones, p. ej.,broadcast_in_dim, los tipos de sus resultados son "de carga", es decir, necesarios para evaluar una operación. En este caso, la función toma estos tipos como argumentos.
Funciones sobre valores
Todos los operadores y las funciones de Python están disponibles. Por ejemplo, las notaciones de subscripción y segmentación de Python están disponibles para indexar tensores, tensores cuantificados y tuplas.
to_destination_type(x: Value, destination_type: Type) -> Valuese define en tensores y devuelve el valor convertido dexsegúntype(x)ydestination_typede la siguiente manera:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Se está debatiendo la posibilidad de combinar las operaciones convert, uniform_quantize y uniform_dequantize (#1576).
Después de la combinación, no necesitamos la función anterior y podemos usar el nombre de la operación para convert en su lugar.
is_nan(x: Value) -> Valuese define en tensores y devuelvetruesi todos los elementos dexsonNaNofalseen caso contrario. Sixno es un tensor, devuelveNone.is_sorted(x: Value) -> Valuese define en tensores y devuelvetruesi los elementos dexse ordenan de forma ascendente con respecto al orden lexicográfico ascendente de sus índices ofalseen caso contrario. Sixno es un tensor, devuelveNone.is_unique(x: Value) -> Valuese define en tensores y devuelvetruesixno tiene elementos duplicados ofalseen caso contrario. Sixno es un tensor, devuelveNone.member_name(x: Value) -> Anyse define para todas las definiciones de miembrosmember_namede todos los valores. Por ejemplo,real_part(x)devuelve la parteRealPartde unComplexConstantcorrespondiente. Sixno es un valor que tenga un miembro adecuado, devuelveNone.same(x: Value) -> Valuese define en tensores y devuelvetruesi todos los elementos dexson iguales entre sí ofalseen caso contrario. Si el tensor no tiene elementos, se considera que "todos son iguales entre sí", es decir, la función devuelvetrue. Sixno es un tensor, devuelveNone.split(x: Value, num_results: Value, axis: Value) -> Valuese define en tensores y devuelve segmentos denum_resultsdexa lo largo del ejeaxis. Sixno es un tensor odim(x, axis) % num_results != 0, devuelveNone.is_defined_in_parent_scope(x: Value) -> Valuese define en cadenas y devuelvetruesixes el nombre de una función definida en el mismo alcance que la función principal de la operación pertinente.is_namespaced_op_name(x: Value) -> Valuese define en cadenas y devuelvetruesixes un nombre de operación válido, es decir, si respeta la siguiente expresión regular:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Cálculos de formas
axes(x: Value | Placeholder | Type) -> Valuees un acceso directo pararange(rank(x)).dim(x: Value | Placeholder | Type, axis: Value) -> Valuees un acceso directo parashape(x)[axis].dims(x: Value | Placeholder | Type, axes: List) -> Listes un acceso directo paralist(map(lambda axis: dim(x, axis), axes)).index_space(x: Value | Placeholder | Type) -> Valuese define en tensores y devuelve índicessize(x)para elTensorTypecorrespondiente ordenado en orden lexicográfico ascendente, es decir,[0, ..., 0],[0, ..., 1], …,shape(x) - 1. Sixno es un tipo de tensor, un tipo de tensor cuantificado, o un valor o un marcador de posición de uno de estos tipos, devuelveNone.rank(x: Value | Placeholder | Type) -> Valuees un acceso directo parasize(shape(x)).shape(x: Value | Placeholder | Type) -> Valuese define en la sección "Funciones sobre tipos" a través demember_name.size(x: Value | Placeholder | Type) -> Valuees un acceso directo parareduce(lambda x, y: x * y, shape(x)).
Cálculos de cuantización
def baseline_element_type(x: Value | Placeholder | Type) -> Typees un atajo paraelement_type(baseline_type(x)).baseline_typese define en tipos de tensores y tipos de tensores cuantizados, y los transforma en un "valor de referencia", es decir, un tipo con la misma forma, pero con los parámetros de cuantización del tipo de elemento restablecidos a los valores predeterminados. Esto se usa como un truco útil para comparar los tipos de tensores y tensores cuantificados de manera uniforme, lo que se necesita con bastante frecuencia. Para los tipos cuantificados, esto permite comparar tipos ignorando los parámetros de cuantificación, es decir,shape,storage_type,expressed_type,storage_min,storage_maxyquantization_dimension(para el tipo cuantificado por eje) deben coincidir, peroscalesyzero pointspueden diferir.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantizese define en tipos de tensores cuantificados y los convierte en tipos de tensores de punto flotante. Esto sucede a través de la conversión de elementos cuantificados que representan valores enteros del tipo de almacenamiento en valores de punto flotante correspondientes del tipo expresado con el punto cero y la escala asociados al tipo de elemento cuantificado.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantizese define en los tipos de tensores de punto flotante y los convierte en tipos de tensores cuantizados. Esto sucede a través de la conversión de valores de punto flotante del tipo expresado en valores enteros correspondientes del tipo de almacenamiento, con el punto cero y la escala asociados al tipo de elemento cuantificado.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantizese usa para especificar cálculos por elementos en tensores cuantificados. Descuantiza, es decir, convierte los elementos cuantizados en sus tipos expresados, luego realiza una operación y, luego, cuantiza, es decir, convierte los resultados en sus tipos de almacenamiento. Por el momento, esta función solo funciona para la cuantificación por tensor. La cuantificación por eje es un trabajo en curso (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_opse usa para especificar la cuantización solo de pesos para la operación híbrida que acepta el lado izquierdo en punto flotante y el lado derecho en tipos cuantificados. Este método desantifica las entradas cuantificadas en sus tipos expresados y realiza el cálculo en números de punto flotante. El tipo de elemento del tensor a la izquierda de números de punto flotante y el tipo expresado del tensor a la derecha cuantificado deben ser idénticos.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Cálculos de cuadrícula
cross_partition(replica_groups: Value) -> Value. Consulta la sección "cross_replica" anterior.cross_replica(replica_groups: Value) -> Value. Consulta la sección "cross_replica" anterior.cross_replica_and_partition(replica_groups: Value) -> Value. Consulta la sección "cross_replica_and_partition" anterior.flattened_ids(replica_groups: Value) -> Value. Consulta la sección "flattened_ids" anterior.
Dinamismo
Los valores de StableHLO pueden tener tamaños de dimensión dinámicos, p.ej., tensor<?xi64>.
Sin embargo, los valores de StableHLO no pueden tener una cantidad dinámica de dimensiones (dinamismo sin clasificación, p.ej., tensor<*xi64>). Se permite que los operandos y los resultados usen tamaños de dimensión dinámicos, incluso si hay restricciones en los tamaños. Las restricciones se verificarán de forma estática si es posible; de lo contrario, se diferirán para el tiempo de ejecución y las discrepancias generarán un comportamiento indefinido. Consulta los ejemplos a continuación.
Discrepancias de forma para operaciones unarias a nivel del elemento
Considera el siguiente programa de ejemplo:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Este tipo de programa es inusual, ya que no es común conocer la forma del resultado, pero no la forma de la entrada. Sin embargo, este es un programa de StableHLO válido. No es posible validar de forma estática la operación abs en este programa, ya que se desconoce la forma exacta del operando. Sin embargo, las formas son compatibles, y esto se puede verificar de forma estática: ? podría resultar ser 2 en el tiempo de ejecución, y no habría ningún problema. Sin embargo, ? también podría ser algún otro número entero, en cuyo caso el comportamiento es indefinido.
Ten en cuenta que, si el tamaño de una dimensión es dinámico en el resultado, no puede haber un comportamiento indefinido. De hecho, no hay un tamaño "esperado", por lo que no puede haber una discrepancia.
Discrepancias de forma para operaciones binarias por elemento
Considera el siguiente programa de ejemplo:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
En el caso de las operaciones binarias por elemento, las formas de las entradas y el resultado deben coincidir en el tiempo de ejecución. En el tiempo de compilación, las dimensiones estáticas deben ser iguales; de lo contrario, solo deben ser compatibles. Si alguna dimensión es dinámica en las entradas, podría haber un comportamiento indefinido en el tiempo de ejecución, ya que el tamaño dinámico podría no coincidir con el tamaño correspondiente en el otro operando (ya sea estático o dinámico). Si todas las entradas son estáticas, no importa si el resultado es dinámico o no: las dimensiones conocidas de forma estática se verificarán de forma estática, y las dimensiones dinámicas no imponen ninguna restricción.
Desajustes de forma para las operaciones que toman su forma de salida como un operando
Considera el siguiente programa de ejemplo:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Los valores del operando de forma en el tiempo de ejecución deben coincidir con la forma del resultado; de lo contrario, el comportamiento no está definido. Es decir, en el tiempo de ejecución, %arg0 debe tener un valor de dense<[3, 4]> : tensor<2xi32>. Si el operando de forma es constante, esto se puede verificar de forma estática. Si la forma del resultado es completamente dinámica, no puede haber una falta de coincidencia.