StableHLO es un conjunto de operaciones para operaciones de alto nivel (HLO) en modelos de aprendizaje automático (AA). StableHLO funciona como una capa de portabilidad entre diferentes frameworks de AA y compiladores de AA: los frameworks de AA que producen programas StableHLO son compatibles con los compiladores de AA que consumen programas StableHLO.
Nuestro objetivo es simplificar y acelerar el desarrollo del AA creando una mayor interoperabilidad entre varios frameworks de AA (como TensorFlow, JAX y PyTorch) y compiladores de AA (como IREE y XLA). En ese sentido, este documento proporciona una especificación para el lenguaje de programación StableHLO.
Esta especificación contiene tres secciones principales. Primero, en la sección Programas, se describe la estructura de los programas StableHLO que constan de funciones de StableHLO que, a su vez, constan de operaciones de StableHLO. Dentro de esa estructura, la sección Ops especifica la semántica de las operaciones individuales. La sección Execution proporciona semántica para todas estas operaciones que se ejecutan juntas dentro de un programa. Por último, en la sección Notación, se analiza la notación utilizada en toda la especificación.
Programas
Program ::= {Func}
Los programas StableHLO constan de una cantidad arbitraria de funciones StableHLO.
A continuación, se muestra un programa de ejemplo con una función @main
que tiene 3 entradas (%image
, %weights
y %bias
) y 1 resultado. El cuerpo de la función
tiene 6 operaciones.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funciones
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Las funciones estables (que también se denominan funciones con nombre) tienen un identificador, entradas y salidas, y un cuerpo. En el futuro, planeamos agregar metadatos adicionales para las funciones a fin de lograr una mejor compatibilidad con HLO (#425, #626, #740 y #744).
Identificadores
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Los identificadores estables son similares a los identificadores en muchos lenguajes de programación, con dos peculiaridades: 1) todos los identificadores tienen sigils que distinguen diferentes tipos de identificadores, 2) los identificadores de valores pueden ser completamente numéricos para simplificar la generación de programas StableHLO.
Tipos
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Los tipos de StableHLO se clasifican en tipos de valores (que también se denominan tipos de primera clase), que representan valores de StableHLO y tipos que no son de valores que describen otros elementos del programa. Los tipos StableHLO son similares a los tipos en muchos lenguajes de programación, cuya particularidad principal es la naturaleza específica del dominio de StableHLO, que genera algunos resultados inusuales (p.ej., los tipos escalares no son tipos de valores).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}
Los tipos de tensor representan tensores, es decir, arrays multidimensionales. Tienen una forma y un tipo de elemento, en los que una forma representa tamaños de dimensión no negativos en orden ascendente de las dimensiones correspondientes (que también se denominan ejes) numeradas del 0
al R-1
. La cantidad de dimensiones R
se denomina clasificación. Por ejemplo, tensor<2x3xf32>
es un tipo de tensor con forma 2x3
y tipo de elemento f32
. Tiene dos dimensiones (o, en otras palabras, dos ejes): 0 y 1, cuyos tamaños son 2 y 3. Su clasificación es 2.
Esto define la compatibilidad con formas estáticas en las que los tamaños de dimensión se conocen estáticamente. En el futuro, planeamos agregar compatibilidad con formas dinámicas en las que los tamaños de las dimensiones sean desconocidos de forma parcial o total (#8). Además, planeamos explorar la extensión de los tipos de tensores más allá de los tamaños de dimensión y los tipos de elementos, por ejemplo, para incluir diseños (#629) y dispersión (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nombre | Tipo | Restricciones |
---|---|---|
storage_type |
tipo de número entero | (C1-C4) o C9 |
storage_min |
constante de número entero | (C2), (C4) y (C8) |
storage_max |
constante de número entero | (C3), (C4) y (C8) |
expressed_type |
tipo de punto flotante | (C1), (C5) |
quantization_dimension |
constante de número entero opcional | (C11-C13). |
scales |
número variádico de constantes de punto flotante | (C5-C7), (C10), (C11) y (C13) |
zero_points |
número variádico de constantes de números enteros | (C8-C10) |
Los tipos de elementos cuantificados representan valores de números enteros de un tipo de almacenamiento en el rango de storage_min
a storage_max
(inclusive) que corresponden a valores de punto flotante de un tipo expresado. Para un número entero determinado i
, el valor de punto flotante correspondiente f
se puede calcular como f = (i - zero_point) * scale
, en el que scale
y zero_point
se denominan parámetros de cuantización. Los valores storage_min
y storage_max
son opcionales en la gramática, pero tienen valores predeterminados de min_value(storage_type)
y max_value(storage_type)
, respectivamente. Los tipos de elementos cuantizados tienen las siguientes restricciones:
- (C1)
num_bits(storage_type) < num_bits(expressed_type)
. - (C2)
type(storage_min) = storage_type
. - (C3)
type(storage_max) = storage_type
. - (C4)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
. - (C5)
type(scales...) = expressed_type
. - (C6)
0 < scales
. - (C7)
is_finite(scales...)
. - (C8)
storage_min <= zero_points <= storage_max
. - (C9)
type(zero_points...) = storage_type
. - (C10)
size(scales) = size(zero_points)
. - (C11) Si es
is_empty(quantization_dimension)
, entoncessize(scales) = 1
. - (C12)
0 <= quantization_dimension
.
Por el momento, QuantizationScale
es una constante de punto flotante, pero hay gran interés en las escalas basadas en números enteros, representadas con multiplicadores y cambios. Tenemos pensado explorar este tema en un futuro cercano
(#1404).
Hay un debate en curso sobre la semántica de QuantizationZeroPoint
, que incluye el tipo, los valores y si puede haber solo uno o varios puntos cero en un tipo de tensor cuantificado. En función de los resultados de este análisis, la especificación en torno a los puntos cero podría cambiar en el futuro (#1405).
Otro debate en curso involucra la semántica de QuantizationStorageMin
y QuantizationStorageMax
para determinar si se debe imponer alguna restricción a estos valores y a los de tensores cuantificados (#1406).
Por último, planeamos explorar la representación de escalas desconocidas y puntos cero, de manera similar a cómo planeamos explorar la representación de tamaños de dimensión desconocidos (#1407).
Los tipos de tensores cuantificados representan tensores con elementos cuantificados. Estos tensores son exactamente los mismos que los tensores normales, con la excepción de que sus elementos tienen tipos de elementos cuantificados, en lugar de tipos de elementos regulares.
En los tensores cuantizados, la cuantización puede ser por tensor, es decir, tener un scale
y zero_point
para todo el tensor o puede ser por eje, es decir, tener varios scales
y zero_points
, un par por porción de una dimensión en particular quantization_dimension
. De manera más formal, en un tensor t
con cuantización por eje, hay segmentos dim(t, quantization_dimension)
de quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
, etc. Todos los elementos de la i
.a porción usan scales[i]
y zero_points[i]
como sus parámetros de cuantización. Los tipos de tensores cuantificados tienen las siguientes restricciones:
- Para la cuantización por tensor:
- Sin restricciones adicionales.
- Para la cuantización por eje:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Los tipos de token representan tokens, es decir, valores opacos que producen y consumen algunas operaciones. Los tokens se usan para imponer el orden de ejecución a las operaciones como se describe en la sección Ejecución.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Los tipos de tupla representan tuplas, es decir, listas heterogéneas. Las tuplas son una función heredada que solo existe para ser compatible con HLO. En HLO, las tuplas se usan para representar entradas y salidas variables. En StableHLO, las entradas y salidas variables son compatibles de forma nativa y el único uso de tuplas en StableHLO es representar de manera integral la ABI de HLO, en la que, p.ej., T
, tuple<T>
y tuple<tuple<T>>
pueden ser sustancialmente diferentes según una implementación en particular. En el futuro, planeamos realizar cambios en la ABI de HLO, lo que podría permitirnos quitar los tipos de tuplas de StableHLO (#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Los tipos de elementos representan elementos de tipos de tensores. A diferencia de muchos lenguajes de programación, estos tipos no son de primera clase en StableHLO. Esto significa que los programas estables no pueden representar directamente valores de estos tipos (como resultado, es idiomático representar valores escalares de tipo T
con valores de tensores de 0 dimensiones de tipo tensor<T>
).
- Tipo booleano representa los valores booleanos
true
yfalse
. - Los tipos de número entero pueden ser con firma (
si
) o sin firma (ui
) y tener uno de los anchos de bits admitidos (4
,8
,16
,32
o64
). Los tipos desiN
firmados representan valores de números enteros de-2^(N-1)
a2^(N-1)-1
inclusive, y los tiposuiN
sin firma representan valores de números enteros de0
a2^N-1
inclusive. - Los tipos de punto flotante pueden ser uno de los siguientes:
- Los tipos
f8E4M3FN
yf8E5M2
correspondientes a las codificacionesE4M3
yE5M2
respectivamente del formato FP8 descrito en Formatos FP8 para el aprendizaje profundo. - Los tipos
f8E4M3FNUZ
yf8E5M2FNUZ
correspondientes a las codificacionesE4M3
yE5M2
de los formatos FP8 descritos en Formatos numéricos de 8 bits para redes neuronales profundas. - Es un tipo
f8E4M3B11FNUZ
que corresponde a la codificaciónE4M3
de los formatos FP8 descritos en Inferencia y entrenamiento de punto flotante híbrido de 8 bits (HFP8) para redes neuronales profundas. - Tipo
bf16
correspondiente al formatobfloat16
descrito en BFloat16: El secreto para un alto rendimiento en Cloud TPU. - Los tipos
f16
,f32
yf64
correspondientes a los formatosbinary16
(“precisión media”),binary32
(“precisión simple”) ybinary64
(“precisión doble”) respectivamente que se describen en el estándar IEEE 754
- Los tipos
- Los tipos complejos representan valores complejos que tienen una parte real y una parte imaginaria del mismo tipo de elemento. Los tipos complejos admitidos son
complex<f32>
(ambas partes son del tipof32
) ycomplex<f64>
(ambas partes son del tipof64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Los tipos de funciones representan funciones con nombre y funciones anónimas. Tienen tipos de entrada (la lista de tipos en el lado izquierdo de ->
) y tipos de salida (la lista de tipos a la derecha de ->
). En muchos lenguajes de programación, los tipos de funciones son de primera clase, pero no en StableHLO.
StringType ::= 'string'
Tipo de string representa secuencias de bytes. A diferencia de muchos lenguajes de programación, el tipo de string no es la primera clase en StableHLO y solo se usa para especificar metadatos estáticos para los elementos del programa.
Operaciones
Las operaciones estables (que también se denominan ops) representan un conjunto cerrado de operaciones de alto nivel en modelos de aprendizaje automático. Como se mencionó antes, la sintaxis de StableHLO está inspirada en gran medida en MLIR, que no es necesariamente la alternativa más ergonómica, pero podría decirse que es la mejor opción para el objetivo de StableHLO de crear más interoperabilidad entre los frameworks de AA y los compiladores de AA.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Las operaciones estables (que también se llaman ops) tienen un nombre, entradas y salidas, y una firma. El nombre consta del prefijo stablehlo.
y un nombre mnemotécnico que identifica de forma única una de las operaciones admitidas. Consulta la siguiente lista para obtener una lista completa de todas las operaciones admitidas.
Por el momento, los programas StableHLO en la naturaleza a veces contienen operaciones que no se describen en este documento. En el futuro, planeamos absorber estas operaciones en el conjunto de operaciones StableHLO o prohibir que aparezcan en los programas StableHLO. Mientras tanto, la siguiente es la lista de estas operaciones:
builtin.module
,func.func
,func.call
yfunc.return
(#425).chlo
(#602).- Categoría "No en HLO" de operaciones StableHLO. Inicialmente, eran parte del conjunto de operaciones StableHLO, pero luego se determinó que no encajaban bien:
broadcast
,create_token
,cross-replica-sum
,dot
,einsum
,torch_index_select
,unary_einsum
(#3). - Categoría “Dynamism” de las operaciones StableHLO. Se iniciaron a partir de MHLO, pero aún no las especificamos:
compute_reshape_shape
,cstr_reshapable
,dynamic_broadcast_in_dim
,dynamic_conv
,dynamic_gather
,dynamic_iota
,dynamic_pad
,dynamic_reshape
,real_dynamic_slice
,set_dimension_size
(#8). - Cálculos de forma, incluidas las operaciones
arith
,shape
ytensor
(#8).
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Las operaciones consumen entradas y producen salidas. Las entradas se categorizan en valores de entrada (que se procesan durante la ejecución), funciones de entrada (proporcionadas de forma estática, porque en StableHLO las funciones no son valores de primera clase) y atributos de entrada (también proporcionados de forma estática). El tipo de entradas y salidas que consume y produce una op depende de su nemotécnica. Por ejemplo, la op add
consume 2 valores de entrada y produce 1 valor de salida. En comparación, la op select_and_scatter
consume 3 valores de entrada, 2 funciones de entrada y 3 atributos de entrada.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Las funciones de entrada (también llamadas funciones anónimas) son muy similares a las funciones con nombre, con la excepción de que 1) no tienen un identificador (por eso el nombre es "anónimo"), y 2) no declaran tipos de salida (los tipos de salida se deducen de la operación return
dentro de la función).
La sintaxis de las funciones de entrada incluye una parte que no se usa en este momento (consulta la producción de Unused
más arriba) que brinda compatibilidad con MLIR. En MLIR, existe un concepto más general de “regiones” que pueden tener varios “bloques” de operaciones conectados a través de operaciones de salto. Estos bloques tienen IDs que corresponden a la producción de Unused
, de modo que se puedan distinguir entre sí.
StableHLO no tiene operaciones de salto, por lo que la parte correspondiente de la sintaxis de MLIR no se usa (pero sigue estando allí).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Los atributos de entrada tienen un nombre y un valor que es una de las constantes admitidas. Son la forma principal de especificar metadatos estáticos para elementos del programa. Por ejemplo, la operación concatenate
usa el atributo dimension
para especificar la dimensión con la que se concatenan sus valores de entrada. De manera similar, la op slice
usa varios atributos, como start_indices
y limit_indices
, a fin de especificar los límites que se usan para dividir el valor de entrada.
Por el momento, los programas StableHLO en la naturaleza a veces contienen atributos que no se describen en este documento. En el futuro, planeamos absorber estos atributos en el conjunto de operaciones StableHLO o prohibir que aparezcan en los programas StableHLO. Mientras tanto, esta es la lista de estos atributos:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Metadatos de ubicación (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma de operaciones consta de los tipos de todos los valores de entrada (la lista de tipos en el lado izquierdo de ->
) y los tipos de todos los valores de salida (la lista de tipos en el lado derecho de ->
). En sentido estricto, los tipos de entrada son redundantes y los tipos de salida casi siempre son redundantes (porque para la mayoría de las operaciones StableHLO, los tipos de salida se pueden inferir de las entradas). No obstante, la firma de operaciones forma parte deliberadamente de la sintaxis de StableHLO para brindar compatibilidad con MLIR.
A continuación, se muestra un ejemplo de operación mnemotécnica select_and_scatter
. Consume 3 valores de entrada (%operand
, %source
y %init_value
), 2 funciones de entrada y 3 atributos de entrada (window_dimensions
, window_strides
y padding
). Ten en cuenta que la firma de la op solo incluye los tipos de sus valores de entrada (pero no los tipos de funciones y atributos de entrada que se proporcionan intercalados).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Las constantes StableHLO tienen un literal y un tipo que, en conjunto, representan un valor StableHLO. Por lo general, el tipo forma parte de la sintaxis constante, excepto cuando no es ambiguo (p.ej., una constante booleana tiene el tipo i1
sin ninguna ambigüedad, mientras que una constante de número entero puede tener varios tipos posibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Las constantes booleanas representan valores booleanos true
y false
. Las constantes booleanas tienen el tipo i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Las constantes de números enteros representan valores de números enteros a través de strings que usan notación decimal o hexadecimal. No se admiten otras bases, p.ej., binaria u octal. Las constantes de número entero tienen las siguientes restricciones:
- (C1)
is_wellformed(integer_literal, integer_type)
.
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Las constantes de punto flotante representan valores de punto flotante a través de cadenas que usan notación decimal o científica. Además, la notación hexadecimal se puede usar para especificar directamente los bits subyacentes en el formato de punto flotante del tipo correspondiente. Las constantes de punto flotante tienen las siguientes restricciones:
- (C1) Si se usa la notación no hexadecimal,
is_wellformed(float_literal, float_type)
. - (C2) Si se usa la notación hexadecimal,
size(hexadecimal_digits) = num_bits(float_type) / 4
.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Las constantes complejas representan valores complejos mediante listas de una parte real (va primero) y una parte imaginaria (va por segundo). Por ejemplo, (1.0, 0.0) : complex<f32>
representa a 1.0 + 0.0i
y (0.0, 1.0) : complex<f32>
representa 0.0 + 1.0i
. El orden en el que estas partes se almacenan en la memoria está definido por la implementación. Las constantes complejas tienen las siguientes restricciones:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
. - (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
.
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Las constantes tensoriales representan valores de tensor mediante listas anidadas especificadas a través de notación NumPy. Por ejemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
representa un valor de tensor con la siguiente asignación de índices a elementos: {0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
, {1, 1} => 5
y {1, 2} => 6
. El orden en el que estos elementos se almacenan en la memoria se define por la implementación. Las constantes de Tensor tienen las siguientes restricciones:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, donde:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, donde:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- de lo contrario,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Las constantes de tensor cuantificadas representan valores de tensor cuantificados que usan la misma notación que las constantes del tensor, con elementos especificados como constantes de su tipo de almacenamiento. Las constantes de tensor cuantificadas tienen las siguientes restricciones:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
. - (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
.
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Los literales de string constan de bytes especificados mediante caracteres ASCII y secuencias de escape. Dado que son independientes de la codificación, la interpretación de estos bytes está definida por la implementación. Los literales de string tienen el tipo string
.
Ops
abs
Semántica
Realiza una operación abs a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros con firma: módulo de números enteros.
- Para números de punto flotante:
abs
de IEEE-754. - Para números complejos: módulo complejo.
- Para tipos cuantizados:
dequantize_op_quantize(abs, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de número entero con signo, punto flotante o tipo complejo, o tensor cuantificado por tensor | (C1-C2). |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de número entero con signo, tipo de punto flotante o tensor cuantificado por tensor | (C1-C2). |
Restricciones
- (C1)
shape(result) = shape(operand)
. - (C2)
baseline_element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.- De lo contrario,
baseline_element_type(operand)
.
Ejemplos
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Semántica
Realiza una adición a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: suma de números enteros.
- Para números de punto flotante:
addition
de IEEE-754. - Para números complejos: suma compleja.
- Para tipos cuantizados:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor o por tensor cuantizado | (C1) |
(I2). | rhs |
tensor o por tensor cuantizado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1) |
Restricciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semántica
Garantiza que las operaciones que producen el inputs
se ejecuten antes de cualquier operación que dependa de result
. La ejecución de esta operación no hace nada; solo existe para establecer dependencias de datos de result
a inputs
.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1). | inputs |
número variable de token |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Ejemplos
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, concatena los valores del tensor operand
de cada proceso junto con all_gather_dim
y produce un tensor result
.
La operación divide la cuadrícula de procesos de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
si eschannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
si eschannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
operands@receiver = [operand@sender for sender in process_group]
para todos losreceiver
enprocess_group
.result@process = concatenate(operands@process, all_gather_dim)
para todos losprocess
enprocess_group
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1), (C6) |
(I2). | all_gather_dim |
constante de tipo si64 |
(C1), (C6) |
(I3). | replica_groups |
Constante tensorial bidimensional de tipo si64 |
(C2-C4). |
(I4). | channel_id |
constante de tipo si64 |
(C5) |
(I5). | use_global_device_ids |
constante de tipo i1 |
(C5) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C6) |
Restricciones
- (C1)
0 <= all_gather_dim < rank(operand)
. - (C2)
is_unique(replica_groups)
. - (C3)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C4)
0 <= replica_groups < size(replica_groups)
. - (C5) Si es
use_global_device_ids = true
, entonceschannel_id > 0
. - (C6)
type(result) = type(operand)
, excepto:dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1)
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
all_reduce
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, aplica una función de reducción computation
a los valores del tensor operand
de cada proceso y produce un tensor result
.
La operación divide la cuadrícula de procesos de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
si eschannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
si eschannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
result@process[result_index] = exec(schedule)
para algún árbol binarioschedule
, en el que:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
es un árbol binario definido por la implementación cuyo recorrido en orden esto_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0]))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C5) y (C6) |
(I2). | replica_groups |
número variable de constantes de tensor unidimensionales de tipo si64 |
(C1-C3). |
(I3). | channel_id |
constante de tipo si64 |
(C4) |
(I4). | use_global_device_ids |
constante de tipo i1 |
(C4) |
(I5). | computation |
la función | (C5) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C6-C7). |
Restricciones
- (C1)
is_unique(replica_groups)
. - (C2)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C3)
0 <= replica_groups < size(replica_groups)
. - (C4) Si es
use_global_device_ids = true
, entonceschannel_id > 0
. - (C5)
computation
tiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, en el queis_promotable(element_type(operand), E)
. - (C6)
shape(result) = shape(operand)
. - (C7)
element_type(result) = E
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
all_to_all
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, se dividen los valores del tensor operand
a lo largo de split_dimension
en partes, dispersa las partes divididas entre los procesos, concatena las partes dispersas junto con concat_dimension
y produce un tensor result
.
La operación divide la cuadrícula de procesos de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0
.cross_partition(replica_groups)
si eschannel_id > 0
.
Luego, dentro de cada process_group
, haz lo siguiente:
split_parts@sender = split(operand@sender, split_count, split_dimension)
para todos lossender
enprocess_group
.scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group]
, dondereceiver_index = process_group.index(receiver)
.result@process = concatenate(scattered_parts@process, concat_dimension)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1-C3) o C9 |
(I2). | split_dimension |
constante de tipo si64 |
(C1), (C2) y (C9) |
(I3). | concat_dimension |
constante de tipo si64 |
(C3) y (C9) |
(I4). | split_count |
constante de tipo si64 |
(C2), (C4), (C8) y (C9) |
(I5). | replica_groups |
Constante tensorial bidimensional de tipo si64 |
(C5-C8). |
(I6). | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C9) |
Restricciones
- (C1)
0 <= split_dimension < rank(operand)
. - (C2)
dim(operand, split_dimension) % split_count = 0
. - (C3)
0 <= concat_dimension < rank(operand)
. - (C4)
0 < split_count
. - (C5)
is_unique(replica_groups)
. - (C6)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
. - (C9)
type(result) = type(operand)
, excepto:dim(result, split_dimension) = dim(operand, split_dimension) / split_count
.dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
// [5, 6],
// [9, 10],
// [13, 14]]
// %result@(1, 0): [[3, 4],
// [7, 8],
// [11, 12],
// [15, 16]]
y
Semántica
Realiza el operador AND a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico AND.
- Para números enteros: AND a nivel de bits.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de tipo booleano o número entero | (C1) |
(I2). | rhs |
tensor de tipo booleano o número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo booleano o número entero | (C1) |
Restricciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semántica
Realiza la operación atan2 a nivel de los elementos en los tensor lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
atan2
de IEEE-754. - Para números complejos: atan2 complejo.
- Para tipos cuantizados:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2). | rhs |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semántica
Calcula los gradientes de varias entradas de batch_norm_training
que se propagan hacia atrás desde grad_output
y produce los tensores grad_operand
, grad_scale
y grad_offset
. De manera más formal, esta operación se puede expresar como una descomposición de operaciones StableHLO existentes mediante la sintaxis de Python de la siguiente manera:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Para tipos cuantizados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1-C3) o C5 |
(I2). | scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2), (C4) y (C5) |
(I3). | mean |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I4). | variance |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I5). | grad_output |
tensor de tipo de punto flotante o por tensor cuantificado | (C2) y (C3) |
(I6). | epsilon |
constante de tipo f32 |
|
(I7). | feature_index |
constante de tipo si64 |
(C1), (C5) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
grad_operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C2) y (C3) |
grad_scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
grad_offset |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
Restricciones
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
ygrad_offset
tienen el mismobaseline_element_type
. - (C3)
operand
,grad_output
ygrad_operand
tienen la misma forma. - (C4)
scale
,mean
,variance
,grad_scale
ygrad_offset
tienen la misma forma. - (C5)
size(scale) = dim(operand, feature_index)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semántica
Normaliza el tensor operand
en todas las dimensiones, excepto feature_index
, y produce un tensor result
. De manera más formal, esta operación se puede expresar como una descomposición de operaciones StableHLO existentes mediante la sintaxis de Python de la siguiente manera:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Para tipos cuantizados, realiza dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1-C7). |
(I2). | scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C3) |
(I3). | offset |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I4). | mean |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C5) |
(I5). | variance |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C6) |
(I6). | epsilon |
constante de tipo f32 |
|
(I7). | feature_index |
constante de tipo si64 |
(C1) o (C3-C6) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante o por tensor cuantificado | (C2) y (C7) |
Restricciones
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,mean
,variance
yresult
tienen el mismobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(mean) = dim(operand, feature_index)
. - (C6)
size(variance) = dim(operand, feature_index)
. - (C7)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semántica
Calcula la media y la varianza de todas las dimensiones, excepto la dimensión feature_index
, y normaliza el tensor operand
, lo que produce los tensores output
, batch_mean
y batch_var
. De manera más formal, esta operación se puede expresar como una descomposición a operaciones StableHLO existentes mediante la sintaxis de Python de la siguiente manera:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Para tipos cuantizados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
(I2). | scale |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C3) |
(I3). | offset |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C4) |
(I4). | epsilon |
constante de tipo f32 |
(C1) o (C3-C6) |
(I5). | feature_index |
constante de tipo si64 |
(C1) o (C3-C6) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
output |
tensor de tipo de punto flotante o por tensor cuantificado | (C7) |
batch_mean |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C5) |
batch_var |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C6) |
Restricciones
- (C1)
0 <= feature_index < rank(operand)
. - (C2)
operand
,scale
,offset
,batch_mean
,batch_var
youtput
tienen el mismobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
. - (C4)
size(offset) = dim(operand, feature_index)
. - (C5)
size(batch_mean) = dim(operand, feature_index)
. - (C6)
size(batch_var) = dim(operand, feature_index)
. - (C7)
baseline_type(output) = baseline_type(operand)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semántica
Realiza una operación de transmisión de bits en el tensor operand
y produce un tensor result
en el que los bits de todo el tensor operand
se vuelven a interpretar con el tipo de tensor result
.
De manera más formal, con E = element_type(operand)
, E' = element_type(result)
y R = rank(operand)
:
- Si es
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
. - Si es
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
. - Si es
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
.
bits
muestra la representación en la memoria de un valor determinado, y su comportamiento está definido por la implementación porque la representación exacta de los tensores está definida por la implementación, y la representación exacta de los tipos de elementos también está definida por la implementación.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o tensor cuantificado | (C1-C2). |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C2). |
Restricciones
- (C1) Dados
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
yR = rank(operand)
:- Si es
num_bits(E') = num_bits(E)
, esshape(result) = shape(operand)
. - Si es
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
para todos los0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Si es
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
para todos los0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Si es
- (C2) Si es
is_complex(operand) or is_complex(result)
, entoncesis_complex(operand) and is_complex(result)
.
Ejemplos
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semántica
Expande las dimensiones o la clasificación de un tensor de entrada mediante la duplicación de los datos en el tensor operand
y produce un tensor result
. De manera más formal, result[result_index] = operand[operand_index]
, donde, para todos los d
en axes(operand)
:
operand_index[d] = 0
si esdim(operand, d) = 1
.- De lo contrario,
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o tensor cuantificado | (C1-C2) o (C5-C6) |
(I2). | broadcast_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C2-C6). |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3) y (C5-C6) |
Restricciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, con la excepción de quequantization_dimension(operand)
,scales(operand)
yzero_points(operand)
pueden diferir dequantization_dimension(result)
,scales(result)
yzero_points(result)
en caso contrario.
- (C2)
size(broadcast_dimensions) = rank(operand)
. - (C3)
0 <= broadcast_dimensions < rank(result)
. - (C4)
is_unique(broadcast_dimensions)
. - (C5) Para todos los
d
enaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si es
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si es
dim(operand, quantization_dimension(operand)) = 1
, entoncesscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
.
Ejemplos
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
caso
Semántica
Genera el resultado a partir de la ejecución de exactamente una función de branches
según el valor de index
. De manera más formal, result = selected_branch()
, en el que sucede lo siguiente:
selected_branch = branches[index]
si es0 <= index < size(branches)
.- De lo contrario,
selected_branch = branches[-1]
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | index |
Tensor de 0 dimensiones de tipo si32 |
|
(I2). | branches |
número variable de funciones | (C1-C4). |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variable de tensores, tensores cuantificados o tokens | (C4) |
Restricciones
- (C1)
0 < size(branches)
. - (C2)
input_types(branches...) = []
. - (C3)
same(output_types(branches...))
. - (C4)
type(results...) = output_types(branches[0])
.
Ejemplos
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CTR
Semántica
Realiza una operación de raíz cúbica a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
rootn(x, 3)
de IEEE-754. - Para números complejos: raíz cúbica compleja.
- Para tipos cuantizados:
dequantize_op_quantize(cbrt, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
ceil
Semántica
Ejecuta el ceil a nivel de elementos del tensor operand
y produce un tensor result
.
Implementa la operación roundToIntegralTowardPositive
de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(ceil, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
colectivo
Semántica
Calcula la descomposición de Cholesky de un lote de matrices.
De manera más formal, para todo i
en index_space(result)
, result[i0, ..., iR-3, :, :]
es una descomposición de Cholesky de a[i0, ..., iR-3, :, :]
en forma de una matriz triangular inferior (si lower
es true
) o una matriz triangular superior (si lower
es false
).
Los valores de salida del triángulo opuesto, es decir, el triángulo superior estricto o al triángulo inferior estricto, respectivamente, están definidos por la implementación.
Si existe i
en la que la matriz de entrada no es una matriz ermitiana definida de forma positiva, entonces el comportamiento es indefinido.
Para tipos cuantizados, realiza dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | a |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1-C3). |
(I2). | lower |
Constante tensorial de 0 dimensiones de tipo i1 |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(a) = baseline_type(result)
. - (C2)
2 <= rank(a)
. - (C3)
dim(a, -2) = dim(a, -1)
.
Ejemplos
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
restringir
Semántica
Sujeta cada elemento del tensor operand
entre un valor mínimo y máximo, y produce un tensor result
. De manera más formal, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
, donde min_element = rank(min) = 0 ? min[] : min[result_index]
, max_element = rank(max) = 0 ? max[] : max[result_index]
. Para tipos cuantizados, realiza dequantize_op_quantize(clamp, min, operand, max, type(result))
.
La imposición de un orden en números complejos implica una semántica sorprendente, por lo que en el futuro planeamos quitar la compatibilidad con números complejos para esta operación (#560).
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | min |
tensor o por tensor cuantizado | (C1), (C3) |
(I2). | operand |
tensor o por tensor cuantizado | (C1-C4). |
(I3). | max |
tensor o por tensor cuantizado | (C2) y (C3) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C4) |
Restricciones
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
. - (C2)
rank(max) = 0 or shape(max) = shape(operand)
. - (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
. - (C4)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, envía el valor del tensor operand
del proceso de origen a los procesos de destino y produce un tensor result
.
La operación divide la cuadrícula de procesos de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0
.cross_partition(replica_groups)
si eschannel_id > 0
.
Luego, el valor de result@process
se obtiene de la siguiente manera:
- Es
operand@process_groups[i, 0]
si existe uni
tal que el proceso esté enprocess_groups[i]
. broadcast_in_dim(constant(0, element_type(result)), [], type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor | (C3) |
(I2). | replica_groups |
número variable de constantes de tensor unidimensionales de tipo si64 |
(C1), (C2) |
(I3). | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor | (C3) |
Restricciones
- (C1)
is_unique(replica_groups)
. - (C2)
0 <= replica_groups < N
, en el queN
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C3)
type(result) = type(operand)
.
Ejemplos
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, envía el valor del tensor operand
del proceso de origen al proceso de destino y produce un tensor result
.
La operación divide la cuadrícula de procesos de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(source_target_pairs)
si eschannel_id <= 0
.cross_partition(source_target_pairs)
si eschannel_id > 0
.
Luego, el valor de result@process
se obtiene de la siguiente manera:
operand@process_groups[i, 0]
, si existe uni
tal queprocess_groups[i, 1] = process
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C5) |
(I2). | source_target_pairs |
Constante tensorial bidimensional de tipo si64 |
(C1-C4). |
(I3). | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1) |
Restricciones
- (C1)
dim(source_target_pairs, 1) = 2
. - (C2)
is_unique(source_target_pairs[:, 0])
. - (C3)
is_unique(source_target_pairs[:, 1])
. - (C4)
0 <= source_target_pairs < N
, en la queN
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C5)
type(result) = type(operand)
.
Ejemplos
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
compare
Semántica
Realiza una comparación con elementos de los tensores lhs
y rhs
de acuerdo con comparison_direction
y compare_type
, y produce un tensor result
.
Los valores de comparison_direction
y compare_type
tienen la siguiente semántica:
Para los tipos de elementos booleanos y enteros:
EQ
:lhs = rhs
NE
:lhs != rhs
GE
:lhs >= rhs
GT
:lhs > rhs
LE
:lhs <= rhs
LT
:lhs < rhs
Para los tipos de elementos de punto flotante con compare_type = FLOAT
, la op implementa las siguientes operaciones IEEE-754:
EQ
:compareQuietEqual
NE
:compareQuietNotEqual
GE
:compareQuietGreaterEqual
GT
:compareQuietGreater
LE
:compareQuietLessEqual
LT
:compareQuietLess
Para los tipos de elementos de punto flotante con compare_type = TOTALORDER
, la operación usa la combinación de operaciones totalOrder
y compareQuietEqual
de IEEE-754. Esta función parece no estar en uso, por lo que, en el futuro, planeamos quitarla (#584).
Para los tipos de elementos complejos, la comparación lexicográfica de los pares (real, imag)
se realiza con los comparison_direction
y compare_type
proporcionados.
La imposición de un orden en números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con números complejos cuando comparison_direction
sea GE
, GT
, LE
o LT
(#560).
Para tipos cuantizados, se realiza dequantize_compare(lhs, rhs,
comparison_direction)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor o por tensor cuantizado | (C1-C3). |
(I2). | rhs |
tensor o por tensor cuantizado | (C1-C2). |
(I3). | comparison_direction |
enum de EQ , NE , GE , GT , LE y LT |
|
(I4). | compare_type |
enum de FLOAT , TOTALORDER , SIGNED y UNSIGNED |
(C3) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo booleano | (C2) |
Restricciones
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
. - (C2)
shape(lhs) = shape(rhs) = shape(result)
. - (C3)
compare_type
se define de la siguiente manera:SIGNED
si esis_signed_integer(element_type(lhs))
.UNSIGNED
si esis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oTOTALORDER
si esis_float(element_type(lhs))
.FLOAT
si esis_complex(element_type(lhs))
.
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complejo
Semántica
Realiza la conversión a nivel de los elementos en un valor complejo a partir de un par de valores imaginarios y reales, lhs
y rhs
, y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de tipo f32 o f64 |
(C1-C3). |
(I2). | rhs |
tensor de tipo f32 o f64 |
(C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo complejo | (C2) y (C3) |
Restricciones
- (C1)
type(lhs) = type(rhs)
. - (C2)
shape(result) = shape(lhs)
. - (C3)
element_type(result)
tiene el tipocomplex<E>
, dondeE = element_type(lhs)
.
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
concatenate
Semántica
Concatena inputs
a lo largo de la dimensión dimension
en el mismo orden que los argumentos dados y produce un tensor result
. De manera más formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, en el que sucede lo siguiente:
id = d0 + ... + dk-1 + kd
.d
es igual adimension
, yd0
, ... son los tamaños ded
a dimensión deinputs
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C6). |
(I2). | dimension |
constante de tipo si64 |
(C2), (C4) y (C6) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C5-C6). |
Restricciones
- (C1)
same(element_type(inputs...))
. - (C2)
same(shape(inputs...))
, exceptodim(inputs..., dimension)
. - (C3)
0 < size(inputs)
. - (C4)
0 <= dimension < rank(inputs[0])
. - (C5)
element_type(result) = element_type(inputs[0])
. - (C6)
shape(result) = shape(inputs[0])
, excepto para:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Ejemplos
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Semántica
Produce un tensor output
a partir de una constante value
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | value |
constante | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
output |
tensor o tensor cuantificado | (C1) |
Restricciones
- (C1)
type(value) = type(output)
.
Ejemplos
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
generar una conversión
Semántica
Realiza una conversión a nivel de elementos de un tipo de elemento a otro en el tensor operand
y produce un tensor result
.
Para las conversiones de boolean-to-any-supported-type, el valor false
se convierte en cero, y el valor true
se convierte en uno. Para las conversiones any-supported-type-to-boolean, un valor cero se convierte en false
y los valores distintos de cero se convierten en true
. A continuación, se muestra cómo funciona esto en tipos complejos.
En el caso de las conversiones que incluyen integer-to-integer, integer-to-floating-point o floating-point-to-floating-point, si el valor de origen se puede representar con exactitud en el tipo de destino, el valor del resultado es esa representación exacta. De lo contrario, el comportamiento está por definir (#180).
En el caso de las conversiones que utilizan floating-point-to-integer, la parte fraccionaria se trunca. Si el valor truncado no se puede representar en el tipo de destino, el comportamiento se define por definir (#180).
Las conversiones que implican complejo a complejo siguen el mismo comportamiento de las conversiones de punto flotante a punto flotante para convertir partes reales y imaginarias.
Para las conversiones de complex-to-any-other-type y complex-to-any-other-type, se ignora el valor imaginario de origen o el valor imaginario de destino se pone en cero, respectivamente. La conversión de la parte real sigue a las conversiones de punto flotante.
En principio, esta operación podría expresar descuantización (conversión de tensores cuantificados a tensores regulares), cuantización (conversión de tensores regulares a tensores cuantificados) y recuantización (conversión entre tensores cuantificados), pero, por el momento, tenemos operaciones dedicadas: uniform_dequantize
para el primer caso de uso y uniform_quantize
para el segundo y el tercero. En el futuro, estas dos operaciones se pueden combinar en convert
(#1576).
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor | (C1) |
Restricciones
- (C1)
shape(operand) = shape(result)
.
Ejemplos
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolución
Semántica
Calcula productos de puntos entre las ventanas de lhs
y porciones de rhs
y produce result
. En el siguiente diagrama, se muestra cómo se calculan los elementos en result
a partir de lhs
y rhs
con un ejemplo concreto.
De manera más formal, considera el siguiente reencuadre de las entradas en términos de lhs
para poder expresar ventanas de lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Este reencuadre usa las siguientes funciones auxiliares:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
, dondej[d] = i[permutation[d]]
.
Si es feature_group_count = 1
y batch_group_count = 1
, para todos los output_spatial_index
en index_space(dim(result, output_spatial_dimensions...))
, result[result_shape(:, output_spatial_index, :)] = dot_product
, en el que sucede lo siguiente:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
. Parece que esta función no se usa, por lo que planeamos quitarla (#1181) en el futuro.dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Si es feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Si es batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Para tipos cuantizados, se realiza dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor o por tensor cuantizado | (C1), (C10-C11), (C14) (C25), (C27-C30) |
(I2). | rhs |
tensor o tensor cuantificado | (C1), (C14-C16), (C25) y (C27-C32) |
(I3). | window_strides |
Constante tensorial unidimensional de tipo si64 |
(C2-C3) o (C25) |
(I4). | padding |
Constante tensorial bidimensional de tipo si64 |
(C4) o (C25) |
(I5). | lhs_dilation |
Constante tensorial unidimensional de tipo si64 |
(C5-C6) o (C25) |
(I6). | rhs_dilation |
Constante tensorial unidimensional de tipo si64 |
(C7-C8) o (C25) |
(I7). | window_reversal |
Constante tensorial unidimensional de tipo i1 |
(C9) |
(I8). | input_batch_dimension |
constante de tipo si64 |
(C10), (C13) y (C25) |
(I9). | input_feature_dimension |
constante de tipo si64 |
(C11) o (C13-C14) |
(I10). | input_spatial_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C12), (C13) y (C25) |
(I11). | kernel_input_feature_dimension |
constante de tipo si64 |
(C14) o (C18) |
(I12). | kernel_output_feature_dimension |
constante de tipo si64 |
(C15-C16), (C18), (C25), (C32) |
(I13). | kernel_spatial_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C17-C18) o (C25) |
(I14). | output_batch_dimension |
constante de tipo si64 |
(C20) o C25 |
(I15). | output_feature_dimension |
constante de tipo si64 |
(C20), (C25) y (C33) |
(I16). | output_spatial_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C19-C20) o (C25) |
(I17). | feature_group_count |
constante de tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18). | batch_group_count |
constante de tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19). | precision_config |
cantidad variable de enumeraciones de DEFAULT , HIGH y HIGHEST |
(C24) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o tensor cuantificado | (C25-C28), (C30-C31) y (C33) |
Restricciones
- (C1)
N = rank(lhs) = rank(rhs)
. - (C2)
size(window_strides) = N - 2
. - (C3)
0 < window_strides
. - (C4)
shape(padding) = [N - 2, 2]
. - (C5)
size(lhs_dilation) = N - 2
. - (C6)
0 < lhs_dilation
. - (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
. - (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
. - (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
se define de la siguiente manera:dim(lhs, input_batch_dimension) / batch_group_count
si esresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
si esresult_dim = output_feature_dimension
.- De lo contrario,
num_windows
, donde: output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si la operación usa tensores cuantificados, haz lo siguiente:
- (C28)
is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result)
. - (C29)
storage_type(lhs) = storage_type(rhs)
. - (C30)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C31) Si es
is_per_tensor_quantized(rhs)
, entoncesis_per_tensor_quantized(result)
. - (C32) Si es
is_per_axis_quantized(rhs)
, entoncesquantization_dimension(rhs) = kernel_output_feature_dimension
. - (C33) Si es
is_per_axis_quantized(result)
, entoncesquantization_dimension(result) = output_feature_dimension
.
- (C28)
Ejemplos
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs : [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = dense<4> : tensor<2xi64>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = dense<2> : tensor<2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_reversal = dense<false> : tensor<2xi1>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semántica
Realiza la operación de coseno a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
cos
de IEEE-754. - Para números complejos: coseno complejo.
- Para tipos cuantizados:
dequantize_op_quantize(cosine, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semántica
Realiza un recuento en elementos del número de bits cero iniciales en el tensor operand
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de número entero | (C1) |
Restricciones
- (C1)
type(operand) = type(result)
.
Ejemplos
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semántica
Encapsula una operación call_target_name
definida por la implementación que toma inputs
y called_computations
, y produce results
. Se pueden usar has_side_effect
, backend_config
y api_version
para proporcionar metadatos adicionales definidos por la implementación.
Por el momento, esta operación contiene una colección de metadatos bastante desorganizada que refleja la evolución orgánica de su operación equivalente en el compilador XLA. En el futuro, planeamos unificar estos metadatos (#741).
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1). | inputs |
número variable de valores |
(I2). | call_target_name |
constante de tipo string |
(I3). | has_side_effect |
constante de tipo i1 |
(I4). | backend_config |
constante de tipo string |
(I5). | api_version |
constante de tipo si32 |
(I6). | called_computations |
número variable de constantes de tipo string |
Salidas
Nombre | Tipo |
---|---|
results |
número variable de valores |
Ejemplos
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = "bar",
api_version = 1 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
dividir
Semántica
Realiza la división a nivel de los elementos de los tensores lhs
y divisor rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: división de números enteros que produce el cociente algebraico con cualquier parte fraccionaria descartada.
- Para números de punto flotante:
division
de IEEE-754. - Para números complejos: división compleja.
- Para tipos cuantizados:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2). | rhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semántica
Calcula productos de puntos entre porciones de lhs
y porciones de rhs
, y produce un tensor result
.
De manera más formal, result[result_index] = dot_product
, donde:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
, dondesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
ysize(result_rhs_index) = size(rhs_result_dimensions)
.transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
.
Para tipos cuantizados, se realiza dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Esto solo especifica la semántica para la cuantización por tensor. La cuantización por eje está en desarrollo (#1574). Además, en el futuro, es posible que consideremos agregar compatibilidad con la cuantización híbrida (#1575).
precision_config
controla la compensación entre velocidad y precisión para los cálculos en backends del acelerador. Puede ser una de las siguientes opciones (por el momento, la semántica de estos valores enum no se especifica de forma adecuada, pero planeamos abordar esto en #755):
DEFAULT
: Es el cálculo más rápido, pero la aproximación menos precisa al número original.HIGH
: Es un cálculo más lento, pero una aproximación más precisa al número original.HIGHEST
: Es el cálculo más lento, pero la aproximación más precisa al número original.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor o por tensor cuantizado | (C5-C6), (C9-C10) o (C12-C16) |
(I2). | rhs |
tensor o por tensor cuantizado | (C7-C10) o C12 |
(I3). | lhs_batching_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C1), (C3), (C5), (C9) y (C12) |
(I4). | rhs_batching_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C1), (C4), (C7) y (C9) |
(I5). | lhs_contracting_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C2), (C3), (C6) y (C10) |
(I6). | rhs_contracting_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C2), (C4), (C8) y (C10) |
(I7). | precision_config |
cantidad variable de enumeraciones de DEFAULT , HIGH y HIGHEST |
(C11) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C12), (C14) y (C16) |
Restricciones
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
. - (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
. - (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
. - (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
. - (C5)
0 <= lhs_batching_dimensions < rank(lhs)
. - (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
. - (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
. - (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
. - (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Si la operación usa tensores no cuantificados, haz lo siguiente:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Si la operación usa tensores cuantificados, haz lo siguiente:
- (C14)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
. - (C15)
storage_type(lhs) = storage_type(rhs)
. - (C16)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C17)
zero_points(rhs) = 0
.
- (C14)
Ejemplos
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_slice
Semántica
Extrae una porción de operand
mediante índices de inicio calculados de forma dinámica y produce un tensor result
. start_indices
contiene los índices iniciales de la porción para cada dimensión sujeta a un posible ajuste, y slice_sizes
contiene los tamaños de la porción para cada dimensión. De manera más formal, result[result_index] = operand[operand_index]
, en el que sucede lo siguiente:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1), (C2) y (C4) |
(I2). | start_indices |
número variádico de tensores de 0 dimensiones de tipo de número entero | (C2) y (C3) |
(I3). | slice_sizes |
Constante tensorial unidimensional de tipo si64 |
(C2), (C4) y (C5) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1), (C5) |
Restricciones
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
. - (C3)
same(type(start_indices...))
. - (C4)
0 <= slice_sizes <= shape(operand)
. - (C5)
shape(result) = slice_sizes
.
Ejemplos
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semántica
Produce un tensor result
que es igual al tensor operand
, excepto que la porción que comienza en start_indices
se actualiza con los valores en update
.
De manera más formal, result[result_index]
se define de la siguiente manera:
update[update_index]
si es0 <= update_index < shape(update)
, donde:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- De lo contrario,
operand[result_index]
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1-C4) o (C6) |
(I2). | update |
tensor o por tensor cuantizado | (C2), (C3) y (C6) |
(I3). | start_indices |
número variádico de tensores de 0 dimensiones de tipo de número entero | (C4), (C5) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1) |
Restricciones
- (C1)
type(operand) = type(result)
. - (C2)
element_type(update) = element_type(operand)
. - (C3)
rank(update) = rank(operand)
. - (C4)
size(start_indices) = rank(operand)
. - (C5)
same(type(start_indices...))
. - (C6)
0 <= shape(update) <= shape(operand)
.
Ejemplos
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponencial
Semántica
Realiza una operación exponencial a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
exp
de IEEE-754. - Para números complejos: exponencial compleja.
- Para tipos cuantizados:
dequantize_op_quantize(exponential, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semántica
Realiza exponenciales a nivel de elementos menos una operación en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
expm1
de IEEE-754. - Para números complejos: exponencial compleja menos uno.
- Para tipos cuantizados:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
FFT
Semántica
Realiza las transformaciones inversas y directas de Fourier para entradas y salidas reales y complejas.
fft_type
es una de las siguientes opciones:
FFT
: Reenvía la FFT compleja a compleja.IFFT
: FFT de complejo a complejo inverso.RFFT
: Reenvía la FFT real a compleja.IRFFT
: FFT inverso de real a complejo (es decir, toma complejo, muestra real)
Más formalmente, dada la función fft
, que toma tensores unidimensionales de tipos complejos como entrada, produce tensores unidimensionales de los mismos tipos como salida y calcula la transformación discreta de Fourier:
Para fft_type = FFT
, result
se define como el resultado final de una serie de cálculos L en los que L = size(fft_length)
. Por ejemplo, para L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Además, dada la función ifft
, que tiene la misma firma de tipo y procesa el inverso de fft
, sucede lo siguiente:
En fft_type = IFFT
, result
se define como el inverso de los cálculos para fft_type = FFT
. Por ejemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Además, dada la función rfft
, que toma tensores unidimensionales de tipos de punto flotante, produce tensores unidimensionales de tipos complejos con la misma semántica de punto flotante y funciona de la siguiente manera:
rfft(real_operand) = truncated_result
, dondecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Cuando se calcula la transformación discreta de Fourier para operandos reales, los primeros elementos N/2 + 1
del resultado definen de manera inequívoca el resto del resultado, por lo que el resultado de rfft
se trunca para evitar el cálculo de elementos redundantes).
Para fft_type = RFFT
, result
se define como el resultado final de una serie de cálculos L en los que L = size(fft_length)
. Por ejemplo, para L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Por último, con la función irfft
, que tiene la misma firma de tipo y calcula el inverso de rfft
, haz lo siguiente:
En fft_type = IRFFT
, result
se define como el inverso de los cálculos para fft_type = RFFT
. Por ejemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o de tipo complejo | (C1), (C2), (C4) y (C5) |
(I2). | fft_type |
enum de FFT , IFFT , RFFT y IRFFT |
(C2) y (C5) |
(I3). | fft_length |
Constante tensorial unidimensional de tipo si64 |
(C1), (C3) y (C4) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o de tipo complejo | (C2), (C4) y (C5) |
Restricciones
- (C1)
size(fft_length) <= rank(operand)
. - (C2) La relación entre los tipos de elementos
operand
yresult
varía:- Si
fft_type = FFT
,element_type(operand)
yelement_type(result)
tienen el mismo tipo complejo. - Si
fft_type = IFFT
,element_type(operand)
yelement_type(result)
tienen el mismo tipo complejo. - Si es
fft_type = RFFT
,element_type(operand)
es un tipo de punto flotante yelement_type(result)
es un tipo complejo de la misma semántica de punto flotante. - Si es
fft_type = IRFFT
,element_type(operand)
es un tipo complejo yelement_type(result)
es un tipo de punto flotante de la misma semántica de punto flotante.
- Si
- (C3)
1 <= size(fft_length) <= 3
. - (C4) Si entre
operand
yresult
, hay un tensorreal
de un tipo de punto flotante, entoncesshape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, excepto para:- Si es
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
. - Si es
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
.
- Si es
Ejemplos
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piso
Semántica
Realiza el piso a nivel de elementos del tensor operand
y produce un tensor result
.
Implementa la operación roundToIntegralTowardNegative
de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(floor, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
recopilar
Semántica
Recopila segmentos del tensor operand
de los desplazamientos especificados en start_indices
y produce un tensor result
.
En el siguiente diagrama, se muestra cómo se asignan los elementos de result
a los elementos de operand
con un ejemplo concreto. En el diagrama, se eligen algunos índices result
de ejemplo y se explican en detalle a qué índices de operand
corresponden.
De manera más formal, result[result_index] = operand[operand_index]
, donde:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
se define de la siguiente manera:start_indices[bi0, ..., :, ..., biN]
, en el quebi
son elementos individuales enbatch_index
y:
se inserta en el índiceindex_vector_dim
, siindex_vector_dim
<rank(start_indices)
.- De lo contrario,
[start_indices[batch_index]]
.
- Para
d_operand
enaxes(operand)
:full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
si esd_operand = start_index_map[d_start]
.- De lo contrario,
full_start_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
, en el queoi
son elementos individuales enoffset_index
y0
se inserta en los índices decollapsed_slice_dims
.operand_index = full_start_index + full_offset_index
.
Si indices_are_sorted
es true
, la implementación puede suponer que start_indices
se ordenan con respecto a start_index_map
; de lo contrario, el comportamiento no está definido. De manera más formal, para todos los i1 < i2
de indices(result)
, full_start_index(i1) <= full_start_index(i2)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1), (C7), (C10-C12) y (C14) |
(I2). | start_indices |
tensor de tipo de número entero | (C2), (C3) y (C13) |
(I3). | offset_dims |
Constante tensorial unidimensional de tipo si64 |
(C1), (C4-C5) y (C13) |
(I4). | collapsed_slice_dims |
Constante tensorial unidimensional de tipo si64 |
(C1), (C6-C8) y (C13) |
(I5). | start_index_map |
Constante tensorial unidimensional de tipo si64 |
(C3), (C9) y (C10) |
(I6). | index_vector_dim |
constante de tipo si64 |
(C2), (C3) y (C13) |
(I7). | slice_sizes |
Constante tensorial unidimensional de tipo si64 |
(C8) o (C11-C13) |
(I8). | indices_are_sorted |
constante de tipo i1 |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C5) o (C13-C14) |
Restricciones
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
. - (C2)
0 <= index_vector_dim <= rank(start_indices)
. - (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
. - (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
. - (C5)
0 <= offset_dims < rank(result)
. - (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
. - (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
. - (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
, donde:batch_dim_sizes = shape(start_indices)
, excepto que no se incluye el tamaño de la dimensión destart_indices
correspondiente aindex_vector_dim
.offset_dim_sizes = shape(slice_sizes)
, excepto que no se incluyen los tamaños de dimensión enslice_sizes
correspondientes acollapsed_slice_dims
.combine
colocabatch_dim_sizes
en los ejes correspondientes abatch_dims
yoffset_dim_sizes
en los ejes correspondientes aoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Ejemplos
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
get_dimension_size
Semántica
Produce el tamaño del dimension
determinado de operand
. De manera más formal, result = dim(operand, dimension)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor | (C1) |
(I2). | dimension |
constante de tipo si64 |
(C1) |
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de 0 dimensiones de tipo si32 |
Restricciones
- (C1)
0 <= dimension < rank(operand)
.
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semántica
Extrae el elemento en la posición index
de la tupla operand
y produce un result
. Más formalmente, result = operand[index]
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tuple | (C1), (C2) |
(I2). | index |
constante de tipo si32 |
(C1), (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
Cualquier tipo admitido | (C2) |
Restricciones
- (C1)
0 <= index < size(operand)
. - (C2)
type(result) = tuple_element_types(operand)[index]
.
Ejemplos
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
if
Semántica
Produce el resultado a partir de la ejecución de exactamente una función de true_branch
o false_branch
según el valor de pred
. Más formalmente, result =
pred ? true_branch() : false_branch()
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | pred |
Tensor de 0 dimensiones de tipo i1 |
|
(I2). | true_branch |
la función | (C1-C3). |
(I3). | false_branch |
la función | (C1), (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variable de tensores, tensores cuantificados o tokens | (C3) |
Restricciones
- (C1)
input_types(true_branch) = input_types(false_branch) = []
. - (C2)
output_types(true_branch) = output_types(false_branch)
. - (C3)
type(results...) = output_types(true_branch)
.
Ejemplos
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imagen
Semántica
Extrae la parte imaginaria a nivel de elementos de operand
y produce un tensor result
. De manera más formal, para cada elemento x
: imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o de tipo complejo | (C1), (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1), (C2) |
Restricciones
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.- De lo contrario,
element_type(operand)
.
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
entrada
Semántica
Lee datos de la entrada y produce results
.
La semántica de infeed_config
está definida por la implementación.
results
consisten en valores de carga útil que van primero y un token que va en último lugar. En el futuro, planeamos dividir la carga útil y el token en dos resultados separados para mejorar la claridad (#670).
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1). | token |
token |
(I2). | infeed_config |
constante de tipo string |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variable de tensores, tensores cuantificados o tokens | (C1-C3). |
Restricciones
- (C1)
0 < size(results)
. - (C2)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
.
Ejemplos
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
Iota
Semántica
Rellena un tensor output
con valores en orden creciente a partir de cero a lo largo de la dimensión iota_dimension
. Más formalmente,
output[result_index] = constant(is_quantized(output) ?
quantize(result_index[iota_dimension], element_type(output)) :
result_index[iota_dimension], element_type(output))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | iota_dimension |
si64 |
(C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
output |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
0 <= iota_dimension < rank(output)
.
Ejemplos
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semántica
Realiza una verificación a nivel de elementos si el valor en x
es finito (es decir, no es +Inf, -Inf ni NaN) y produce un tensor y
. Implementa la operación isFinite
de la especificación IEEE-754. Para los tipos cuantizados, el resultado es siempre true
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | x |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
y |
tensor de tipo booleano | (C1) |
Restricciones
- (C1)
shape(x) = shape(y)
.
Ejemplos
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semántica
Realiza una operación de logaritmo a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
log
de IEEE-754. - Para números complejos: logaritmo complejo.
- Para tipos cuantizados:
dequantize_op_quantize(log, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semántica
Realiza el logaritmo a nivel de elementos más una operación en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
logp1
de IEEE-754. - Para números complejos: logaritmo complejo más uno.
- Para tipos cuantizados:
dequantize_op_quantize(log_plus_one, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logística
Semántica
Realiza una operación logística a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
division(1, addition(1, exp(-x)))
de IEEE-754. - Para números complejos: logística compleja.
- Para tipos cuantizados:
dequantize_op_quantize(logistic, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
map
Semántica
Aplica una función de asignación computation
a inputs
junto al dimensions
y produce un tensor result
.
Más formalmente, result[result_index] = computation(inputs...[result_index])
.
Ten en cuenta que, actualmente, no se usan dimensions
y es probable que se quiten en el futuro (#487).
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4). |
(I2). | dimensions |
Constante tensorial unidimensional de tipo si64 |
(C3) |
(I3). | computation |
la función | (C4) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1), (C4) |
Restricciones
- (C1)
shape(inputs...) = shape(result)
. - (C2)
0 < size(inputs) = N
. - (C3)
dimensions = range(rank(inputs[0]))
. - (C4)
computation
tiene el tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
, en el queEi = element_type(inputs[i])
yE' = element_type(result)
.
Ejemplos
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
máxima
Semántica
Realiza una operación máxima a nivel de elementos en los tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: número entero máximo.
- Para números de punto flotante:
maximum
de IEEE-754. - Para números complejos: máximo lexicográfico para el par
(real, imaginary)
. La imposición de un orden en números complejos implica una semántica sorprendente, por lo que en el futuro planeamos quitar la compatibilidad con números complejos para esta operación (#560). - Para tipos cuantizados:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor o por tensor cuantizado | (C1) |
(I2). | rhs |
tensor o por tensor cuantizado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1) |
Restricciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
mínima
Semántica
Realiza una operación mínima a nivel de elementos en los tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico AND.
- Para números enteros: número entero mínimo.
- Para números de punto flotante:
minimum
de IEEE-754. - En el caso de números complejos: mínimo lexicográfico para el par
(real, imaginary)
. La imposición de un orden en números complejos implica una semántica sorprendente, por lo que en el futuro planeamos quitar la compatibilidad con números complejos para esta operación (#560). - Para tipos cuantizados:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor o por tensor cuantizado | (C1) |
(I2). | rhs |
tensor o por tensor cuantizado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1) |
Restricciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplicar
Semántica
Realiza el producto a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico AND.
- Para números enteros: multiplicación de números enteros.
- Para números de punto flotante:
multiplication
de IEEE-754. - Para números complejos: multiplicación compleja.
- Para tipos cuantizados:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor o por tensor cuantizado | (C1) |
(I2). | rhs |
tensor o por tensor cuantizado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negate
Semántica
Realiza una negación a nivel de elementos del tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros con firma: negación de números enteros.
- Para números enteros sin firma: transmisión de bits a número entero con firma, negación de número entero, conversión de bits a número entero sin firma.
- Para números de punto flotante:
negate
de IEEE-754. - Para números complejos: negación compleja.
- Para tipos cuantizados:
dequantize_op_quantize(negate, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
no me encuentro en
Semántica
Realiza NOT a nivel de elementos del tensor operand
y produce un tensor result
.
Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: NOT lógico.
- Para números enteros: NOT a nivel de bits.
Argumentos
Nombre | Tipo | Restricciones |
---|---|---|
operand |
tensor de tipo booleano o número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo booleano o número entero | (C1) |
Restricciones
- (C1)
type(operand) = type(result)
.
Ejemplos
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semántica
Garantiza que las operaciones que producen operand
se ejecuten antes que cualquier operación que dependa de result
y evita que las transformaciones del compilador muevan las operaciones a través de la barrera. Aparte de eso, la operación es una identidad, es decir, result = operand
.
Argumentos
Nombre | Tipo | Restricciones |
---|---|---|
operand |
cantidad variable de tensores, tokens cuantificados por tensor o tokens | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
cantidad variable de tensores, tokens cuantificados por tensor o tokens | (C1) |
Restricciones
- (C1)
type(operand...) = type(result...)
.
Ejemplos
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
o
Semántica
Realiza OR a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: OR a nivel de bits.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de tipo entero o booleano | (C1) |
(I2). | rhs |
tensor de tipo entero o booleano | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo entero o booleano | (C1) |
Restricciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
salida
Semántica
Escribe inputs
en el feed de salida y produce un token de result
.
La semántica de outfeed_config
está definida por la implementación.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1). | inputs |
cantidad variable de tensores o tensores cuantificados |
(I2). | token |
token |
(I3). | outfeed_config |
constante de tipo string |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Ejemplos
%result = "stablehlo.outfeed"(%inputs0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
almohadilla
Semántica
Expande operand
mediante el relleno alrededor del tensor y entre los elementos del tensor con el padding_value
determinado.
edge_padding_low
y edge_padding_high
especifican la cantidad de padding agregado en el extremo inferior (junto al índice 0) y en el extremo alto (junto al índice más alto) de cada dimensión, respectivamente. La cantidad de padding puede ser negativa, y el valor absoluto del padding negativo indica la cantidad de elementos que se quitarán de la dimensión especificada.
interior_padding
especifica la cantidad de padding agregado entre dos elementos en cada dimensión que no puede ser negativo. El padding interior se produce antes del padding de bordes, de modo que el padding de borde negativo quitará elementos del operando con padding interno.
De manera más formal, result[result_index]
se define de la siguiente manera:
operand[operand_index]
si esresult_index = edge_padding_low + operand_index * (interior_padding + 1)
.- De lo contrario,
padding_value
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1), (C2) y (C4) |
(I2). | padding_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C1) |
(I3). | edge_padding_low |
Constante tensorial unidimensional de tipo si64 |
(C1), (C4) |
(I4). | edge_padding_high |
Constante tensorial unidimensional de tipo si64 |
(C1), (C4) |
(I5). | interior_padding |
Constante tensorial unidimensional de tipo si64 |
(C2-C4). |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C3-C6). |
Restricciones
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
. - (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
. - (C3)
0 <= interior_padding
. - (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
.
Ejemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semántica
Se produce el partition_id
del proceso actual.
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de 0 dimensiones de tipo ui32 |
Ejemplos
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
popcnt
Semántica
Realiza un recuento a nivel de elementos del número de bits configurado en el tensor operand
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de número entero | (C1) |
Restricciones
- (C1)
type(operand) = type(result)
.
Ejemplos
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
potencia
Semántica
Realiza la exponente a nivel de elementos del tensor lhs
con el tensor rhs
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: exponente de números enteros.
- Para números de punto flotante:
pow
de IEEE-754. - Para números complejos: exponente complejo.
- Para tipos cuantizados:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2). | rhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semántica
Extrae la parte real a nivel de elementos del operand
y produce un tensor result
. De manera más formal, para cada elemento x
: real(x) = is_complex(x) ? real_part(x) : x
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o de tipo complejo | (C1), (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1), (C2) |
Restricciones
- (C1)
shape(result) = shape(operand)
. - (C2)
element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.- De lo contrario,
element_type(operand)
.
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recibir
Semántica
Recibe datos de un canal con channel_id
y produce results
.
Si is_host_transfer
es true
, la operación transfiere datos desde el host. De lo contrario, se transferirán los datos desde otro dispositivo. Esto significa que está definido
por la implementación. Esta marca duplica la información proporcionada en channel_type
, por lo que, en el futuro, planeamos conservar solo una de ellas (#666).
results
consisten en valores de carga útil que van primero y un token que va en último lugar. En el futuro, planeamos dividir la carga útil y el token en dos resultados separados para mejorar la claridad (#670).
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | token |
token |
(C4) |
(I2). | channel_id |
constante de tipo si64 |
|
(I3). | channel_type |
enum de DEVICE_TO_DEVICE y HOST_TO_DEVICE |
(C1) |
(I4). | is_host_transfer |
constante de tipo i1 |
(C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variable de tensores, tensores cuantificados o tokens | (C2-C4). |
Restricciones
- (C1)
channel_type
se define de la siguiente manera:HOST_TO_DEVICE
si esis_host_transfer = true
,- De lo contrario,
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
. - (C3)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
.
Ejemplos
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
Reducir
Semántica
Aplica una función de reducción body
a inputs
y init_values
junto a dimensions
y produce tensores results
.
El orden de las reducciones está definido por la implementación, lo que significa que body
y init_values
deben formar un monoide a fin de garantizar que la operación produzca los mismos resultados para todas las entradas en todas las implementaciones. Sin embargo, esta condición no se aplica a muchas reducciones populares. Por ejemplo, la suma de punto flotante para body
y cero para init_values
en realidad no forman un monoide porque la suma de punto flotante no es asociativa.
De manera más formal, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, donde:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, donde:
se insertan endimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
para algún árbol binarioschedule
, en el que:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
es un árbol binario completo definido por la implementación cuyo recorrido en orden consiste en lo siguiente:- Valores
input_slices_converted...[index]
, para todos losindex
enindex_space(input_slices_converted)
en el orden lexicográfico ascendente deindex
. - Se intercala con una cantidad definida por la implementación de
init_values_converted
en posiciones definidas por la implementación.
- Valores
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4), (C6) y (C7) |
(I2). | init_values |
número variádico de tensores de 0 dimensiones o tensores cuantificados por tensor | (C2) y (C3) |
(I3). | dimensions |
Constante tensorial unidimensional de tipo si64 |
(C4), (C5) y (C7) |
(I4). | body |
la función | (C6) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C3), (C7) y (C8) |
Restricciones
- (C1)
same(shape(inputs...))
. - (C2)
element_type(inputs...) = element_type(init_values...)
. - (C3)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C4)
0 <= dimensions < rank(inputs[0])
. - (C5)
is_unique(dimensions)
. - (C6)
body
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, en el queis_promotable(element_type(inputs[i]), Ei)
. - (C7)
shape(results...) = shape(inputs...)
, excepto que no se incluyen los tamaños de dimensión deinputs...
que corresponden adimensions
. - (C8)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semántica
Realiza la conversión a nivel de elementos de operand
a otro tipo de punto flotante que usa exponent_bits
y mantissa_bits
, y de vuelta al tipo de punto flotante original, y produce un tensor output
.
Más formalmente:
- Los bits de mantisa del valor original se actualizan para redondear el valor original al valor más cercano representable con
mantissa_bits
usando la semántica deroundToIntegralTiesToEven
. - Entonces, si
mantissa_bits
son menores que la cantidad de bits mantisa del valor original, los bits mantisas se truncan comomantissa_bits
. - Luego, si los bits exponentes del resultado intermedio no caben en el rango proporcionado por
exponent_bits
, el resultado intermedio desborda hasta el infinito con el signo original o se desborda a cero con el signo original. - Para tipos cuantizados, se realiza
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
(I2). | exponent_bits |
constante de tipo si32 |
(C2) |
(I3). | mantissa_bits |
constante de tipo si32 |
(C3) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
output |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(output)
. - (C2)
1 <= exponent_bits
. - (C3)
0 <= mantissa_bits
.
Ejemplos
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, realiza la reducción mediante computations
sobre los valores del tensor operand
de cada proceso, divide el resultado de la reducción junto con scatter_dimension
en partes y dispersa las partes divididas entre los procesos para producir el result
.
La operación divide la cuadrícula de procesos de StableHLO en process_groups
, que se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
si eschannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
si eschannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
para todos lossender
enprocess_group
, dondereceiver_index = process_group.index(receiver)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1), (C2), (C7) y (C8) |
(I2). | scatter_dimension |
constante de tipo si64 |
(C1), (C2) y (C8) |
(I3). | replica_groups |
Constante tensorial bidimensional de tipo si64 |
(C3-C5). |
(I4). | channel_id |
constante de tipo si64 |
(C6) |
(I5). | use_global_device_ids |
constante de tipo i1 |
(C6) |
(I6). | computation |
la función | (C7) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C8-C9). |
Restricciones
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
. - (C2)
0 <= scatter_dimension < rank(operand)
. - (C3)
is_unique(replica_groups)
. - (C4)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C5)
0 <= replica_groups < size(replica_groups)
. - (C6) Si es
use_global_device_ids = true
, entonceschannel_id > 0
. - (C7)
computation
tiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, en el queis_promotable(element_type(operand), E)
. - (C8)
shape(result) = shape(operand)
, excepto:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semántica
Aplica una función de reducción body
a las ventanas de inputs
y init_values
, y produce results
.
En el siguiente diagrama, se muestra cómo se calculan los elementos en results...
a partir de inputs...
mediante un ejemplo concreto.
De manera más formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(consulta reducir) donde:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13) y (C15) |
(I2). | init_values |
número variádico de tensores de 0 dimensiones o tensores cuantificados por tensor | (C1) o (C13) |
(I3). | window_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C4), (C5) y (C15) |
(I4). | window_strides |
Constante tensorial unidimensional de tipo si64 |
(C6), (C7) y (C15) |
(I5). | base_dilations |
Constante tensorial unidimensional de tipo si64 |
(C8), (C9) y (C15) |
(I6). | window_dilations |
Constante tensorial unidimensional de tipo si64 |
(C10), (C11) y (C15) |
(I7). | padding |
Constante tensorial bidimensional de tipo si64 |
(C12) o (C15) |
(I8). | body |
la función | (C13) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1) o (C14-C16) |
Restricciones
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
. - (C2)
same(shape(inputs...))
. - (C3)
element_type(inputs...) = element_type(init_values...)
. - (C4)
size(window_dimensions) = rank(inputs[0])
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(inputs[0])
. - (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
. - (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
. - (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, en el queis_promotable(element_type(inputs[i]), Ei)
. - (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
, donde:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
resto
Semántica
Realiza el resto a nivel de los elementos de los tensores lhs
y divisor rhs
, y produce un tensor result
.
Más formalmente, el signo del resultado se toma del dividendo, y el valor absoluto del resultado siempre es menor que el valor absoluto del divisor.
El resto se calcula como lhs - d * rhs
, en el que d
se calcula de la siguiente manera:
- Para números enteros:
stablehlo.divide(lhs, rhs)
. - Para números de punto flotante:
division(lhs, rhs)
de IEEE-754 con el atributo de redondeoroundTowardZero
. - Para números complejos: Por definir (#997).
- Para tipos cuantizados:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Para los tipos de elementos de punto flotante, esta operación contrasta con la operación remainder
de la especificación IEEE-754, en la que d
es un valor integral más cercano al valor exacto de lhs/rhs
con vínculos a pares.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2). | rhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semántica
Se produce el replica_id
del proceso actual.
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de 0 dimensiones de tipo ui32 |
Ejemplos
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
cambiar forma
Semántica
Realiza el cambio de forma del tensor operand
a uno result
. Conceptualmente, significa mantener la misma representación canónica, pero posiblemente cambiar la forma, p.ej., de tensor<2x3xf32>
a tensor<3x2xf32>
o tensor<6xf32>
.
De manera más formal, result[result_index] = operand[operand_index]
, en el que result_index
y operand_index
tienen la misma posición en el orden lexicográfico de index_space(result)
y index_space(operand)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o tensor cuantificado | (C1-C3). |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C3). |
Restricciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
yquantization_dimension(result)
pueden diferir.
- (C2)
size(operand) = size(result)
. - (C3) Si es
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
reverse
Semántica
Invierte el orden de los elementos en operand
a lo largo del dimensions
especificado y produce un tensor result
. De manera más formal, result[result_index] = operand[operand_index]
, en el que sucede lo siguiente:
operand_index[d] = dim(result, d) - result_index[d] - 1
si esd
endimensions
.- De lo contrario,
operand_index[d] = result_index[d]
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1), (C3) |
(I2). | dimensions |
Constante tensorial unidimensional de tipo si64 |
(C2) y (C3) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1), (C3) |
Restricciones
- (C1)
type(operand) = type(result)
. - (C2)
is_unique(dimensions)
. - (C3)
0 <= dimensions < rank(result)
.
Ejemplos
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semántica
Genera números aleatorios mediante el algoritmo rng_distribution
y produce un tensor result
de una forma shape
determinada.
Si es rng_distribution = UNIFORM
, los números aleatorios se generan siguiendo la distribución uniforme en el intervalo [a, b)
. Si es a >= b
, el comportamiento no está definido.
Si es rng_distribution = NORMAL
, los números aleatorios se generan según la distribución normal con una media = a
y una desviación estándar = b
.
Si es b < 0
, el comportamiento no está definido.
La implementación define exactamente cómo se generan los números aleatorios. Por ejemplo, pueden o no ser deterministas y pueden o no usar el estado oculto.
En conversaciones con muchas partes interesadas, esta operación dejó de estar disponible, por lo que en el futuro planeamos quitarla (#597).
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | a |
Tensor de 0 dimensiones de tipo entero, booleano o punto flotante | (C1), (C2) |
(I2). | b |
Tensor de 0 dimensiones de tipo entero, booleano o punto flotante | (C1), (C2) |
(I3). | shape |
Constante tensorial unidimensional de tipo si64 |
(C3) |
(I4). | rng_distribution |
enum de UNIFORM y NORMAL |
(C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de número entero, booleano o punto flotante | (C1-C3). |
Restricciones
- (C1)
element_type(a) = element_type(b) = element_type(result)
. - (C2) Si es
rng_distribution = NORMAL
, entoncesis_float(a)
. - (C3)
shape(result) = shape
.
Ejemplos
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semántica
Muestra un output
con bits aleatorios uniformes y un estado de salida actualizado output_state
mediante el algoritmo generador de números seudoaleatorio rng_algorithm
, con un estado inicial initial_state
. Se garantiza que el resultado sea una función determinista de initial_state
, pero no que sea determinista entre implementaciones.
rng_algorithm
es una de las siguientes opciones:
DEFAULT
: Es un algoritmo definido por la implementación.THREE_FRY
: Variante definida por la implementación del algoritmo de Threefry.*PHILOX
: Variante definida por la implementación del algoritmo Philox.*
* Consultar: Salmon et al. SC 2011. Números aleatorios paralelos: tan sencillos como 1, 2, 3
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | rng_algorithm |
Enum de DEFAULT , THREE_FRY y PHILOX |
(C2) |
(I2). | initial_state |
Tensor unidimensional de tipo ui64 |
(C1), (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
output_state |
Tensor unidimensional de tipo ui64 |
(C1) |
output |
tensor de tipo de número entero o punto flotante |
Restricciones
- (C1)
type(initial_state) = type(output_state)
. - (C2)
size(initial_state)
se define de la siguiente manera:- definida por la implementación si es
rng_algorithm = DEFAULT
. 2
si esrng_algorithm = THREE_FRY
.2
o3
si esrng_algorithm = PHILOX
.
- definida por la implementación si es
Ejemplos
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semántica
Realiza un redondeo a nivel de elementos hacia el número entero más cercano, rompiendo los empates de cero en el tensor operand
y produce un tensor result
. Implementa la operación roundToIntegralTiesToAway
de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(round_nearest_afz, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semántica
Realiza un redondeo a nivel de elementos hacia el número entero más cercano, rompiendo los empates hacia el número entero par en el tensor operand
y produce un tensor result
. Implementa la operación roundToIntegralTiesToEven
de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(round_nearest_even, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante o por tensor cuantificado | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
RQR
Semántica
Realiza una operación de raíz cuadrada recíproca a nivel de los elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
rSqrt
de IEEE-754. - Para números complejos: raíz cuadrada recíproca compleja.
- Para tipos cuantizados:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
scatter
Semántica
Produce tensores results
que son iguales a inputs
, excepto que varias porciones especificadas por scatter_indices
se actualizan con los valores updates
mediante update_computation
.
En el siguiente diagrama, se muestra cómo se asignan los elementos de updates...
a los elementos de results...
con un ejemplo concreto. En el diagrama, se eligen algunos índices updates...
de ejemplo y se explican en detalle a qué índices results...
corresponden.
De manera más formal, para todos los update_index
en index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
se define de la siguiente manera:scatter_indices[si0, ..., :, ..., siN]
, en el quesi
son elementos individuales enupdate_scatter_index
y:
se inserta en el índiceindex_vector_dim
, siindex_vector_dim
<rank(scatter_indices)
.- De lo contrario,
[scatter_indices[update_scatter_index]]
.
- Para
d_input
enaxes(inputs[0])
:full_start_index[d_input] = start_index[d_start]
si esd_input = scatter_dims_to_operand_dims[d_start]
.- De lo contrario,
full_start_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
, en el quewi
son elementos individuales enupdate_window_index
y0
se inserta en los índices deinserted_window_dims
.result_index = full_start_index + full_window_index
.
Por lo tanto, results = exec(schedule, inputs)
, donde:
schedule
es una permutación definida por la implementación deindex_space(updates[0])
.exec([update_index, ...], results) = exec([...], updated_results)
, donde:- Si
result_index
está dentro de los límites deshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
es una copia deresults
conresults...[result_index]
establecido enupdated_values...
.- De lo contrario
updated_results = results
.
- Si
exec([], results) = results
.
Si indices_are_sorted
es true
, la implementación puede suponer que scatter_indices
se ordenan con respecto a scatter_dims_to_operand_dims
; de lo contrario, el comportamiento no está definido. De manera más formal, para todos los i1 < i2
de indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Si unique_indices
es true
, la implementación puede suponer que todos los índices result_index
que se dispersan son únicos. Si unique_indices
es true
, pero los índices que se dispersan no son únicos, el comportamiento no está definido.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1), (C2), (C4-C6), (C10), (C13), (C15-C16) |
(I2). | scatter_indices |
tensor de tipo de número entero | (C4), (C11) y (C14) |
(I3). | updates |
cantidad variádica de tensores o tensores cuantificados por tensor | (C3-C6) o (C8) |
(I4). | update_window_dims |
Constante tensorial unidimensional de tipo si64 |
(C2), (C4), (C7) y (C8) |
(I5). | inserted_window_dims |
Constante tensorial unidimensional de tipo si64 |
(C2), (C4), (C9) y (C10) |
(I6). | scatter_dims_to_operand_dims |
Constante tensorial unidimensional de tipo si64 |
(C11-C13). |
(I7). | index_vector_dim |
constante de tipo si64 |
(C4), (C11) y (C14) |
(I8). | indices_are_sorted |
constante de tipo i1 |
|
(I9). | unique_indices |
constante de tipo i1 |
|
(I10). | update_computation |
la función | (C15) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C15-C17). |
Restricciones
- (C1)
same(shape(inputs...))
. - (C2)
rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
. - (C3)
same(shape(updates...))
. - (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
, donde:update_scatter_dim_sizes = shape(scatter_indices)
, excepto que no se incluye el tamaño de la dimensión descatter_indices
correspondiente aindex_vector_dim
.update_window_dim_sizes <= shape(inputs[0])
, excepto que no se incluyen los tamaños de dimensión eninputs[0]
correspondientes ainserted_window_dims
.combine
colocaupdate_scatter_dim_sizes
en los ejes correspondientes aupdate_scatter_dims
yupdate_window_dim_sizes
en los ejes correspondientes aupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
. - (C6)
element_type(updates...) = element_type(inputs...)
. - (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
. - (C9)
is_unique(inserted_window_dims) and is_sorted(update_window_dims)
. - (C10)
0 <= inserted_window_dims < rank(inputs[0])
. - (C11)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C12)
is_unique(scatter_dims_to_operand_dims)
. - (C13)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C14)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C15)
update_computation
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, en el queis_promotable(element_type(inputs[i]), Ei)
. - (C16)
shape(inputs...) = shape(results...)
. - (C17)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10], [11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
// [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [2, 3],
inserted_window_dims = [0],
scatter_dims_to_operand_dims = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
// [[1, 2], [5, 6], [7, 8], [7, 8]],
// [[10, 11], [12, 13], [14, 15], [16, 17]],
// [[18, 19], [20, 21], [21, 22], [23, 24]]
// ]
select
Semántica
Produce un tensor result
en el que cada elemento se selecciona del tensor on_true
o on_false
según el valor del elemento correspondiente de pred
.
De manera más formal, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, donde pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Para tipos cuantizados, realiza dequantize_select_quantize(pred, on_true, on_false, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | pred |
tensor de tipo i1 |
(C1) |
(I2). | on_true |
tensor o por tensor cuantizado | (C1-C2). |
(I3). | on_false |
tensor o por tensor cuantizado | (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C2) |
Restricciones
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
. - (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
.
Ejemplos
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semántica
Pasa los valores del tensor source
mediante scatter
en función del resultado de reduce_window
del tensor input
con select
y produce un tensor result
.
En el siguiente diagrama, se muestra cómo se calculan los elementos en result
a partir de operand
y source
con un ejemplo concreto.
Más formalmente:
selected_values = reduce_window_without_init(...)
por las siguientes entradas:- `inputs = [operando].
window_dimensions
,window_strides
ypadding
, que se usan tal como están.base_dilations = windows_dilations = 1
.body
se define de la siguiente manera:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
donde
E = element_type(operand)
yreduce_window_without_init
funcionan exactamente comoreduce_window
, excepto que elschedule
delreduce
subyacente (consulta reducir) no incluye valores de inicio. Por el momento, no se especifica qué sucede si la ventana correspondiente no tiene valores (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
, en el que:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
siselected_values[source_index]
tiene el elementooperand
deoperand_index
source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1-C4), (C6) y (C8-C11) |
(I2). | source |
tensor o por tensor cuantizado | (C1), (C2) |
(I3). | init_value |
Tensor de 0 dimensiones o tensor cuantificado por tensor | (C3) |
(I4). | window_dimensions |
Constante tensorial unidimensional de tipo si64 |
(C2), (C4) y (C5) |
(I5). | window_strides |
Constante tensorial unidimensional de tipo si64 |
(C2), (C6) y (C7) |
(I6). | padding |
Constante tensorial bidimensional de tipo si64 |
(C2) y (C8) |
(I7). | select |
la función | (C9) |
(I8). | scatter |
la función | (C10) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C11-C12). |
Restricciones
- (C1)
element_type(operand) = element_type(source)
. - (C2)
shape(source) = num_windows
, donde:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
. - (C4)
size(window_dimensions) = rank(operand)
. - (C5)
0 < window_dimensions
. - (C6)
size(window_strides) = rank(operand)
. - (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
. - (C9)
select
tiene el tipo(tensor<E>, tensor<E>) -> tensor<i1>
, en el queE = element_type(operand)
. - (C10)
scatter
tiene el tipo(tensor<E>, tensor<E>) -> tensor<E>
, dondeis_promotable(element_type(operand), E)
. - (C11)
shape(operand) = shape(result)
. - (C12)
element_type(result) = E
.
Ejemplos
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
enviar
Semántica
Envía inputs
al canal channel_id
y produce un token result
.
Si is_host_transfer
es true
, la operación transfiere datos al host. De lo contrario, transferirá los datos a otro dispositivo. Esto significa que está definido
por la implementación. Esta marca duplica la información proporcionada en channel_type
, por lo que, en el futuro, planeamos conservar solo una de ellas (#666).
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | inputs |
cantidad variable de tensores o tensores cuantificados | |
(I2). | token |
token |
|
(I3). | channel_id |
constante de tipo si64 |
|
(I4). | channel_type |
enum de DEVICE_TO_DEVICE y DEVICE_TO_HOST |
(C1) |
(I5). | is_host_transfer |
constante de tipo i1 |
(C1) |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Restricciones
- (C1)
channel_type
se define de la siguiente manera:DEVICE_TO_HOST
si esis_host_transfer = true
,- De lo contrario,
DEVICE_TO_DEVICE
.
Ejemplos
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semántica
Realiza una operación de desplazamiento a la izquierda a nivel de elementos en el tensor lhs
según la cantidad de bits rhs
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de tipo de número entero | (C1) |
(I2). | rhs |
tensor de tipo de número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de número entero | (C1) |
Restricciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semántica
Realiza una operación aritmética de desplazamiento a la derecha a nivel de los elementos en el tensor lhs
según la cantidad de bits rhs
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de tipo de número entero | (C1) |
(I2). | rhs |
tensor de tipo de número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de número entero | (C1) |
Restricciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semántica
Realiza una operación lógica de desplazamiento a la derecha a nivel de elementos en el tensor lhs
según la cantidad de bits de rhs
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de tipo de número entero | (C1) |
(I2). | rhs |
tensor de tipo de número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de número entero | (C1) |
Restricciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
igual.
Semántica
Muestra el signo del operand
a nivel de elementos y produce un tensor result
.
De manera más formal, para cada elemento x
, la semántica se puede expresar mediante la sintaxis de Python de la siguiente manera:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Para tipos cuantizados, realiza dequantize_op_quantize(sign, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de número entero con signo, punto flotante o tipo complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de número entero con signo, punto flotante o tipo complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semántica
Realiza una operación de seno a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
sin
de IEEE-754. - Para números complejos: seno complejo.
- Para tipos cuantizados:
dequantize_op_quantize(sine, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
slice
Semántica
Extrae una porción de operand
mediante índices de inicio calculados de forma estática y produce un tensor result
. start_indices
contiene los índices iniciales de la porción para cada dimensión, limit_indices
contiene los índices finales (exclusivos) de la porción para cada dimensión y strides
contiene los segmentos de cada dimensión.
De manera más formal, result[result_index] = operand[operand_index]
, donde operand_index = start_indices + result_index * strides
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o por tensor cuantizado | (C1-C3) o C5 |
(I2). | start_indices |
Constante tensorial unidimensional de tipo si64 |
(C2), (C3) y (C5) |
(I3). | limit_indices |
Constante tensorial unidimensional de tipo si64 |
(C2), (C3) y (C5) |
(I4). | strides |
Constante tensorial unidimensional de tipo si64 |
(C2) y (C4) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o por tensor cuantizado | (C1), (C5) |
Restricciones
- (C1)
element_type(operand) = element_type(result)
. - (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
. - (C3)
0 <= start_indices <= limit_indices <= shape(operand)
. - (C4)
0 < strides
. - (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
.
Ejemplos
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = dense<[1, 2]> : tensor<2xi64>,
limit_indices = dense<[3, 4]> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
sort
Semántica
Ordena porciones unidimensionales de inputs
a lo largo de la dimensión dimension
juntas, de acuerdo con comparator
, y produce results
.
A diferencia de entradas similares en otras operaciones, dimension
permite valores negativos, con la semántica que se describe a continuación. En el futuro, es posible que esto se inhabilite por motivos de coherencia (#1377).
Si is_stable
es verdadero, el orden es estable, es decir, se conserva el orden relativo de los elementos que el comparador considera iguales. Para el caso en el que hay una sola entrada, el comparador considera que dos elementos e1
y e2
son iguales solo si comparator(e1, e2) = comparator(e2, e1) = false
. Consulta la formalización a continuación para ver cómo esto se generaliza a varias entradas.
De manera más formal, para todos los result_index
en index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
, en el queriN
son elementos individuales enresult_index
y:
se inserta enadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- donde
sort
ordena una porción unidimensional en orden no descendente y espera quecomparator_together
muestretrue
si el argumento del lado izquierdo es menor que el segundo argumento de la derecha. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C5). |
(I2). | dimension |
constante de tipo si64 |
(C4) |
(I3). | is_stable |
constante de tipo i1 |
|
(I4). | comparator |
la función | (C5) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C2) y (C3) |
Restricciones
- (C1)
0 < size(inputs)
. - (C2)
type(inputs...) = type(results...)
. - (C3)
same(shape(inputs...) + shape(results...))
. - (C4)
-R <= dimension < R
, dondeR = rank(inputs[0])
. - (C5)
comparator
tiene el tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, en el queEi = element_type(inputs[i])
.
Ejemplos
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
sqrt
Semántica
Realiza una operación de raíz cuadrada a nivel de elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
squareRoot
de IEEE-754. - Para números complejos: raíz cuadrada compleja.
- Para tipos cuantizados:
dequantize_op_quantize(sqrt, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semántica
Realiza la resta a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: resta de números enteros.
- Para números de punto flotante:
subtraction
de IEEE-754. - Para números complejos: resta compleja.
- Para tipos cuantizados:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
(I2). | rhs |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
.
Ejemplos
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tanh
Semántica
Realiza la operación tangente hiperbólica a nivel de los elementos en el tensor operand
y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
tanh
de IEEE-754. - Para números complejos: tangente hiperbólica compleja.
- Para tipos cuantizados:
dequantize_op_quantize(tanh, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transponer
Semántica
Permuta las dimensiones del tensor operand
con permutation
y produce un tensor result
. De manera más formal, result[result_index] = operand[operand_index]
, donde result_index[d] = operand_index[permutation[d]]
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor o tensor cuantificado | (C1-C4). |
(I2). | permutation |
Constante tensorial unidimensional de tipo si64 |
(C2-C4). |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1) o (C3-C4) |
Restricciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
yquantization_dimension(result)
pueden diferir.
- (C2)
permutation
es una permutación derange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
. - (C4) Si es
is_per_axis_quantized(result)
, entoncesquantization_dimension(operand) = permutation(quantization_dimension(result))
.
Ejemplos
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semántica
Resuelve lotes de sistemas de ecuaciones lineales con matrices de coeficientes triangulares inferior o superior.
De manera más formal, con a
y b
, result[i0, ..., iR-3, :, :]
es la solución para op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
cuando left_side
es true
o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
cuando left_side
es false
, lo que resuelve la variable x
donde op(a)
está determinado por transpose_a
, que puede ser una de las siguientes opciones:
NO_TRANSPOSE
: Realiza la operación cona
tal como está.TRANSPOSE
: realiza una operación de transposición dea
.ADJOINT
: Realiza una operación sobre la transposición conjugada dea
.
Los datos de entrada solo se leen desde el triángulo inferior de a
, si lower
es true
o, de lo contrario, el triángulo superior de a
. Los datos de salida se muestran en el mismo triángulo; los valores en el otro triángulo están definidos por la implementación.
Si el valor de unit_diagonal
es verdadero, la implementación puede suponer que los elementos diagonales de a
son iguales a 1. De lo contrario, el comportamiento no está definido.
Para tipos cuantizados, realiza dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | a |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1-C3). |
(I2). | b |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1-C4). |
(I3). | left_side |
constante de tipo i1 |
(C3) |
(I4). | lower |
constante de tipo i1 |
|
(I5). | unit_diagonal |
constante de tipo i1 |
|
(I6). | transpose_a |
Enum de NO_TRANSPOSE , TRANSPOSE y ADJOINT |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de punto flotante o complejo, o tensor cuantificado por tensor | (C1) |
Restricciones
- (C1)
baseline_element_type(a) = baseline_element_type(b)
. - (C2)
2 <= rank(a) = rank(b) = R
. - (C3) La relación entre
shape(a)
yshape(b)
se define de la siguiente manera:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
.
Ejemplos
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semántica
Produce una tupla result
a partir de los valores val
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | val |
número variable de valores | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tuple | (C1) |
Restricciones
- (C1)
result
tiene el tipotuple<E0, ..., EN-1>
, dondeEi = type(val[i])
.
Ejemplos
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semántica
Realiza la conversión a nivel de elementos del tensor cuantificado operand
en un tensor de punto flotante result
de acuerdo con los parámetros de cuantización definidos por el tipo operand
.
Más formalmente, result = dequantize(operand)
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor cuantizado | (C1), (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1), (C2) |
Restricciones
- (C1)
shape(operand) = shape(result)
. - (C2)
element_type(result) = expressed_type(operand)
.
Ejemplos
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semántica
Realiza la conversión a nivel de elementos del tensor de punto flotante o del tensor cuantificado operand
en un tensor cuantificado result
de acuerdo con los parámetros de cuantización definidos por el tipo result
.
Más formalmente,
- Si es
is_float(operand)
:result = quantize(operand, type(result))
.
- Si es
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
tensor de punto flotante o tipo cuantizado | (C1), (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor cuantizado | (C1), (C2) |
Restricciones
- (C1)
shape(operand) = shape(result)
. - (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
.
Ejemplos
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
mientras
Semántica
Produce el resultado de la ejecución de la función body
0 o más veces, mientras que la función cond
genera true
. De manera más formal, la semántica se puede expresar mediante la sintaxis de Python de la siguiente manera:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Por definir el comportamiento de un bucle infinito (#383).
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | operand |
cantidad variable de tensores, tensores cuantificados o tokens | (C1-C3). |
(I2). | cond |
la función | (C1) |
(I3). | body |
la función | (C2) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
results |
cantidad variable de tensores, tensores cuantificados o tokens | (C3) |
Restricciones
- (C1)
cond
tiene el tipo(T0, ..., TN-1) -> tensor<i1>
, en el queTi = type(operand[i])
. - (C2)
body
tiene el tipo(T0, ..., TN-1) -> (T0, ..., TN-1)
, en el queTi = type(operand[i])
. - (C3)
type(results...) = type(operand...)
.
Ejemplos
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semántica
Realiza XOR a nivel de elementos de dos tensores lhs
y rhs
, y produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: XOR lógico.
- Para números enteros: XOR a nivel de bits.
Entradas
Etiqueta | Nombre | Tipo | Restricciones |
---|---|---|---|
(I1). | lhs |
tensor de tipo booleano o número entero | (C1) |
(I2). | rhs |
tensor de tipo booleano o número entero | (C1) |
Salidas
Nombre | Tipo | Restricciones |
---|---|---|
result |
tensor de tipo booleano o número entero | (C1) |
Restricciones
- (C1)
type(lhs) = type(rhs) = type(result)
.
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Ejecución
Ejecución secuencial
Para ejecutar un programa StableHLO, se proporcionan valores de entrada a la función main
y se calculan los valores de salida. Los valores de salida de una función se calculan mediante la ejecución del grafo de ops con raíces en la op return
correspondiente.
El orden de ejecución se define por la implementación, siempre y cuando esté alineado con el flujo de datos, es decir, si las operaciones se ejecutan antes de sus usos. En StableHLO, todas las operaciones con efectos secundarios consumen un token y producen uno (varios tokens se pueden multiplexar en uno a través de after_all
), por lo que el orden de ejecución de los efectos secundarios también se alinea con el flujo de datos. Los posibles órdenes de ejecución del programa de ejemplo anterior son %0
→ %1
→ %2
→ %3
→ %4
→ return
o %3
→ %0
→ %1
→ %2
→ %4
→ return
.
De manera más formal, un proceso EstableHLO es una combinación de: 1) un programa estable, 2) estados de operación (aún no ejecutado, ya ejecutado) y 3) valores intermedios en los que funciona el proceso.
El proceso comienza con los valores de entrada en la función main
, avanza por el grafo de operaciones que actualizan los estados de las operaciones y los valores intermedios, y finaliza con los valores de salida. Aún no se formaliza el proceso (#484).
Ejecución paralela
Los programas StableHLO se pueden ejecutar en paralelo y se organizan en una cuadrícula de procesos en 2D de num_replicas
por num_partitions
, que ambos tienen el tipo ui32
.
En la cuadrícula de procesos de StableHLO, num_replicas * num_partitions
de los procesos de StableHLO se ejecutan al mismo tiempo. Cada proceso tiene un process_id = (replica_id, partition_id)
único, en el que replica_id
en replica_ids = range(num_replicas)
y partition_id
en partition_ids = range(num_partitions)
, ambos tienen el tipo ui32
.
El tamaño de la cuadrícula de procesos se conoce estáticamente para cada programa (en el futuro, planeamos hacerlo una parte explícita de los programas StableHLO #650), y la posición dentro de la cuadrícula de procesos se conoce estáticamente para cada proceso. Cada proceso tiene acceso a su posición dentro de la cuadrícula de procesos mediante las operaciones replica_id
y partition_id
.
Dentro de la cuadrícula de procesos, todos los programas pueden ser iguales (en el estilo "Programa único, varios datos"), pueden ser todos diferentes (en el estilo "Varios programas, varios datos") o algo intermedio. En el futuro, planeamos agregar compatibilidad con otros modismos de definición de programas StableHLO paralelos, incluido GSPMD (#619).
Dentro de la cuadrícula de procesos, los procesos son mayormente independientes entre sí: tienen estados de operación separados, valores de entrada/intermedio/salida independientes y la mayoría de las operaciones se ejecutan por separado entre procesos, a excepción de una pequeña cantidad de operaciones colectivas que se describen a continuación.
Debido a que la ejecución de la mayoría de las operaciones solo usa valores del mismo proceso, no suele ser ambiguo hacer referencia a estos valores por sus nombres.
Sin embargo, cuando se describe la semántica de ops colectivas, esto no es suficiente y da lugar a la notación name@process_id
para hacer referencia al valor name
dentro de un proceso en particular. (Desde esa perspectiva, se puede ver name
no calificado como una abreviatura de name@(replica_id(), partition_id())
).
El orden de ejecución entre los procesos está definido por la implementación, excepto por la sincronización introducida por la comunicación punto a punto y las operaciones colectivas, como se describe a continuación.
Comunicación punto a punto
Los procesos StableHLO pueden comunicarse entre sí a través de canales estables. Un canal se representa con un ID positivo del tipo si64
. A través de varias operaciones, es posible enviar valores a los canales y recibirlos de ellos.
Aún no se requiere más formalización, p.ej., de dónde provienen estos IDs de canal, cómo los procesos los programas toman conocimiento de ellos y qué tipo de sincronización introducen (#484).
Comunicación en vivo
Cada proceso de StableHLO tiene acceso a dos interfaces de transmisión:
- Entrada que se puede leer.
- Salida en la que se pueden escribir.
A diferencia de los canales, que se usan para la comunicación entre procesos y, por lo tanto, tienen procesos en ambos extremos, las entradas y salidas tienen definida la otra implementación final.
Aún no se define más formalización, p.ej., cómo la comunicación de transmisión influye en el orden de ejecución y qué tipo de sincronización introduce (#484).
Operaciones colectivas
Hay seis operaciones colectivas en StableHLO: all_gather
, all_reduce
, all_to_all
, collective_broadcast
, collective_permute
y reduce_scatter
. Todas estas operaciones dividen los procesos en la cuadrícula de procesos de StableHLO en grupos de procesos de StableHLO y ejecutan un cálculo conjunto dentro de cada grupo de procesos, independientemente de otros grupos de procesos.
Dentro de cada grupo de procesos, las operaciones colectivas pueden introducir una barrera de sincronización. No se requiere más formalización, p.ej., explicar cuándo exactamente se produce esta sincronización, cómo llegan exactamente los procesos a esta barrera y qué sucede si no lo hacen (#484).
Si el grupo de procesos involucra una comunicación de partición cruzada, es decir, hay procesos en el grupo de procesos cuyos ID de partición son diferentes, la ejecución de la operación colectiva necesita un canal, y la operación colectiva debe proporcionar un channel_id
positivo de tipo si64
. La comunicación entre réplicas no necesita canales.
Los cálculos que realizan las operaciones colectivas son específicos de cada operación individual y se describen en las secciones anteriores de las operaciones individuales. Sin embargo, las estrategias mediante las cuales la cuadrícula de procesos se divide en grupos de procesos se comparten entre estas operaciones y se describen en esta sección. De manera más formal, StableHLO admite las siguientes cuatro estrategias.
cross_replica
Solo las comunicaciones entre réplicas ocurren dentro de cada grupo de procesos. Esta estrategia toma replica_groups
, una lista de listas de IDs de réplica, y calcula un producto cartesiano de replica_groups
según partition_ids
. replica_groups
debe tener elementos únicos y abarcar todos los replica_ids
. De manera más formal, con la sintaxis de Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]]
y num_partitions = 2
, cross_replica
producirá [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
.
cross_partition
Solo las comunicaciones de particiones cruzadas ocurren dentro de cada grupo de procesos. Esta estrategia toma partition_groups
(una lista de listas de ID de partición) y calcula un producto cartesiano de partition_groups
por replica_ids
.
partition_groups
debe tener elementos únicos y abarcar toda la partition_ids
.
De manera más formal, con la sintaxis de Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para partition_groups = [[0, 1]]
y num_replicas = 4
, cross_partition
producirá [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
.
cross_replica_and_partition
Las comunicaciones entre réplicas y particiones pueden ocurrir dentro de cada grupo de procesos. Esta estrategia toma replica_groups
, una lista de listas de IDs de réplica, y calcula los productos cartesianos de cada replica_group
por partition_ids
. replica_groups
debe tener elementos únicos y abarcar todas las replica_ids
. De manera más formal, con la sintaxis de Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]]
y num_partitions = 2
, cross_replica_and_partition
producirá [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
.
flattened_ids
Esta estrategia toma flattened_id_groups
, una lista de IDs de procesos "compactos" en forma de replica_id * num_partitions + partition_id
, y los convierte en IDs de proceso. flattened_id_groups
debe tener elementos únicos y abarcar todos los process_ids
. De manera más formal, con la sintaxis de Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
, num_replicas = 4
y num_partitions = 2
, flattened_ids
producirá [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
.
Exactitud
Por el momento, StableHLO no proporciona garantías sobre la exactitud numérica, pero esto puede cambiar en el futuro (#1156).
Errores
Los programas StableHLO se validan a través de un amplio conjunto de restricciones para operaciones individuales, que descarta muchas clases de errores antes del tiempo de ejecución. Sin embargo, las condiciones de error aún son posibles, p.ej., a través de desbordamientos de números enteros, accesos fuera de los límites, etc. A menos que se llamen explícitamente, todos estos errores generan un comportamiento definido por la implementación, pero esto puede cambiar en el futuro (#1157).
Como excepción a esta regla, las excepciones de punto flotante en los programas StableHLO tienen un comportamiento bien definido. Las operaciones que generan excepciones definidas por el estándar IEEE-754 (operación no válida, división por cero, desbordamiento, subdesbordamiento o excepciones inexactas) producen resultados predeterminados (como se define en el estándar) y continúan su ejecución sin elevar la marca de estado correspondiente; similar al control de excepciones raiseNoFlag
del estándar. La implementación define las excepciones para las operaciones no estándar (p.ej., las aritméticas complejas y ciertas funciones trascendentales).
Notation
Para describir la sintaxis, en este documento se usa la variante ISO modificada de la sintaxis EBNF (ISO/IEC 14977:1996, Wikipedia), con dos modificaciones: 1) las reglas se definen con ::=
en lugar de =
;
2) la concatenación se expresa mediante la yuxtaposición en lugar de ,
.
Para describir la semántica (es decir, dentro de las secciones “Tipos”, “Constantes” y “Operaciones”), usamos fórmulas que se basan en la sintaxis de Python, ampliada con compatibilidad para expresar operaciones de arreglo de forma concisa, como se describe a continuación. Esto funciona bien para pequeños fragmentos de código, pero en casos excepcionales, cuando se necesitan fragmentos de código más grandes, usamos la sintaxis normal de Python, que siempre se presenta de manera explícita.
Fórmulas
Exploremos cómo funcionan las fórmulas con un ejemplo de la especificación dot_general
. Una de las restricciones de esta operación se ve de la siguiente manera:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
.
Los nombres usados en esta fórmula provienen de dos fuentes: 1) funciones globales, es decir, dim
, y 2) definiciones de miembros del elemento del programa correspondiente, es decir, las entradas lhs
, lhs_batching_dimensions
, rhs
y rhs_batching_dimensions
definidas en la sección "Entradas" de dot_general
.
Como se mencionó antes, la sintaxis de esta fórmula se basa en Python con algunas extensiones orientadas a la brevedad. Para entender la fórmula, vamos a transformarla en la sintaxis normal de Python.
A) En estas fórmulas, usamos =
para representar la igualdad, por lo que el primer paso para obtener la sintaxis de Python es reemplazar =
por ==
, de la siguiente manera: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
.
B) Además, estas fórmulas admiten elipses (...
) que convierten las expresiones escalares en expresiones de tensor. En pocas palabras, f(xs...)
significa "para cada x
escalar en el tensor xs
, calcula un f(x)
escalar y, luego, muestra todos estos resultados escalares juntos como un resultado del tensor". En la sintaxis normal de Python,
nuestra fórmula de ejemplo se convierte en:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
.
Gracias a los puntos suspensivos, a menudo es posible evitar trabajar al nivel de escalares individuales. Sin embargo, en algunos casos difíciles, se puede usar la sintaxis semiinformal de nivel inferior, como en la fórmula start_indices[bi0, ..., :, ..., biN]
de la especificación gather
. Por razones de brevedad, no proporcionamos un formalismo exacto para traducir esa sintaxis al formato clásico de Python, con la esperanza de que sea intuitivamente comprensible según el caso.
Avísanos si algunas fórmulas específicas se ven opacas y trataremos de mejorarlas.
Además, verás que las fórmulas usan elipses para expandir todo tipo de listas, incluidos los tensores, las listas de tensores (que, p.ej., pueden surgir de una cantidad variable de tensores), entre otras. Esta es otra área en la que no proporcionamos un formalismo exacto (p.ej., las listas ni siquiera son parte del sistema de comprensibilidad intuitiva).
C) El último vehículo notacional notable que empleamos es la transmisión implícita. Si bien el opset StableHLO no admite la transmisión implícita, las fórmulas sí lo hacen, para garantizar la concisión. En pocas palabras, si se usa un escalar en un contexto donde se espera un tensor, el escalar se transmite a la forma esperada.
Para continuar con el ejemplo de dot_general
, esta es otra restricción: 0 <= lhs_batching_dimensions < rank(lhs)
. Como se define en la especificación dot_general
, lhs_batching_dimensions
es un tensor; sin embargo, tanto 0
como rank(lhs)
son escalares. Después de aplicar la transmisión implícita, la fórmula se convertirá en [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Cuando se aplica a una operación dot_general
específica, esta fórmula se evaluará como un tensor de valores booleanos. Cuando las fórmulas se usan como restricciones, la restricción se mantiene si la fórmula se evalúa como true
o como un tensor que solo tiene elementos true
.
Nombres
En las fórmulas, el alcance léxico incluye lo siguiente: 1) las funciones globales, 2) las definiciones de los miembros,
3) las definiciones locales. A continuación, se proporciona la lista de funciones globales. La lista de definiciones de elementos depende del elemento del programa al que se aplica la notación:
- Para las operaciones, las definiciones de los miembros incluyen los nombres ingresados en las secciones “Entradas” y “Salidas”.
- Para todo lo demás, las definiciones de los miembros incluyen partes estructurales del elemento del programa, nombradas según los no terminales de EBNF correspondientes. La mayoría de las veces, los nombres de esas partes estructurales se obtienen convirtiendo los nombres de las no terminales en snake case (p.ej.,
IntegerLiteral
=>integer_literal
), pero a veces los nombres se abrevian en el proceso (p.ej.,QuantizationStorageType
=>storage_type
), en cuyo caso los nombres se ingresan de manera explícita de manera similar a las secciones "Inputs" y "Outputs" en las especificaciones de operación. - Además, las definiciones de los miembros siempre incluyen
self
para hacer referencia al elemento del programa correspondiente.
Valores
Cuando se evalúan las fórmulas, funcionan con los siguientes tipos de valores:
1) Value
(valores reales, p.ej., dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
siempre conocen sus tipos),
2) Placeholder
(valores futuros, p.ej., lhs
, rhs
o result
; sus valores
reales aún no se conocen, solo se conocen sus tipos),
3) Type
(tipos definidos en la sección "Tipos")
4) Function
(funciones globales según se definen en la sección "Funciones").
Según el contexto, los nombres pueden referirse a valores diferentes. Específicamente, la sección "Semántica" de las operaciones (y equivalentes de otros elementos del programa) define la lógica del entorno de ejecución, por lo que todas las entradas están disponibles como Value
.
Por el contrario, la sección "Restricciones" de las operaciones (y equivalentes) define la lógica de "tiempo de compilación", es decir, algo que se suele ejecutar antes del tiempo de ejecución, por lo que solo las entradas constantes están disponibles como Value
y las demás solo están disponibles como Placeholder
.
Nombres | En "Semántica" | En “Restricciones” |
---|---|---|
Funciones globales | Function |
Function |
Entradas constantes | Value |
Value |
Entradas no constantes | Value |
Placeholder |
Salidas | Value |
Placeholder |
Definiciones locales | Depende de la definición | Depende de la definición |
Consideremos una operación transpose
de ejemplo:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Para esta operación, permutation
es una constante, por lo que está disponible como Value
tanto en semántica como en restricciones. Por el contrario, operand
y result
están disponibles como Value
en la semántica, pero solo como Placeholder
en restricciones.
Funciones
Construcción de tipos
No hay funciones que se puedan usar para construir tipos. En su lugar, usamos directamente la sintaxis de tipos, ya que suele ser más concisa. P.ej., (tensor<E>, tensor<E>) -> (tensor<E>)
en lugar de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funciones en tipos
element_type
se define en los tipos de tensor y los tipos de tensor cuantificados y muestra, respectivamente, la parteTensorElementType
oQuantizedTensorElementType
delTensorType
oQuantizedTensorType
correspondientes.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
es un atajo parais_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
es un acceso directo parais_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
verifica si el tipox
se puede ascender al tipoy
. Cuandox
yy
son elementosQuantizedTensorElementType
, la promoción se aplica solo astorage_type
. Actualmente, esta versión específica de la promoción se usa en el contexto del cálculo de reducción (consulta RFC para obtener más detalles).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
es un acceso directo parais_quantized_tensor_element_type(x)
.is_type_name(x: Value | Placeholder | Type) -> Value
. Disponible para todos los tipos. Por ejemplo,is_float(x)
muestratrue
six
es unFloatType
. Six
es un valor o un marcador de posición, esta función es un acceso directo parais_type_name(type(x))
.max_value(x: Type) -> Value
muestra el valor máximo de unTensorElementType
. Six
no es unTensorElementType
, muestraNone
.min_value(x: Type) -> Value
muestra el valor mínimo posible de unTensorElementType
. Six
no es unTensorElementType
, muestraNone
.member_name(x: Value | Placeholder | Type) -> Any
. Disponible para todas las definiciones de miembrosmember_name
de todos los tipos. Por ejemplo,tensor_element_type(x)
muestra la parteTensorElementType
de unTensorType
correspondiente. Six
es un valor o un marcador de posición, esta función es un acceso directo paramember_name(type(x))
. Six
no es un tipo que tenga un miembro adecuado, o un valor o un marcador de posición de ese tipo, muestraNone
.
Construcción de valores
operation_name(*xs: Value | Type) -> Value
. Disponible para todas las operaciones. Por ejemplo,add(lhs, rhs)
toma dos valores de tensorlhs
yrhs
, y muestra el resultado de evaluar la operaciónadd
con estas entradas. Para algunas operaciones, p. ej.,broadcast_in_dim
, los tipos de sus resultados son “de carga”, es decir, necesarios para evaluar una operación. En este caso, la función toma estos tipos como argumentos.
Función en valores
Todos los operadores y las funciones de Python están disponibles. P.ej., las anotaciones de suscripción y segmentación de Python están disponibles para indexarse en tensores, tensores cuantificados y de tuplas.
to_destination_type(x: Value, destination_type: Type) -> Value
se define en los tensores y muestra el valor convertido dex
segúntype(x)
ydestination_type
de la siguiente manera:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Hay un debate anticipado sobre la combinación de operaciones convert
, uniform_quantize
y uniform_dequantize
(#1576).
Después de la combinación, no necesitamos la función anterior y, en su lugar, podemos usar el nombre de la operación para convert
.
is_nan(x: Value) -> Value
se define en los tensores y muestratrue
si todos los elementos dex
sonNaN
o, de lo contrario,false
. Six
no es un tensor, muestraNone
.is_sorted(x: Value) -> Value
se define en los tensores y muestratrue
si los elementos dex
se ordenan de manera ascendente con respecto al orden ascendente lexicográfico de sus índices o, en caso contrario,false
. Six
no es un tensor, muestraNone
.is_unique(x: Value) -> Value
se define en los tensores y muestratrue
six
no tiene elementos duplicados ofalse
de lo contrario. Six
no es un tensor, muestraNone
.member_name(x: Value) -> Any
se define para todas las definiciones de miembrosmember_name
de todos los valores. Por ejemplo,real_part(x)
muestra la parteRealPart
de unComplexConstant
correspondiente. Six
no es un valor que tenga un miembro apropiado, muestraNone
.same(x: Value) -> Value
se define en los tensores y muestratrue
si todos los elementos dex
son iguales entre sí o, de lo contrario, muestrafalse
. Si el tensor no tiene elementos, se cuenta como "todos son iguales entre sí", es decir, la función muestratrue
. Six
no es un tensor, muestraNone
.split(x: Value, num_results: Value, axis: Value) -> Value
se define en los tensores y muestra porcionesnum_results
dex
a lo largo del ejeaxis
. Six
no es un tensor odim(x, axis) % num_results != 0
, muestraNone
.
Cálculos de formas
axes(x: Value | Placeholder | Type) -> Value
es un acceso directo pararange(rank(x))
.dim(x: Value | Placeholder | Type, axis: Value) -> Value
es un acceso directo parashape(x)[axis]
.dims(x: Value | Placeholder | Type, axes: List) -> List
es un acceso directo paralist(map(lambda axis: dim(x, axis), axes))
.index_space(x: Value | Placeholder | Type) -> Value
se define en los tensores y muestra índicessize(x)
para elTensorType
correspondiente ordenado en orden lexicográfico ascendente, es decir,[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
. Six
no es un tipo de tensor, un tipo de tensor cuantificado, un valor o un marcador de posición de uno de estos tipos, muestraNone
.rank(x: Value | Placeholder | Type) -> Value
es un acceso directo parasize(shape(x))
.shape(x: Value | Placeholder | Type) -> Value
se define en la sección "Funciones en tipos" mediantemember_name
.size(x: Value | Placeholder | Type) -> Value
es un acceso directo parareduce(lambda x, y: x * y, shape(x))
.
Cálculos de cuantización
def baseline_element_type(x: Value | Placeholder | Type) -> Type
es un acceso directo paraelement_type(baseline_type(x))
.baseline_type
se define en los tipos de tensores y de tipos de tensores cuantificados y los transforma en un "modelo de referencia", es decir, un tipo con la misma forma, pero con los parámetros de cuantización del tipo de elemento restablecidos a los valores predeterminados. Esto se usa como un truco útil para comparar los tipos de tensor y los cuantizados de manera uniforme, lo que es necesario con bastante frecuencia. Para los tipos cuantizados, esto permite comparar tipos que ignoran los parámetros de cuantización, es decir,shape
,storage_type
,expressed_type
,storage_min
,storage_max
yquantization_dimension
(para el tipo cuantizado por eje) deben coincidir, peroscales
yzero points
pueden diferir.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
se define en tipos de tensores cuantizados y los convierte en tipos de tensores de punto flotante. Esto sucede a través de la conversión de elementos cuantizados que representan valores enteros del tipo de almacenamiento en valores de punto flotante correspondientes del tipo expresado mediante el punto cero y la escala asociados con el tipo de elemento cuantificado.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
se define en tipos de tensores de punto flotante y los convierte en tipos de tensores cuantizados. Esto sucede cuando se convierten los valores de punto flotante del tipo expresado en valores de números enteros correspondientes del tipo de almacenamiento con el punto cero y la escala asociados con el tipo de elemento cuantificado.
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
dequantize_op_quantize
se usa para especificar cálculos a nivel de elementos en tensores cuantificados. Simplifica, es decir, convierte los elementos cuantificados en sus tipos expresados, realiza una operación y, luego, cuantiza los resultados, es decir, convierte los resultados en sus tipos de almacenamiento. Por el momento, esta función solo funciona con la cuantización por tensor. La cuantización por eje es un trabajo en curso (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
Cálculos de cuadrícula
cross_partition(replica_groups: Value) -> Value
. Consulta la sección "cross_replica" anterior.cross_replica(replica_groups: Value) -> Value
. Consulta la sección "cross_replica" anterior.cross_replica_and_partition(replica_groups: Value) -> Value
. Consulta la sección "cross_replica_and_partition" más arriba.flattened_ids(replica_groups: Value) -> Value
. Consulta la sección "Flated_ids" anterior.