StableHLO es un conjunto de operaciones de alto nivel (HLO) en máquinas de aprendizaje automático (AA). StableHLO funciona como una capa de portabilidad entre diferentes Frameworks y compiladores de AA: Frameworks de AA que producen programas StableHLO y son compatibles con compiladores de AA que consumen programas StableHLO.
Nuestro objetivo es simplificar y acelerar el desarrollo del AA creando más entre diversos frameworks de AA (como TensorFlow, JAX y PyTorch) y compiladores de AA (como IREE y XLA). Con ese fin, este proporciona una especificación para el lenguaje de programación StableHLO.
Esta especificación contiene tres secciones principales. En primer lugar, La sección Programas describe la estructura de los programas de StableHLO. que consisten en funciones StableHLO que, a su vez, consisten en operaciones de StableHLO. Dentro de esa estructura, la sección Operaciones especifica la semántica de operaciones individuales. La sección Ejecución proporciona semántica para todos estas ops se ejecutan juntas dentro de un programa. Por último, la En la sección Notación, se analiza la notación utilizada en todo el especificación.
Para ver las especificaciones de una versión anterior de StableHLO, abre el repo en versión etiquetada de interés. Por ejemplo, la especificación de StableHLO v0.19.0. Para ver los cambios que se produjeron en cada cambio de versión secundario de StableHLO, consulta el registro de la versión en VhloDialect.td.
Programas
Program ::= {Func}
Los programas StableHLO constan de una cantidad arbitraria de funciones StableHLO.
A continuación, se muestra un programa de ejemplo con una función @main
que tiene 3 entradas
(%image
, %weights
y %bias
) y 1 resultado. El cuerpo de la función
tiene 6 operaciones.
func.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() {value = dense<0.0> : tensor<1x10xf32>} : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
Funciones
Func ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput ::= ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput ::= ValueType
FuncBody ::= {Op}
Las funciones estables (que también se denominan funciones con nombre) tienen un identificador, entradas/salidas y un cuerpo. En el futuro, planeamos ingresar metadatos adicionales para las funciones a fin de lograr una mejor compatibilidad con HLO (#425, n.o 626, #740, #744).
Identificadores
FuncId ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
| '%' letter {letter | digit}
letter ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit ::= '0' | ... | '9'
Los identificadores de StableHLO son similares a los identificadores en muchos entornos idiomas, con dos peculiaridades: 1) todos los identificadores tienen sigilos que distinguir distintos tipos de identificadores, 2) los identificadores de valor se pueden completamente numéricos para simplificar la generación de programas StableHLO.
Tipos
Type ::= ValueType | NonValueType
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
Los tipos de StableHLO se clasifican en tipos de valores (también llamados tipos de primera clase), que representan valores de StableHLO y tipos sin valores que describen otros elementos del programa. Los tipos StableHLO son similares a los tipos en tiene muchos lenguajes de programación, y su principal peculiaridad es el de StableHLO. naturaleza específica del dominio que da como resultado algunos resultados inusuales (por ejemplo, tipos escalares no son tipos de valores).
TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit} | '?'
Los tipos de tensores representan tensores, es decir, arrays multidimensionales. Tienen un
forma y un tipo de elemento, donde una forma representa un valor no negativo o
tamaños de dimensión desconocidos en orden ascendente de los
dimensiones (también llamadas ejes) numeradas del 0
al R-1
. El
la cantidad de dimensiones R
se denomina clasificación. Por ejemplo, tensor<2x3xf32>
es
Un tipo de tensor con forma 2x3
y tipo de elemento f32
. Tiene dos dimensiones
(o, en otras palabras, dos ejes): 0.a dimensión y 1.a dimensión, cuyos tamaños
son 2 y 3. Su clasificación es 2.
Las formas pueden ser parcialmente o totalmente desconocidas (dinámicas), p.ej., tensor<?x2xf64>
es parcialmente desconocido y tensor<?x?xf64>
es completamente desconocido. Dinámico
los tamaños de las dimensiones se representan con un ?
. No es posible anular la clasificación de las formas.
En el futuro, planeamos explorar la extensión de los tipos de tensores más allá tamaños de dimensiones y tipos de elementos, por ejemplo, para incluir diseños (#629) y dispersión (#1078).
QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
QuantizationStorageType
['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
':' QuantizationExpressedType
[':' QuantizationDimension]
',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
| '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nombre | Tipo | Limitaciones |
---|---|---|
storage_type |
tipo de número entero | (C1-C3), (C8) |
storage_min |
constante de número entero | (C1), (C3) y (C7) |
storage_max |
constante de número entero | (C2), (C3) y (C7) |
expressed_type |
tipo de punto flotante | (C4) |
quantization_dimension |
constante de número entero opcional | (C10-C12) |
scales |
número variádico de constantes de punto flotante | (C4-C6), (C9), (C10) y (C13) |
zero_points |
cantidad variádica de constantes de número entero | (C7-C9) |
Los tipos de elementos cuantificados representan valores enteros de un tipo de almacenamiento en
el rango de storage_min
a storage_max
(inclusive) que corresponden a
Valores de punto flotante de un tipo expresado. Para un número entero i
determinado,
el valor de punto flotante correspondiente f
se puede calcular como
f = (i - zero_point) * scale
, donde se llama a scale
y zero_point
parámetros de cuantización. storage_min
y storage_max
son opcionales
en la gramática, pero los valores predeterminados son min_value(storage_type)
y
max_value(storage_type)
respectivamente. Los tipos de elementos cuantificados tienen las siguientes características:
las siguientes restricciones:
- (C1)
type(storage_min) = storage_type
- (C2)
type(storage_max) = storage_type
- (C3)
min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type)
- (C4)
type(scales...) = expressed_type
- (C5)
0 < scales
- (C6)
is_finite(scales...)
- (C7)
storage_min <= zero_points <= storage_max
. - (C8)
type(zero_points...) = storage_type
- (C9)
size(scales) = size(zero_points)
. - (C10) Si es
is_empty(quantization_dimension)
, entoncessize(scales) = 1
. - (C11)
0 <= quantization_dimension
Por el momento, QuantizationScale
es una constante de punto flotante, pero hay
gran interés en las escalas basadas en números enteros, representados con multiplicadores y
cambios. Tenemos planificado explorar esta función en un futuro cercano.
(#1404).
Hay un debate en curso sobre la semántica de QuantizationZeroPoint
,
incluidos el tipo, los valores y si puede haber
posiblemente múltiples puntos cero
en un tipo de tensor cuantificado. Según el
resultados de esta discusión, la especificación alrededor de cero puntos puede cambiar
en el futuro (#1405).
Otro debate en curso involucra la semántica de QuantizationStorageMin
.
y QuantizationStorageMax
para determinar si se debe aplicar alguna restricción
impuesto a estos valores y a los valores de los tensores cuantificados
(#1406).
Por último, planeamos explorar la representación de escalas desconocidas y cero de manera similar a cómo planeamos explorar la representación tamaños de dimensiones (#1407).
Los tipos de tensores cuantificados representan tensores con elementos cuantificados. Estos son exactamente los mismos que los regulares, con la excepción de que sus elementos tienen tipos de elementos cuantizados, en lugar de tipos de elementos regulares.
En tensores cuantificados, la cuantización puede ser por tensor, es decir, tiene
una scale
y una zero_point
para todo el tensor o pueden ser por eje,
es decir, tener varios scales
y zero_points
, un par por porción de
una dimensión en particular quantization_dimension
De manera más formal, en un tensor t
con la cuantización por eje, hay segmentos dim(t, quantization_dimension)
de la quantization_dimension
: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :]
etc. Todos los elementos de la i
a porción usan scales[i]
y zero_points[i]
, como
sus parámetros de cuantización. Los tipos de tensores cuantificados tienen las siguientes características:
limitaciones:
- Para la cuantización por tensor:
- Sin restricciones adicionales.
- Para la cuantización por eje:
- (C12)
quantization_dimension < rank(self)
. - (C13)
dim(self, quantization_dimension) = size(scales)
.
- (C12)
TokenType ::= 'token'
Los tipos de token representan tokens, es decir, valores opacos producidos y consumidos debido a algunas operaciones. Los tokens se usan para imponer un orden de ejecución a las operaciones como se describe en la sección Ejecución.
TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]
Los tipos de tuplas representan tuplas, es decir, listas heterogéneas. Las tuplas son un legado
que solo existe para compatibilidad con HLO. En HLO, las tuplas son
que se usan para representar entradas y salidas variables. En StableHLO, las entradas variádicas y
de salida son compatibles de forma nativa, y el único uso de tuplas en StableHLO es
representan de manera exhaustiva la ABI HLO donde, p.ej., T
, tuple<T>
y
tuple<tuple<T>>
puede ser sustancialmente diferente según un
para implementarlos. En el futuro, planeamos realizar cambios en la ABI de HLO
lo que nos permite quitar tipos de tuplas de StableHLO.
(#598).
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Los tipos de elementos representan elementos de tipos de tensores. A diferencia de lo que ocurre
lenguajes, estos tipos no son de primera clase en StableHLO. Esto significa que
Los programas StableHLO no pueden representar directamente valores de estos tipos (por ello,
es idiomático representar valores escalares de tipo T
con tensor de 0 dimensiones
valores de tipo tensor<T>
).
- El tipo booleano representa los valores booleanos
true
yfalse
. - Los tipos de números enteros pueden ser con firma (
si
) o sin firma (ui
), y tener uno de los anchos de bits admitidos (2
,4
,8
,16
,32
o64
). Los tipossiN
con firma representan valores de números enteros de-2^(N-1)
a2^(N-1)-1
los tiposuiN
inclusivos y sin firma representan valores enteros de0
a2^N-1
inclusive. - Los tipos de punto flotante pueden ser uno de los siguientes:
f8E4M3FN
yf8E5M2
que corresponden a la CodificacionesE4M3
yE5M2
del formato FP8 descrito en Formatos FP8 para el aprendizaje profundo.- Tipos
f8E4M3FNUZ
yf8E5M2FNUZ
que corresponden aE4M3
yE5M2
de los formatos FP8 descritos en Formatos numéricos de 8 bits para redes neuronales profundas. - Tipo
f8E4M3B11FNUZ
correspondiente a la codificaciónE4M3
de los formatos FP8 descritos en Inferencia y entrenamiento de punto flotante híbrido de 8 bits (HFP8) para redes neuronales profundas. - Tipo de
bf16
correspondiente al formatobfloat16
descrito en BFloat16: El secreto para un alto rendimiento en Cloud TPU. - Los tipos
f16
,f32
yf64
corresponden a respectivamentebinary16
("precisión media"),binary32
("precisión simple") y Formatosbinary64
("precisión doble") descritos en el estándar IEEE 754. - El tipo
tf32
corresponde al formato TensorFloat32. y tiene compatibilidad limitada en StableHLO.
- Los tipos complejos representan valores complejos que tienen una parte real.
y una parte imaginaria del mismo tipo de elemento. Complejo admitido
los tipos son
complex<f32>
(ambas partes son del tipof32
) ycomplex<f64>
(ambas partes son del tipof64
).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]
Los tipos de funciones representan funciones con nombre y anónimas. Tienen
tipos de entrada (la lista de tipos en el lado izquierdo de ->
) y tipos de salida
(la lista de tipos a la derecha de ->
). En muchos casos de programación,
lenguajes, los tipos de funciones son de primera clase, pero no están en StableHLO.
StringType ::= 'string'
El tipo de string representa secuencias de bytes. A diferencia de lo que ocurre en varios idiomas, el tipo de cadena no es la primera clase en StableHLO y solo se usa para especificar metadatos estáticos para elementos del programa.
Operaciones
Las operaciones StableHLO (que también se denominan ops) representan un conjunto cerrado de las operaciones de alto nivel en los modelos de aprendizaje automático. Como se mencionó anteriormente, La sintaxis del StableHLO está inspirada en gran medida en MLIR, que no es el alternativa ergonómica, pero podría ser la mejor opción para el objetivo de StableHLO de lo que crea más interoperabilidad entre frameworks y compiladores de AA.
Op ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic ::= 'abs' | 'add' | ...
Las operaciones StableHLO (que también se denominan ops) tienen un nombre,
las entradas y salidas
y una firma. El nombre consta del prefijo stablehlo.
y
un mnemónico que identifica de forma exclusiva una de las operaciones admitidas. Consulta a continuación para
una lista completa de todas las operaciones admitidas.
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue ::= ValueId
OpInputFuncs ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs ::= [OpOutput {',' OpOutput} '=']
OpOutput ::= ValueId
Las operaciones consumen entradas y producen salidas. Las entradas se categorizan
valores de entrada (calculados durante la ejecución), funciones de entrada (proporcionadas
de forma estática, porque en StableHLO las funciones no son valores de primera clase) y
atributos de entrada (también se proporcionan estáticamente). El tipo de entradas y salidas
consumidos y producidos por una op dependen de su mnemotecnia. Por ejemplo, add
.
op consume 2 valores de entrada y produce 1 valor de salida. En comparación,
La op select_and_scatter
consume 3 valores de entrada, 2 funciones de entrada, y
3 atributos de entrada.
OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused ::= '^' digit {digit}
| '^' letter {letter | digit}
Las funciones de entrada (también llamadas funciones anónimas) son muy
similares a las funciones con nombre, excepto que 1) no tienen identificador (por lo tanto,
el nombre "anónimo"), 2) no declaran los tipos de salida (los tipos de salida son
inferido de la op return
dentro de la función).
La sintaxis de las funciones de entrada incluye una parte que no se usa actualmente (consulta la
la producción de Unused
anterior) que es compatible con MLIR. En MLIR,
hay un concepto más general de "regiones" que pueden tener varios “bloqueos”
de ops conectadas entre sí a través de jump ops. Estos bloques tienen IDs que corresponden
a la producción de Unused
para que se distingan entre sí.
StableHLO no tiene operaciones de salto, por lo que la parte correspondiente de la sintaxis de MLIR es
sin usar (pero aún están ahí).
OpInputAttr ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName ::= letter {letter | digit}
OpInputAttrValue ::= Constant
Los atributos de entrada tienen un nombre y un valor que es uno de los admitidos
constantes. Son la forma principal de especificar metadatos estáticos para programas
o de terceros. Por ejemplo, la op concatenate
usa el atributo dimension
para
especifica la dimensión en la que se concatenan sus valores de entrada. De forma similar,
La op slice
usa varios atributos, como start_indices
y limit_indices
para especificar los límites que se usan para dividir el valor de entrada.
Por el momento, los programas de StableHLO en el entorno a veces contienen atributos que no se describen en este documento. En el futuro, planeamos ya sea absorban estos atributos en el conjunto de operaciones StableHLO o les prohíben que aparece en los programas de StableHLO. Mientras tanto, te presentamos la lista de estos atributos:
layout
(#629).mhlo.frontend_attributes
(#628).mhlo.sharding
(#619).output_operand_aliases
(#740).- Metadatos de ubicación (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'
La firma de operaciones consta de los tipos de todos los valores de entrada (la lista de tipos en
el lado izquierdo de ->
) y los tipos de todos los valores de salida (la lista de
a la derecha de ->
). En sentido estricto, los tipos de entrada
redundantes, y los tipos de salida casi siempre también son redundantes (porque para
la mayoría de las operaciones de StableHLO, los tipos de salida se pueden inferir de las entradas). No obstante, op
es parte de la sintaxis de StableHLO para brindar compatibilidad con MLIR.
A continuación, se muestra un ejemplo de una op cuyo nombre nemotécnico es select_and_scatter
. Consume 3
valores de entrada (%operand
, %source
y %init_value
), 2 funciones de entrada
y 3 atributos de entrada (window_dimensions
, window_strides
y padding
).
Observa que la firma de la operación solo incluye los tipos de sus valores de entrada.
(pero no los tipos de atributos y funciones de entrada que se proporcionan en línea).
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>
Constantes
Constant ::= BooleanConstant
| IntegerConstant
| FloatConstant
| ComplexConstant
| TensorConstant
| QuantizedTensorConstant
| StringConstant
| EnumConstant
Las constantes StableHLO tienen un literal y un tipo que, en conjunto, representan
un valor StableHLO. En general, el tipo es parte de la sintaxis de la constante, excepto
cuando no es ambigua (p.ej., una constante booleana inequívocamente tiene el tipo i1
,
mientras que una constante de número entero puede tener varios tipos posibles).
BooleanConstant ::= BooleanLiteral
BooleanLiteral ::= 'true' | 'false'
Las constantes booleanas representan valores booleanos true
y false
. Booleano
Las constantes tienen el tipo i1
.
IntegerConstant ::= IntegerLiteral ':' IntegerType
IntegerLiteral ::= ['-' | '+'] DecimalDigits
| ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit ::= '0' | ... | '9'
hexadecimalDigit ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'
Las constantes de números enteros representan valores de números enteros mediante cadenas que usan decimales o la notación hexadecimal. Otras bases, p.ej. octal o binaria, no son compatibles. Las constantes de número entero tienen las siguientes restricciones:
- (C1)
is_wellformed(integer_literal, integer_type)
FloatConstant ::= FloatLiteral ':' FloatType
FloatLiteral ::= SignPart IntegerPart FractionalPart ScientificPart
| '0x' [HexadecimalDigits]
SignPart ::= ['-' | '+']
IntegerPart ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]
Las constantes de punto flotante representan valores de punto flotante a través de cadenas que usar notación decimal o científica. Además, la notación hexadecimal puede ser para especificar directamente los bits subyacentes en el formato de punto flotante de el tipo correspondiente. Las constantes de punto flotante tienen las siguientes restricciones:
- (C1) Si se usa notación no hexadecimal,
is_wellformed(float_literal, float_type)
- (C2) Si se usa la notación hexadecimal,
size(hexadecimal_digits) = num_bits(float_type) / 4
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral ::= '(' RealPart ',' ImaginaryPart ')'
RealPart ::= FloatLiteral
ImaginaryPart ::= FloatLiteral
Las constantes complejas representan valores complejos usando listas de una parte real
(va primero) y una parte imaginaria (va en segundo). Por ejemplo:
(1.0, 0.0) : complex<f32>
representa 1.0 + 0.0i
.
(0.0, 1.0) : complex<f32>
representa 0.0 + 1.0i
. El orden en que estas
se almacenan en la memoria y se define en la implementación. Constantes complejas
tienen las siguientes restricciones:
- (C1)
is_wellformed(real_part, complex_element_type(complex_type))
- (C2)
is_wellformed(imaginary_part, complex_element_type(complex_type))
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral
Las constantes de tensor representan valores de tensor con listas anidadas especificadas a través de
Notación NumPy. Por ejemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
.
representa un valor de tensor con la siguiente asignación de índices a elementos:
{0, 0} => 1
, {0, 1} => 2
, {0, 2} => 3
, {1, 0} => 4
y {1, 1} => 5
,
{1, 2} => 6
El orden en que estos elementos se almacenan en la memoria es
definido por la implementación. Las constantes de tensor tienen las siguientes restricciones:
- (C1)
has_syntax(tensor_literal, element_type(tensor_type))
, donde:has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type)
.has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type)
.
- (C2)
has_shape(tensor_literal, shape(tensor_type))
, donde:has_shape(element_literal: Syntax, []) = true
.has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:])
.- de lo contrario,
false
.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
Las constantes de tensor cuantificadas representan valores de tensores cuantificados usando las mismas como constantes de tensor, con elementos especificados como constantes de su de almacenamiento. Las constantes del tensor cuantificadas tienen las siguientes restricciones:
- (C1)
has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type))
- (C2)
has_shape(quantized_tensor_literal, shape(quantized_tensor_type))
StringConstant ::= StringLiteral
StringLiteral ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))
Los literales de cadena consisten en bytes especificados con caracteres ASCII y
secuencias de escape. Son agnósticas a la codificación, por lo que la interpretación de estos
bytes está definido por la implementación. Los literales de string tienen el tipo string
.
Ops
abdominales
Semántica
Realiza una operación abs a nivel de elementos en el tensor operand
y produce un result
tensor. Según el tipo de elemento, hace lo siguiente:
- Para números enteros firmados: módulo de números enteros
- Para números de punto flotante:
abs
de IEEE-754. - Para números complejos: módulo complejo.
- Para tipos cuantizados:
dequantize_op_quantize(abs, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de número entero con firma, de punto flotante o de tipo complejo, o tensor cuantificado por tensor | (C1-C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero con firma o de tipo de punto flotante, o tensor cuantificado por tensor | (C1-C2) |
Limitaciones
- (C1)
shape(result) = shape(operand)
- (C2)
baseline_element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.- De lo contrario,
baseline_element_type(operand)
.
Ejemplos
// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]
add
Semántica
Realiza la adición a nivel de elementos de dos tensores, lhs
y rhs
, y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: suma de números enteros.
- Para números de punto flotante:
addition
de IEEE-754. - Para números complejos: suma compleja.
- Para tipos cuantizados:
dequantize_op_quantize(add, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado | (C1-C6) |
(I2) | rhs |
tensor o tensor cuantificado | (C1-C5), (C7) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C7) |
Limitaciones
- Si la operación usa tensores no cuantificados:
- (C1)
type(lhs) = type(rhs) = type(result)
- (C1)
- Si la operación usa tensores cuantificados:
- (C2)
is_quantized(lhs) and is_quantized(rhs) and is_quantized(result)
- (C3)
storage_type(lhs) = storage_type(rhs) = storage_type(result)
- (C4)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
- (C5)
(is_per_axis_quantized(lhs) or is_per_axis_quantized(rhs)) = is_per_axis_quantized(result)
- (C6) Si es
is_per_axis_quantized(lhs)
, entoncesquantization_dimension(lhs) = quantization_dimension(result)
. - (C7) Si es
is_per_axis_quantized(rhs)
, entoncesquantization_dimension(rhs) = quantization_dimension(result)
.
- (C2)
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]
after_all
Semántica
Garantiza que las operaciones que producen el inputs
se ejecuten antes que cualquier
las operaciones que dependen de result
. La ejecución de esta operación no tiene ningún efecto,
Solo existe para establecer dependencias de datos de result
a inputs
.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
cantidad variádica de token |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Ejemplos
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
all_gather
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, concatena los valores
de los tensores operands
de cada proceso a lo largo de all_gather_dim
y produce
results
.
La operación divide la cuadrícula de procesos StableHLO en process_groups
, que tiene el siguiente valor:
se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
si eschannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
si eschannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
operands...@receiver = [operand@sender for sender in process_group]
para todosreceiver
enprocess_group
.results...@process = concatenate(operands...@process, all_gather_dim)
para todosprocess
enprocess_group
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operands |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1), (C6) |
(I2) | all_gather_dim |
constante de tipo si64 |
(C1), (C6) |
(I3) | replica_groups |
Constante de tensor bidimensional de tipo si64 |
(C2-C4) |
(I4) | channel_id |
constante de tipo si64 |
C5 |
(I5) | use_global_device_ids |
constante de tipo i1 |
C5 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C6) |
Limitaciones
- (C1)
0 <= all_gather_dim < rank(operands...)
- (C2)
is_unique(replica_groups)
- (C3)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C4)
0 <= replica_groups < size(replica_groups)
- (C5) Si es
use_global_device_ids = true
, entonceschannel_id > 0
. - (C6)
type(results...) = type(operands...)
, excepto:dim(results..., all_gather_dim) = dim(operands..., all_gather_dim) * dim(process_groups, 1)
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
all_reduce
Semántica
Dentro de cada grupo de procesos de la cuadrícula de procesos StableHLO, aplica una reducción
función computation
con los valores de los tensores operands
de cada proceso
y produce tensores results
.
La operación divide la cuadrícula de procesos StableHLO en process_groups
, que tiene el siguiente valor:
se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
si eschannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
si eschannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
results...@process[result_index] = exec(schedule)
para un árbol binarioschedule
donde:exec(node)
=computation(exec(node.left), exec(node.right))
.exec(leaf)
=leaf.value
.
schedule
es un árbol binario definido por la implementación cuyo orden El recorrido esto_destination_type(operands...@process_group...[result_index], type(func_inputs(computation)[0]))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operands |
cantidad variádica de tensores o tensores cuantificados por tensor | (C5), (C6) |
(I2) | replica_groups |
Número variádico de constantes tensores unidimensionales del tipo si64 |
(C1-C3) |
(I3) | channel_id |
constante de tipo si64 |
(C4) |
(I4) | use_global_device_ids |
constante de tipo i1 |
(C4) |
(I5) | computation |
función | C5 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C6-C7) |
Limitaciones
- (C1)
is_unique(replica_groups)
- (C2)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C3)
0 <= replica_groups < size(replica_groups)
- (C4) Si es
use_global_device_ids = true
, entonceschannel_id > 0
. - (C5)
computation
tiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, dondeis_promotable(element_type(operand), E)
- (C6)
shape(results...) = shape(operands...)
- (C7)
element_type(results...) = E
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
all_to_all
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, divide los valores de
los tensores operands
a lo largo de split_dimension
en partes, dispersa la división
partes entre los procesos, concatena las partes dispersas junto
concat_dimension
y produce tensores results
.
La operación divide la cuadrícula de procesos StableHLO en process_groups
, que tiene el siguiente valor:
se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0
.cross_partition(replica_groups)
si eschannel_id > 0
.
Luego, dentro de cada process_group
, haz lo siguiente:
split_parts...@sender = split(operands...@sender, split_count, split_dimension)
para todos lossender
deprocess_group
.scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group]
dondereceiver_index = process_group.index(receiver)
results...@process = concatenate(scattered_parts...@process, concat_dimension)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operands |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C3), (C9) |
(I2) | split_dimension |
constante de tipo si64 |
(C1), (C2), (C9) |
(I3) | concat_dimension |
constante de tipo si64 |
(C3) y (C9) |
(I4) | split_count |
constante de tipo si64 |
(C2), (C4), (C8) y (C9) |
(I5) | replica_groups |
Constante de tensor bidimensional de tipo si64 |
(C5-C8) |
(I6) | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C9) |
Limitaciones
- (C1)
0 <= split_dimension < rank(operands...)
- (C2)
dim(operands..., split_dimension) % split_count = 0
- (C3)
0 <= concat_dimension < rank(operands...)
- (C4)
0 < split_count
- (C5)
is_unique(replica_groups)
- (C6)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C7)
0 <= replica_groups < size(replica_groups)
. - (C8)
dim(replica_groups, 1) = split_count
- (C9)
type(results...) = type(operands...)
, excepto si essplit_dimension != concat_dimension
:dim(results..., split_dimension) = dim(operands..., split_dimension) / split_count
.dim(results..., concat_dimension) = dim(operands..., concat_dimension) * split_count
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand1@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand1@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
// %operand2@(0, 0): [[17, 18, 19, 20],
// [21, 22, 23, 24]]
// %operand2@(1, 0): [[25, 26, 27, 28],
// [29, 30, 31, 32]]
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
// %result#1@(0, 0): [[17, 18], [21, 22], [25, 26], [29, 30]]
// %result#1@(1, 0): [[19, 20], [23, 24], [27, 28], [31, 32]]
y
Semántica
Realiza el operador AND a nivel de los elementos de dos tensores, lhs
y rhs
, y produce un result
.
tensor. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico AND.
- Para números enteros: AND bit a bit.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo booleano o de número entero | C1 |
(I2) | rhs |
tensor de tipo booleano o de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano o de número entero | C1 |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]
atan2
Semántica
Realiza la operación atan2 en cuanto a elementos en los tensores lhs
y rhs
, y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
atan2
de IEEE-754. - Para números complejos: complejo atan2.
- Para tipos cuantizados:
dequantize_op_quantize(atan2, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
Ejemplos
// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]
batch_norm_grad
Semántica
Calcula gradientes de varias entradas de propagación inversa de batch_norm_training
.
de grad_output
y produce grad_operand
, grad_scale
y grad_offset
tensores. Más formalmente, esta operación puede expresarse como una descomposición en
operaciones StableHLO existentes con la sintaxis de Python de la siguiente manera:
def compute_sum(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
return sum
def compute_mean(operand, feature_index):
sum = compute_sum(operand, feature_index)
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
# Broadcast inputs to type(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance`
# Intermediate values will be useful for computing gradients
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
# Use the implementation from batchnorm_expander.cc in XLA
# Temporary variables have exactly the same names as in the C++ code
elements_per_feature = broadcast_in_dim(
constant(divide(size(operand), dim(operand, feature_index)),
element_type(grad_output)),
[], type(operand))
i1 = multiply(grad_output, elements_per_feature)
i2 = broadcast_in_dim(
compute_sum(grad_output, feature_index), [feature_index], type(operand))
i3 = broadcast_in_dim(
compute_sum(multiply(grad_output, centered_operand), feature_index),
[feature_index], type(operand))
i4 = multiply(i3, centered_operand)
i5 = divide(i4, add(variance_bcast, epsilon_bcast))
i6 = subtract(subtract(i1, i2), i5)
grad_operand =
multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
grad_scale =
compute_sum(multiply(grad_output, normalized_operand), feature_index)
grad_offset = compute_sum(grad_output, feature_index)
return grad_operand, grad_scale, grad_offset
Para los tipos cuantizados, realiza
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean,
variance, grad_output: batch_norm_grad(operand, scale, mean, variance,
grad_output, epsilon, feature_index), operand, scale, mean, variance,
grad_output, type(grad_operand), type(grad_scale), type(feature_index))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1-C3), (C5) |
(I2) | scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2), (C4), (C5) |
(I3) | mean |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I4) | variance |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I5) | grad_output |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2) y (C3) |
(I6) | epsilon |
constante de tipo f32 |
|
(I7) | feature_index |
constante de tipo si64 |
(C1), (C5) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
grad_operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2) y (C3) |
grad_scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
grad_offset |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand)
- (C2)
operand
,scale
,mean
,variance
,grad_output
,grad_operand
,grad_scale
ygrad_offset
tienen el mismobaseline_element_type
. - (C3)
operand
,grad_output
ygrad_operand
tienen la misma forma. - (C4)
scale
,mean
,variance
,grad_scale
ygrad_offset
tienen la misma forma. - (C5)
size(scale) = dim(operand, feature_index)
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
// [[0.1, 0.1], [0.1, 0.1]],
// [[0.1, 0.1], [0.1, 0.1]]
// ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
// [[0.0, 0.0], [0.0, 0.0]],
// [[0.0, 0.0], [0.0, 0.0]]
// ]
// %grad_scale: [0.0, 0.0]
// %grad_offset: [0.4, 0.4]
batch_norm_inference
Semántica
Normaliza el tensor operand
en todas las dimensiones excepto en el
feature_index
y produce un tensor result
. Más formalmente, esto
la operación se puede expresar como una descomposición de las operaciones StableHLO existentes
con la sintaxis de Python de la siguiente manera:
def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
# Broadcast inputs to shape(operand)
scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
type(operand))
# Perform normalization using the provided `mean` and `variance` instead of
# computing them like `batch_norm_training` does.
centered_operand = subtract(operand, mean_bcast)
stddev = sqrt(add(variance_bcast, epsilon_bcast))
normalized_operand = divide(centered_operand, stddev)
return add(multiply(scale_bcast, normalized_operand), offset_bcast)
Para los tipos cuantizados, realiza
dequantize_op_quantize(lambda operand, scale, offset, mean, variance:
batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index), operand, scale, offset, mean, variance, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C1-C7) |
(I2) | scale |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C3) |
(I3) | offset |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C4) |
(I4) | mean |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | C5 |
(I5) | variance |
Tensor unidimensional de punto flotante o tipo cuantificado por tensor | (C2) y (C6) |
(I6) | epsilon |
constante de tipo f32 |
|
(I7) | feature_index |
constante de tipo si64 |
(C1), (C3-C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C2) y (C7) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand)
- (C2)
operand
,scale
,offset
,mean
,variance
yresult
tienen la mismobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
- (C4)
size(offset) = dim(operand, feature_index)
- (C5)
size(mean) = dim(operand, feature_index)
- (C6)
size(variance) = dim(operand, feature_index)
- (C7)
baseline_type(operand) = baseline_type(result)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
batch_norm_training
Semántica
Calcula la media y la varianza en todas las dimensiones, excepto la feature_index
.
y normaliza el tensor operand
, lo que produce output
, batch_mean
y batch_var
. Más formalmente, esta operación puede expresarse como
descomposición en operaciones StableHLO existentes con la sintaxis de Python como
sigue:
def compute_mean(operand, feature_index):
(sum,) = reduce(
inputs=[operand],
init_values=[constant(0, element_type(operand))],
dimensions=[i for i in range(rank(operand)) if i != feature_index],
body=lambda x, y: add(x, y))
divisor = constant(size(operand) / dim(operand, feature_index),
element_type(operand))
divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
return divide(sum, divisor_bcast)
def compute_variance(operand, feature_index):
mean = compute_mean(operand, feature_index)
mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
centered_operand = subtract(operand, mean_bcast)
return compute_mean(mul(centered_operand, centered_operand), feature_index)
def batch_norm_training(operand, scale, offset, epsilon, feature_index):
mean = compute_mean(operand, feature_index)
variance = compute_variance(operand, feature_index)
return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
feature_index),
mean, variance
Para los tipos cuantizados, realiza
dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset:
batch_norm_training(operand, scale, offset, epsilon, feature_index), operand,
scale, offset, type(output), type(batch_mean), type(batch_var))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
(I2) | scale |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C3) |
(I3) | offset |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C4) |
(I4) | epsilon |
constante de tipo f32 |
(C1), (C3-C6) |
(I5) | feature_index |
constante de tipo si64 |
(C1), (C3-C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor de tipo de punto flotante o tensor cuantificado por tensor | (C7) |
batch_mean |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C5) |
batch_var |
Tensor unidimensional de punto flotante o por tensor cuantificado | (C2) y (C6) |
Limitaciones
- (C1)
0 <= feature_index < rank(operand)
- (C2)
operand
,scale
,offset
,batch_mean
,batch_var
youtput
tienen el mismobaseline_element_type
. - (C3)
size(scale) = dim(operand, feature_index)
- (C4)
size(offset) = dim(operand, feature_index)
- (C5)
size(batch_mean) = dim(operand, feature_index)
- (C6)
size(batch_var) = dim(operand, feature_index)
- (C7)
baseline_type(output) = baseline_type(operand)
.
Ejemplos
// %operand: [
// [[1.0, 2.0], [3.0, 4.0]],
// [[3.0, 4.0], [1.0, 2.0]]
// ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
epsilon = 0.0 : f32,
feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
(tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
// [[0.0, 0.0], [2.0, 2.0]],
// [[2.0, 2.0], [0.0, 0.0]]
// ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]
bitcast_convert
Semántica
Realiza una operación de transmisión de bits en el tensor operand
y produce un tensor result
en el que los bits del tensor operand
completo se reinterpretan mediante el parámetro
del tensor result
.
Más formalmente, dado E = element_type(operand)
, E' = element_type(result)
,
y R = rank(operand)
:
- Si es
num_bits(E') < num_bits(E)
,bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])
- Si es
num_bits(E') > num_bits(E)
,bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])
- Si es
num_bits(E') = num_bits(E)
,bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])
bits
muestra una representación en la memoria de un valor determinado y su comportamiento
está definido por la implementación porque la representación exacta de los tensores está
definido por la implementación, y la representación exacta de los tipos de elementos es
qué es la implementación.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C2) |
Limitaciones
- (C1) Dados
E = is_quantized(operand) ? storage_type(operand) : element_type(operand)
,E' = is_quantized(result) ? storage_type(result) : element_type(result)
yR = rank(operand)
:- Si es
num_bits(E') = num_bits(E)
,shape(result) = shape(operand)
. - Si es
num_bits(E') < num_bits(E)
: rank(result) = R + 1
.dim(result, i) = dim(operand, i)
para todos los0 <= i < R
.dim(result, R) * num_bits(E') = num_bits(E)
.- Si es
num_bits(E') > num_bits(E)
: rank(result) = R - 1
.dim(result, i) = dim(operand, i)
para todos los0 <= i < R
.dim(operand, R - 1) * num_bits(E) = num_bits(E')
.
- Si es
- (C2) Si es
is_complex(operand) or is_complex(result)
, entoncesis_complex(operand) and is_complex(result)
Ejemplos
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
broadcast_in_dim
Semántica
Expande las dimensiones o la clasificación de un tensor de entrada mediante la duplicación de los datos.
en el tensor operand
y produce un tensor result
. Más formalmente,
result[result_index] = operand[operand_index]
donde para todos los d
en
axes(operand)
:
operand_index[d] = 0
si esdim(operand, d) = 1
.- De lo contrario,
operand_index[d] = result_index[broadcast_dimensions[d]]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C2), (C5-C6) |
(I2) | broadcast_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C2-C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3), (C5-C6) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
,scales(operand)
yzero_points(operand)
pueden diferir dequantization_dimension(result)
,scales(result)
yzero_points(result)
de lo contrario.
- (C2)
size(broadcast_dimensions) = rank(operand)
- (C3)
0 <= broadcast_dimensions < rank(result)
- (C4)
is_unique(broadcast_dimensions)
- (C5) Para todos los
d
deaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si es
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si es
dim(operand, quantization_dimension(operand)) = 1
, entoncesscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
Ejemplos
// %operand: [
// [1, 2, 3]
// ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
caso
Semántica
Produce el resultado de la ejecución de exactamente una función de branches
.
según el valor de index
. Más formalmente, result = selected_branch()
En el ejemplo anterior, se ilustra lo siguiente:
selected_branch = branches[index]
si es0 <= index < size(branches)
.- De lo contrario,
selected_branch = branches[-1]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | index |
Tensor de dimensión 0 de tipo si32 |
|
(I2) | branches |
número variádico de funciones | (C1-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
número variádico de tensores, tensores cuantificados o tokens | (C4) |
Limitaciones
- (C1)
0 < size(branches)
- (C2)
input_types(branches...) = []
- (C3)
same(output_types(branches...))
- (C4)
type(results...) = output_types(branches[0])
Ejemplos
// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]
CTR
Semántica
Realiza la operación de raíz cúbica a nivel de elementos en el tensor operand
y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
rootn(x, 3)
de IEEE-754. - Para números complejos: raíz cúbica compleja.
- Para tipos cuantizados:
dequantize_op_quantize(cbrt, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]
Celeste
Semántica
Realiza el bloqueo a nivel de elementos del tensor operand
y produce un tensor result
.
Implementa la operación roundToIntegralTowardPositive
del estándar IEEE-754.
especificación. Para los tipos cuantizados, realiza
dequantize_op_quantize(ceil, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]
collesky
Semántica
Calcula la descomposición de Cholesky de un lote de matrices.
De manera más formal, para todos los i
de index_space(result)
,
result[i0, ..., iR-3, :, :]
es una descomposición de Cholesky
a[i0, ..., iR-3, :, :]
, en forma de un elemento triangular inferior
(si lower
es true
) o una matriz triangular superior (si lower
es false
).
Los valores de salida en el triángulo opuesto, es decir, el triángulo superior estricto o
del triángulo inferior estrictos, están definidos por la implementación.
Si existe i
en la que la matriz de entrada no es un valor definido positivo de Hermitian
matriz, entonces el comportamiento es indefinido.
Para los tipos cuantizados, realiza
dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | a |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | (C1-C3) |
(I2) | lower |
Constante de tensor de dimensión 0 de tipo i1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(a) = baseline_type(result)
- (C2)
2 <= rank(a)
- (C3)
dim(a, -2) = dim(a, -1)
Ejemplos
// %a: [
// [1.0, 2.0, 3.0],
// [2.0, 20.0, 26.0],
// [3.0, 26.0, 70.0]
// ]
%result = "stablehlo.cholesky"(%a) {
lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
restringir
Semántica
Fija cada elemento del tensor operand
entre un mínimo y un máximo.
de salida y produce un tensor result
. Más formalmente, result[result_index] =
minimum(maximum(operand[result_index], min_element), max_element)
,
donde min_element = rank(min) = 0 ? min[] : min[result_index]
,
max_element = rank(max) = 0 ? max[] : max[result_index]
Para los tipos cuantizados,
realiza dequantize_op_quantize(clamp, min, operand, max, type(result))
.
Imponer un orden en números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | min |
tensor o tensor cuantificado por tensor | (C1) y (C3) |
(I2) | operand |
tensor o tensor cuantificado por tensor | (C1-C4) |
(I3) | max |
tensor o tensor cuantificado por tensor | (C2) y (C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C4) |
Limitaciones
- (C1)
rank(min) = 0 or shape(min) = shape(operand)
- (C2)
rank(max) = 0 or shape(max) = shape(operand)
- (C3)
baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max)
- (C4)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]
collective_broadcast
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, envía el valor del
operand
desde el proceso de origen hasta los procesos de destino y produce un
Tensor result
.
La operación divide la cuadrícula de procesos StableHLO en process_groups
, que tiene el siguiente valor:
se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0
.cross_partition(replica_groups)
si eschannel_id > 0
.
Luego, result@process
se obtiene de la siguiente manera:
operand@process_groups[i, 0]
si existe uni
que hace que el proceso se enprocess_groups[i]
.broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
de lo contrario.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C3) |
(I2) | replica_groups |
Número variádico de constantes tensores unidimensionales del tipo si64 |
(C1) y (C2) |
(I3) | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3) |
Limitaciones
- (C1)
is_unique(replica_groups)
- (C2)
0 <= replica_groups < N
, en el queN
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C3)
type(result) = type(operand)
Ejemplos
// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]
collective_permute
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, envía el valor del
operand
desde el proceso de origen al proceso de destino y produce un
Tensor result
.
La operación divide la cuadrícula de procesos StableHLO en process_groups
, que tiene el siguiente valor:
se define de la siguiente manera:
cross_replica(source_target_pairs)
si eschannel_id <= 0
.cross_partition(source_target_pairs)
si eschannel_id > 0
.
Luego, result@process
se obtiene de la siguiente manera:
operand@process_groups[i, 0]
, si existe uni
queprocess_groups[i, 1] = process
broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result))
de lo contrario.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | C5 |
(I2) | source_target_pairs |
Constante de tensor bidimensional de tipo si64 |
(C1-C4) |
(I3) | channel_id |
constante de tipo si64 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
dim(source_target_pairs, 1) = 2
- (C2)
is_unique(source_target_pairs[:, 0])
- (C3)
is_unique(source_target_pairs[:, 1])
- (C4)
0 <= source_target_pairs < N
, dondeN
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_partitions
si se usacross_partition
.
- (C5)
type(result) = type(operand)
Ejemplos
// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]
comparar
Semántica
Realiza una comparación a nivel de elementos de los tensores lhs
y rhs
de acuerdo con
comparison_direction
y compare_type
, y produce un tensor result
.
Los valores de comparison_direction
y compare_type
tienen lo siguiente:
semántica:
Para los tipos de elementos booleanos y de número entero:
EQ
:lhs = rhs
NE
:lhs != rhs
GE
:lhs >= rhs
GT
:lhs > rhs
LE
:lhs <= rhs
LT
:lhs < rhs
Para los tipos de elementos de punto flotante con compare_type = FLOAT
, la operación implementa
las siguientes operaciones de IEEE-754:
EQ
:compareQuietEqual
NE
:compareQuietNotEqual
GE
:compareQuietGreaterEqual
GT
:compareQuietGreater
LE
:compareQuietLessEqual
LT
:compareQuietLess
Para los elementos de punto flotante con compare_type = TOTALORDER
, la op
usa la combinación de operaciones totalOrder
y compareQuietEqual
de
IEEE-754.
Para los tipos de elementos complejos, la comparación lexicográfica de los pares (real, imag)
es
realizar con los comparison_direction
y compare_type
proporcionados.
Imponer un orden en números complejos implica una semántica sorprendente,
por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos
cuando comparison_direction
es GE
, GT
, LE
o LT
(#560).
Para tipos cuantizados. realiza dequantize_compare(lhs, rhs,
comparison_direction)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1-C3) |
(I2) | rhs |
tensor o tensor cuantificado por tensor | (C1-C2) |
(I3) | comparison_direction |
enum de EQ , NE , GE , GT , LE y LT |
|
(I4) | compare_type |
enum de FLOAT , TOTALORDER , SIGNED y UNSIGNED |
(C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano | (C2) |
Limitaciones
- (C1)
baseline_element_type(lhs) = baseline_element_type(rhs)
- (C2)
shape(lhs) = shape(rhs) = shape(result)
- (C3)
compare_type
se define de la siguiente manera:SIGNED
si esis_signed_integer(element_type(lhs))
.UNSIGNED
si esis_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs))
.FLOAT
oTOTALORDER
si esis_float(element_type(lhs))
.FLOAT
si esis_complex(element_type(lhs))
.
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
comparison_direction = #stablehlo<comparison_direction LT>,
compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]
complejo
Semántica
Realiza una conversión por elementos a un valor complejo a partir de un par de valores reales y
valores imaginarios, lhs
y rhs
, y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo f32 o f64 |
(C1-C3) |
(I2) | rhs |
tensor de tipo f32 o f64 |
C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo complejo | (C2) y (C3) |
Limitaciones
- (C1)
type(lhs) = type(rhs)
- (C2)
shape(result) = shape(lhs)
- (C3)
element_type(result)
tiene el tipocomplex<E>
, dondeE = element_type(lhs)
Ejemplos
// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]
compuesto
Semántica
Encapsula una operación compuesta (compuesta) por otras operaciones StableHLO.
Tomando inputs
y composite_attributes
, y produciendo results
. El
la semántica de la op se implementa con el atributo decomposition
. El
La op composite
se puede reemplazar por su descomposición sin cambiar el programa
semántica. En los casos en que el intercalado de la descomposición no proporcione la misma
semántica de op, es preferible usar custom_call
.
El campo version
(el valor predeterminado es 0
) se utiliza para denotar cuando el valor de
cambios semánticos.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
cantidad variádica de valores |
(I2) | name |
constante de tipo string |
(I3) | composite_attributes |
diccionario de atributos |
(I4) | decomposition |
constante de tipo string |
(I5) | version |
constante de tipo si32 |
Salidas
Nombre | Tipo |
---|---|
results |
cantidad variádica de valores |
Limitaciones
is_namespaced_op_name(name)
(C1)- C2:
is_defined_in_parent_scope(decomposition)
- C3:
types(inputs...) == input_types(decomposition)
- C4:
types(results...) == output_types(decomposition)
Ejemplos
%results = "stablehlo.composite"(%input0, %input1) {
name = "my_namespace.my_op",
composite_attributes = {
my_attribute = "my_value"
},
decomposition = @my_op,
version = 1 : i32
} : (tensor<f32>, tensor<f32>) -> tensor<f32>
concatenate
Semántica
Concatena inputs
en la dimensión dimension
en el mismo orden que el dado.
y produce un tensor result
. Más formalmente,
result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1]
, donde:
id = d0 + ... + dk-1 + kd
.d
equivale adimension
, yd0
, ... son tamaños ded
.a dimensión deinputs
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C6) |
(I2) | dimension |
constante de tipo si64 |
(C2), (C4), (C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5-C6) |
Limitaciones
- (C1)
same(element_type(inputs...))
- (C2)
same(shape(inputs...))
, excepto pordim(inputs..., dimension)
. - (C3)
0 < size(inputs)
- (C4)
0 <= dimension < rank(inputs[0])
- (C5)
element_type(result) = element_type(inputs[0])
- (C6)
shape(result) = shape(inputs[0])
, excepto por:dim(result, dimension) = dim(inputs[0], dimension) + ...
.
Ejemplos
// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]
constante
Semántica
Produce un tensor output
a partir de una constante value
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | value |
constante | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor o tensor cuantificado | C1 |
Limitaciones
- (C1)
type(value) = type(output)
Ejemplos
%output = "stablehlo.constant"() {
value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]
generar una conversión
Semántica
Realiza una conversión en los elementos de un tipo de elemento a otro en
operand
y produce un tensor result
.
Para las conversiones de tipo boolean-to-any-supported-type, el valor false
es
convierte a cero y el valor true
se convierte en uno. Para
any-supported-type-to-boolean, un valor cero se convierte en
false
y los valores que no son cero se convierten en true
. Vea a continuación cómo esto
funcionan para tipos complejos.
Para conversiones que involucran integer-to-integer, integer-to-floating-point o floating-point-to-floating-point, si el valor de origen puede ser exactamente representados en el tipo de destino, el valor del resultado es que para la representación de los datos. De lo contrario, el comportamiento está por definir. (#180).
Para las conversiones que implican floating-point-to-integer, la parte fraccionaria es truncados. Si el valor truncado no se puede representar en el tipo de destino, el comportamiento está por definir (#180).
Las conversiones de complejo a complejo siguen el mismo comportamiento que floating-point-to-floating-point para convertir conversiones reales en partes imaginarias.
Para las conversiones complex-to-any-other-type y de complex-to-any-other-type, el valor imaginario de origen se ignora o el valor imaginario de destino se en cero, respectivamente. La conversión de la parte real sigue la las conversiones de punto flotante.
En principio, esta operación podría expresar una descuantización (conversión de
de tensores cuantificados a tensores regulares), la cuantización (conversión de tensores regulares a
de tensores a tensores cuantificados) y la recuantización (conversión entre valores
tensores), pero, por el momento, tenemos operaciones dedicadas para eso.
uniform_dequantize
para el primer caso de uso y uniform_quantize
para el
segundo y tercer caso de uso. En el futuro, estas dos operaciones podrían combinarse
a convert
(#1576).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor | C1 |
Limitaciones
- (C1)
shape(operand) = shape(result)
Ejemplos
// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]
convolución
Semántica
Calcula productos punto entre ventanas de lhs
y porciones de rhs
y produce
result
En el siguiente diagrama, se muestra cómo se calculan los elementos de result
a partir de
lhs
y rhs
con un ejemplo concreto.
De manera más formal, considera la siguiente reformulación de las entradas en términos de lhs
.
para poder expresar ventanas de lhs
:
lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension))
.lhs_window_strides = lhs_shape(1, window_strides, 1)
.lhs_padding = lhs_shape([0, 0], padding, [0, 0])
.lhs_base_dilations = lhs_shape(1, lhs_dilation, 1)
.lhs_window_dilations = lhs_shape(1, rhs_dilation, 1)
.
Este reencuadre utiliza las siguientes funciones auxiliares:
lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])
.result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])
.permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1]
dondej[d] = i[permutation[d]]
.
Si es feature_group_count = 1
y batch_group_count = 1
, entonces para todos
output_spatial_index
en index_space(dim(result, output_spatial_dimensions...))
,
result[result_shape(:, output_spatial_index, :)] = dot_product
donde:
padding_value = constant(0, element_type(lhs))
.padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1)
.lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides
.lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations)
.reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true])
Parece que esta función no se usa, por lo que planeamos quitarla en el futuro. (#1181).dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension])
.
Si es feature_group_count > 1
:
lhses = split(lhs, feature_group_count, input_feature_dimension)
.rhses = split(rhs, feature_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
.
Si es batch_group_count > 1
:
lhses = split(lhs, batch_group_count, input_batch_dimension)
.rhses = split(rhs, batch_group_count, kernel_output_feature_dimension)
.results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...)
.result = concatenate(results, output_feature_dimension)
Para los tipos cuantizados, realiza dequantize_op_quantize(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))
.
Para tipos híbridos cuantizados, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension,
input_feature_dimension, input_spatial_dimensions,
kernel_input_feature_dimension, kernel_output_feature_dimension,
kernel_spatial_dimensions, output_batch_dimension,
output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1), (C10-C11), (C14) (C25), (C27-C28), (C31-C32) y (C34) |
(I2) | rhs |
tensor o tensor cuantificado | (C1), (C14-C16), (C25), (C27-C29), (C31-C34) |
(I3) | window_strides |
Constante de tensor unidimensional de tipo si64 |
(C2-C3), (C25) |
(I4) | padding |
Constante de tensor bidimensional de tipo si64 |
(C4) y (C25) |
(I5) | lhs_dilation |
Constante de tensor unidimensional de tipo si64 |
(C5-C6), (C25) |
(I6) | rhs_dilation |
Constante de tensor unidimensional de tipo si64 |
(C7-C8), (C25) |
(I7) | window_reversal |
Constante de tensor unidimensional de tipo i1 |
(C9) |
(I8) | input_batch_dimension |
constante de tipo si64 |
(C10), (C13), (C25) |
(I9) | input_feature_dimension |
constante de tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C12), (C13) y (C25) |
(I11) | kernel_input_feature_dimension |
constante de tipo si64 |
(C14) y (C18) |
(I12) | kernel_output_feature_dimension |
constante de tipo si64 |
(C15-C16), (C18), (C25) y (C29) |
(I13) | kernel_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C17-C18), (C25) |
(I14) | output_batch_dimension |
constante de tipo si64 |
(C20), (C25) |
(I15) | output_feature_dimension |
constante de tipo si64 |
(C20), (C25), (C30) |
(I16) | output_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C19-C20), (C25) |
(I17) | feature_group_count |
constante de tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
constante de tipo si64 |
(C10), (C15), (C22), (C23), (C25) |
(I19) | precision_config |
Cantidad variádica de enumeraciones de DEFAULT , HIGH y HIGHEST |
(C24) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C25-C28), (C30), (C32-34) |
Limitaciones
- (C1)
N = rank(lhs) = rank(rhs)
- (C2)
size(window_strides) = N - 2
- (C3)
0 < window_strides
- (C4)
shape(padding) = [N - 2, 2]
- (C5)
size(lhs_dilation) = N - 2
- (C6)
0 < lhs_dilation
- (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
- (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
- (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
se define de la siguiente manera:dim(lhs, input_batch_dimension) / batch_group_count
si esresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
si esresult_dim = output_feature_dimension
.num_windows
de lo contrario, donde:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si la operación usa tensores no cuantificados:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si la operación usa tensores cuantificados:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si es
is_per_axis_quantized(rhs)
, luego,quantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, entoncesquantization_dimension(result) = output_feature_dimension
- Si es
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
- (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, entoncesis_per_tensor_quantized(result)
- Si es
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Ejemplos
// %lhs: [[
// [
// [1], [2], [5], [6]
// ],
// [
// [3], [4], [7], [8]
// ],
// [
// [10], [11], [14], [15]
// ],
// [
// [12], [13], [16], [17]
// ]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
window_strides = array<i64: 4, 4>,
padding = dense<0> : tensor<2x2xi64>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
// In the StableHLO dialect, dimension numbers are encoded via:
// `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
// "b" is batch dimension, "f" is feature dimension,
// "i" is input feature dimension, "o" is output feature dimension,
// "0/1/etc" are spatial dimensions.
dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[10], [26]],
// [[46], [62]]
// ]]
coseno
Semántica
Realiza una operación de coseno en relación con los elementos en el tensor operand
y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
cos
de IEEE-754. - Para números complejos: coseno complejo.
- Para tipos cuantizados:
dequantize_op_quantize(cosine, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]
count_leading_zeros
Semántica
Realiza un recuento a nivel de elementos de la cantidad de bits cero iniciales en el operand
.
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero | C1 |
Limitaciones
- (C1)
type(operand) = type(result)
Ejemplos
// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]
custom_call
Semántica
Encapsula una operación definida por la implementación call_target_name
que toma
inputs
y called_computations
, y produce results
. has_side_effect
,
backend_config
y api_version
se pueden usar para proporcionar
los metadatos definidos por la implementación.
Por el momento, esta operación contiene un conjunto bastante desorganizado de metadatos que reflejan la evolución orgánica del funcionamiento de su contraparte en el compilador XLA. En el futuro, planeamos unificar estos metadatos (#741).
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
cantidad variádica de valores |
(I2) | call_target_name |
constante de tipo string |
(I3) | has_side_effect |
constante de tipo i1 |
(I4) | backend_config |
constante de tipo string o diccionario de atributos |
(I5) | api_version |
constante de tipo si32 |
(I6) | called_computations |
cantidad variádica de constantes del tipo string |
Salidas
Nombre | Tipo |
---|---|
results |
cantidad variádica de valores |
Ejemplos
%results = "stablehlo.custom_call"(%input0) {
call_target_name = "foo",
has_side_effect = false,
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
dividir
Semántica
Realiza la división en elementos de los tensores de dividendo lhs
y rhs
del divisor, y
produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: la división de enteros que produce el cociente algebraico con cualquier y se descartó una parte fraccionaria.
- Para números de punto flotante:
division
de IEEE-754. - Para números complejos: división compleja
- Para tipos cuantizados:
dequantize_op_quantize(divide, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
Ejemplos
// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]
dot_general
Semántica
Calcula productos punteados entre porciones de lhs
y porciones de rhs
, y produce un
Tensor result
.
Más formalmente, result[result_index] = dot_product
, donde:
lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions]
.rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions]
.result_batching_index + result_lhs_index + result_rhs_index = result_index
dondesize(result_batching_index) = size(lhs_batching_dimensions)
,size(result_lhs_index) = size(lhs_result_dimensions)
ysize(result_rhs_index) = size(rhs_result_dimensions)
transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions)
.transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :])
.reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions))
.transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions)
.transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :])
.reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions))
.dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))
Para los tipos cuantizados, realiza dequantize_op_quantize(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs, type(result))
.
Para tipos híbridos cuantizados, realiza hybrid_dequantize_then_op(
lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions,
rhs_batching_dimensions, lhs_contracting_dimensions,
rhs_contracting_dimensions, precision_config), lhs, rhs)
.
precision_config
controla el equilibrio entre la velocidad y la precisión de
procesamientos en backends del acelerador. Puede ser una de las siguientes opciones (en el
momento, la semántica de estos valores enum está poco especificada, pero estamos
planean abordar esto en
#755):
DEFAULT
: Es el cálculo más rápido, pero la aproximación menos precisa al valor de número original.HIGH
: Es un cálculo más lento, pero una aproximación más precisa al valor. número original.HIGHEST
: Es el cálculo más lento, pero la aproximación más precisa al valor. número original.
Un DotAlgorithm
define las propiedades principales del algoritmo que se usa para implementar
la operación punto, que también define la precisión. Si el atributo del algoritmo
estén configurados, el precision_config
debe ser DEFAULT
. DotAlgorithms
no tienen un valor predeterminado, ya que los parámetros predeterminados son implementación
definido. Por lo tanto, todos los campos del algoritmo de puntos se pueden configurar en None
para especificar un
algoritmo de punto vacío, que usará el valor precision_config
.
Los campos DotAlgorithm
incluyen lo siguiente:
lhs_precision_type
yrhs_precision_type
, las precisiónes que el LHS y El lado derecho de la operación se redondea. Los tipos de precisión son independientes del y tipos de almacenamiento de las entradas y salidas.accumulation_type
es la precisión que se usa para la acumulación.lhs_component_count
,rhs_component_count
ynum_primitive_operations
cuando realizamos un algoritmo que descompone el LHS o el RHS en varios componentes y realiza múltiples operaciones de puntos en esos valores, generalmente para emular una precisión más alta (p.ej., Aprovechar el tipo de datos de inteligencia artificial bfloat16 para cálculos de mayor precisión: bf16_6x tf32_3x, etcétera). En el caso de los algoritmos sin descomposición, estos valores se debe establecer en1
.allow_imprecise_accumulation
para especificar si la acumulación se debe realizar con menor precisión se permite para algunos pasos (p.ej.,CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
Atributos DotAlgorithm
de ejemplo:
// Inputs are casted to tf32, and then accumulated in f32:
{lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false}
// bf16_6x: each input is decomposed to 3 bf16 components, then 6 dot operations are done on those components, and the result is accumulated in f32.
{lhs_precision_type = bf16,
rhs_precision_type = bf16,
accumulation_type = f32,
lhs_component_count = 3,
rhs_component_count = 3,
num_primitive_operations = 6,
allow_imprecise_accumulation = false}
// Inputs are (casted to) f8e5m2, and we accumulate in f32, but for some steps we may accumulate in lower precision.
{lhs_precision_type = f8e5m2,
rhs_precision_type = f8e5m2,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = true}
Depende de las implementaciones decidir qué combinaciones son compatibles. En general, no se garantiza que cada algoritmo sea compatible con cada de acelerador por el consumidor del StableHLO. Si un algoritmo dado no es se debe generar un error en lugar de recurrir alternativa. La verificación de StableHLO proporcionará la mejor verificación, lo que evita que los algoritmos desconocidos sean compatibles con ningún hardware.
Consulta xla_data.proto > Algorithm
para algunos valores de algoritmos admitidos. El ticket #2483 captura el plan para crear una
documento centralizado sobre algoritmos compatibles con backend.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C5-C6), (C9-C10), (C12-C14), (C17-C18), (C20) |
(I2) | rhs |
tensor o tensor cuantificado | (C7-C10), (C12-C20) |
(I3) | lhs_batching_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C1), (C3), (C5), (C9), (C12) |
(I4) | rhs_batching_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C1), (C4), (C7) y (C9) |
(I5) | lhs_contracting_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C2), (C3), (C6), (C10) |
(I6) | rhs_contracting_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C8), (C10), (C16) |
(I7) | precision_config |
Cantidad variádica de enumeraciones de DEFAULT , HIGH y HIGHEST |
(C11) y (C21) |
(I8) | lhs_precision_type |
FloatType o TensorFloat32 | C21 |
(I9) | rhs_precision_type |
FloatType o TensorFloat32 | C21 |
(I10) | accumulation_type |
FloatType o TensorFloat32 | C21 |
(I11) | lhs_component_count |
constante de tipo si32 |
(C21) y (C22) |
(I12) | rhs_component_count |
constante de tipo si32 |
(C21) y (C23) |
(I13) | num_primitive_operations |
constante de tipo si32 |
(C21) y (C24) |
(I14) | allow_imprecise_accumulation |
constante de tipo bool |
C21 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C12), (C14), (C18-C20) |
Limitaciones
- (C1)
size(lhs_batching_dimensions) = size(rhs_batching_dimensions)
- (C2)
size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions)
- (C3)
is_unique(lhs_batching_dimensions + lhs_contracting_dimensions)
- (C4)
is_unique(rhs_batching_dimensions + rhs_contracting_dimensions)
- (C5)
0 <= lhs_batching_dimensions < rank(lhs)
- (C6)
0 <= lhs_contracting_dimensions < rank(lhs)
- (C7)
0 <= rhs_batching_dimensions < rank(rhs)
. - (C8)
0 <= rhs_contracting_dimensions < rank(rhs)
- (C9)
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
. - (C10)
dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...)
. - (C11)
size(precision_config) = 2
- (C12)
shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions)
. - Si la operación usa tensores no cuantificados:
- (C13)
element_type(lhs) = element_type(rhs)
.
- (C13)
- Si la operación usa tensores cuantificados:
- (C14)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C15)
zero_points(rhs) = 0
. - (C16) Si
is_per_axis_quantized(rhs)
, entoncesquantization_dimension(rhs)
no está enrhs_contracting_dimensions
. - Si es
is_quantized(lhs)
: - (C17)
storage_type(lhs) = storage_type(rhs)
. - (C18)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C19) Si
is_per_tensor_quantized(rhs)
, entoncesis_per_tensor_quantized(result)
- Si es
!is_quantized(lhs)
: - (C20)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C14)
- Si es
!is_empty_algorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count, rhs_component_count, num_primitive_operations allow_imprecise_accumulation)
:- (C21)
precision_config... = DEFAULT
. - (C22)
0 < lhs_component_count
. - (C23)
0 < rhs_component_count
. - (C24)
0 < num_primitive_operations
.
- (C21)
Ejemplos
// %lhs: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
// %rhs: [
// [[1, 0],
// [0, 1]],
// [[1, 0],
// [0, 1]]
// ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [2],
rhs_contracting_dimensions = [1]
>,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
algorithm = #stablehlo.dot_algorithm<
lhs_precision_type = tf32,
rhs_precision_type = tf32,
accumulation_type = f32,
lhs_component_count = 1,
rhs_component_count = 1,
num_primitive_operations = 1,
allow_imprecise_accumulation = false
>
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
// [[1, 2],
// [3, 4]],
// [[5, 6],
// [7, 8]]
// ]
dynamic_broadcast_in_dim
Semántica
Esta operación es funcionalmente idéntica a
broadcast_in_dim
op, pero la forma del resultado se especifica de forma dinámica a través de output_dimensions
.
La operación también acepta atributos opcionales known_expanding_dimensions
, known_non_expanding_dimensions
para expresar conocimiento estático sobre el comportamiento de expansión de las dimensiones.
Si no se especifican, se supone que es posible que todas las dimensiones se expandan.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C2), (C5-C6), (C9) |
(I2) | output_dimensions |
Tensor unidimensional de tipo de número entero | (C7) |
(I3) | broadcast_dimensions |
Tensor constante unidimensional de tipo de número entero | (C2-C6) |
(I4) | known_expanding_dimensions |
Tensor constante unidimensional de tipo de número entero | (C8-C9) |
(I5) | known_non_expanding_dimensions |
Tensor constante unidimensional de tipo de número entero | (C8-C9) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3), (C5-C7) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
,scales(operand)
yzero_points(operand)
pueden diferir dequantization_dimension(result)
,scales(result)
yzero_points(result)
de lo contrario.
- (C2)
size(broadcast_dimensions) = rank(operand)
- (C3)
0 <= broadcast_dimensions < rank(result)
- (C4)
is_unique(broadcast_dimensions)
- (C5) Para todos los
d
deaxes(operand)
:dim(operand, d) = 1
odim(operand, d) = dim(result, broadcast_dimensions[d])
.
- (C6) Si es
is_per_axis_quantized(result)
:quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)]
.- Si es
dim(operand, quantization_dimension(operand)) = 1
, entoncesscales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result)))
- (C7)
size(output_dimensions) = rank(result)
. - (C8)
is_unique(known_expanding_dimensions + known_non_expanding_dimensions)
- (C9)
0 <= known_expanding_dimensions < rank(operand)
. - (C10)
0 <= known_non_expanding_dimensions < rank(operand)
.
Ejemplos
// %operand: [
// [1, 2, 3]
// ]
%operand = stablehlo.constant dense<[[1, 2, 3]]> : tensor<1x3xi64>
%output_dimensions = stablehlo.constant dense<[2, 3, 2]> : tensor<3xi64>
%result = "stablehlo.dynamic_broadcast_in_dim"(%operand, %output_dimensions) {
broadcast_dimensions = array<i64: 2, 1>,
known_expanding_dimensions = array<i64: 0>,
known_non_expanding_dimensions = array<i64: 1>
} : (tensor<1x3xi64>, tensor<3xi64>) -> tensor<2x3x2xi64>
// %result: [
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ],
// [
// [1, 1],
// [2, 2],
// [3, 3]
// ]
// ]
dynamic_conv
Semántica
Esta operación es funcionalmente idéntica a
convolución
op, pero el padding se especifica de forma dinámica a través de padding
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | (C1), (C10-C11), (C14) (C25), (C26-C27), (C30-C31) y (C33) |
(I2) | rhs |
tensor o tensor cuantificado | (C1), (C14-C16), (C26-C28), (C30-C33) |
(I3) | padding |
Tensor bidimensional de tipo de número entero | (C4) |
(I4) | window_strides |
Constante de tensor unidimensional de tipo si64 |
(C2-C3) |
(I5) | lhs_dilation |
Constante de tensor unidimensional de tipo si64 |
(C5-C6) |
(I6) | rhs_dilation |
Constante de tensor unidimensional de tipo si64 |
(C7-C8) |
(I7) | window_reversal |
Constante de tensor unidimensional de tipo i1 |
(C9) |
(I8) | input_batch_dimension |
constante de tipo si64 |
(C10), (C13) |
(I9) | input_feature_dimension |
constante de tipo si64 |
(C11), (C13-C14) |
(I10) | input_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C12) y (C13) |
(I11) | kernel_input_feature_dimension |
constante de tipo si64 |
(C14) y (C18) |
(I12) | kernel_output_feature_dimension |
constante de tipo si64 |
(C15-C16), (C18) y (C28) |
(I13) | kernel_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C17-C18) |
(I14) | output_batch_dimension |
constante de tipo si64 |
C20 |
(I15) | output_feature_dimension |
constante de tipo si64 |
(C20), (C29) |
(I16) | output_spatial_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C19-C20) |
(I17) | feature_group_count |
constante de tipo si64 |
(C11), (C14), (C16), (C21), (C23) |
(I18) | batch_group_count |
constante de tipo si64 |
(C10), (C15), (C22), (C23) |
(I19) | precision_config |
Cantidad variádica de enumeraciones de DEFAULT , HIGH y HIGHEST |
(C24) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C25-C27), (C29), (C31-C33) |
Limitaciones
- (C1)
N = rank(lhs) = rank(rhs)
- (C2)
size(window_strides) = N - 2
- (C3)
0 < window_strides
- (C4)
shape(padding) = [N - 2, 2]
- (C5)
size(lhs_dilation) = N - 2
- (C6)
0 < lhs_dilation
- (C7)
size(rhs_dilation) = N - 2
. - (C8)
0 < rhs_dilation
- (C9)
size(window_reversal) = N - 2
. - (C10)
dim(lhs, input_batch_dimension) % batch_group_count = 0
. - (C11)
dim(lhs, input_feature_dimension) % feature_group_count = 0
- (C12)
size(input_spatial_dimensions) = N - 2
. - (C13) Dado
input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]
:is_unique(input_dimensions)
.0 <= input_dimensions < N
.
- (C14)
dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count
. - (C15)
dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0
. - (C16)
dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0
. - (C17)
size(kernel_spatial_dimensions) = N - 2
. - (C18) Dado
kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]
:is_unique(kernel_dimensions)
.0 <= kernel_dimensions < N
.
- (C19)
size(output_spatial_dimensions) = N - 2
. - (C20) Dado
output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]
:is_unique(output_dimensions)
.0 <= output_dimensions < N
.
- (C21)
0 < feature_group_count
. - (C22)
0 < batch_group_count
. - (C23)
feature_group_count = 1 or batch_group_count = 1
. - (C24)
size(precision_config) = 2
. - (C25)
dim(result, result_dim)
se define de la siguiente manera:dim(lhs, input_batch_dimension) / batch_group_count
si esresult_dim = output_batch_dimension
.dim(rhs, kernel_output_feature_dimension)
si esresult_dim = output_feature_dimension
.num_windows
de lo contrario, donde:output_spatial_dimensions[spatial_dim] = result_dim
.lhs_dim = input_spatial_dimensions[spatial_dim]
.rhs_dim = kernel_spatial_dimensions[spatial_dim]
.dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1
.padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]
.dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1
.is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]
.num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1
.
- (C26)
rank(result) = N
. - Si la operación usa tensores no cuantificados:
- (C27)
element_type(lhs) = element_type(rhs) = element_type(result)
.
- (C27)
- Si la operación usa tensores cuantificados:
- (C28)
is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)
. - (C29) Si es
is_per_axis_quantized(rhs)
, luego,quantization_dimension(rhs) = kernel_output_feature_dimension
. - (C30) Si
is_per_axis_quantized(result)
, entoncesquantization_dimension(result) = output_feature_dimension
- Si es
is_quantized(lhs)
: - (C31)
storage_type(lhs) = storage_type(rhs)
- (C32)
expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)
. - (C33) Si
is_per_tensor_quantized(rhs)
, entoncesis_per_tensor_quantized(result)
- Si es
!is_quantized(lhs)
: - (C34)
element_type(lhs) = expressed_type(rhs) = element_type(result)
.
- (C28)
Ejemplos
// %lhs: [[
// [[1], [2], [5], [6]],
// [[3], [4], [7], [8]],
// [[10], [11], [14], [15]],
// [[12], [13], [16], [17]]
// ]]
//
// %rhs: [
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]],
// [[[1]], [[1]], [[1]]]
// ]
// %padding: [[1, 1],
// [1, 1]]
%result = "stablehlo.dynamic_conv"(%lhs, %rhs, %padding) {
window_strides = array<i64: 4, 4>,
lhs_dilation = array<i64: 2, 2>,
rhs_dilation = array<i64: 1, 1>,
window_reversal = array<i1: false, false>,
dimension_numbers = #stablehlo.conv<raw
input_batch_dimension = 0,
input_feature_dimension = 3,
input_spatial_dimensions = [0, 1],
kernel_input_feature_dimension = 2,
kernel_output_feature_dimension = 3,
kernel_spatial_dimensions = [0, 1],
output_batch_dimension = 0,
output_feature_dimension = 3,
output_spatial_dimensions = [1, 2]
>,
feature_group_count = 1 : i64,
batch_group_count = 1 : i64,
precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi64>, tensor<3x3x1x1xi64>, tensor<2x2xi64>) -> tensor<1x2x2x1xi64>
// %result: [[
// [[1], [5]],
// [[10], [14]]
// ]]
dynamic_gather
Semántica
Esta operación es funcionalmente idéntica a
recopilación
op, con el slice_sizes
especificado de forma dinámica como un valor.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C7), (C10-C12), (C14) |
(I2) | start_indices |
tensor de tipo de número entero | (C2), (C3) y (C13) |
(I3) | slice_sizes |
Tensor unidimensional de tipo de número entero | (C8), (C11-C13) |
(I4) | offset_dims |
Constante de tensor unidimensional de tipo si64 |
(C1), (C4-C5), (C13) |
(I5) | collapsed_slice_dims |
Constante de tensor unidimensional de tipo si64 |
(C1), (C6-C8), (C13) |
(I6) | start_index_map |
Constante de tensor unidimensional de tipo si64 |
(C3), (C9), (C10) |
(I7) | index_vector_dim |
constante de tipo si64 |
(C2), (C3) y (C13) |
(I8) | indices_are_sorted |
constante de tipo i1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5), (C13-C14) |
Limitaciones
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims)
- (C2)
0 <= index_vector_dim <= rank(start_indices)
- (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
- (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
- (C5)
0 <= offset_dims < rank(result)
- (C6)
is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims)
- (C7)
0 <= collapsed_slice_dims < rank(operand)
. - (C8)
slice_sizes[collapsed_slice_dims...] <= 1
- (C9)
is_unique(start_index_map)
. - (C10)
0 <= start_index_map < rank(operand)
. - (C11)
size(slice_sizes) = rank(operand)
- (C12)
0 <= slice_sizes <= shape(operand)
. - (C13)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
donde:batch_dim_sizes = shape(start_indices)
, excepto que el tamaño de la dimensión destart_indices
correspondientes aindex_vector_dim
no está incluido.offset_dim_sizes = shape(slice_sizes)
, excepto que los tamaños de las dimensiones enslice_sizes
correspondientes acollapsed_slice_dims
no se incluyen.combine
colocabatch_dim_sizes
en los ejes correspondientes abatch_dims
yoffset_dim_sizes
en los ejes correspondientes aoffset_dims
.
- (C14)
element_type(operand) = element_type(result)
.
Ejemplos
// %operand: [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ]
// %start_indices: [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 2]]
// ]
// %slize_sizes: [1, 2, 2]
%result = "stablehlo.dynamic_gather"(%operand, %start_indices, %slize_sizes) {
dimension_numbers = #stablehlo.gather<
offset_dims = [2, 3],
collapsed_slice_dims = [0],
start_index_map = [1, 0],
index_vector_dim = 2>,
indices_are_sorted = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<3xi64>) -> tensor<2x3x2x2xi64>
// %result: [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[9, 10], [11, 12]],
// [[11, 12], [13, 14]],
// [[17, 18], [19, 20]]
// ]
// ]
dynamic_iota
Semántica
Esta operación es funcionalmente idéntica a
iota
op, pero la forma del resultado se especifica de forma dinámica a través de output_shape
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | output_shape |
Tensor unidimensional de tipo de número entero | (C1) y (C2) |
(I2) | iota_dimension |
si64 |
C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | (C2) |
Limitaciones
- (C1)
0 <= iota_dimension < size(output_shape)
- (C2)
rank(result) = size(output_shape)
Ejemplos
%output_shape = stablehlo.constant dense<[4, 5]> : tensor<2xi64>
%result = "stablehlo.dynamic_iota"(%output_shape) {
iota_dimension = 0 : i64
} : (tensor<2xi64>) -> tensor<4x5xi64>
// %result: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
dynamic_pad
Semántica
Esta operación es funcionalmente idéntica a
almohadilla
op, pero con edge_padding_low
, edge_padding_high
y interior_padding
se especifican de forma dinámica como valores.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor de dimensión 0 o tensor cuantificado por tensor | C1 |
(I3) | edge_padding_low |
Tensor unidimensional de tipo de número entero | (C1) y (C4) |
(I4) | edge_padding_high |
Tensor unidimensional de tipo de número entero | (C1) y (C4) |
(I5) | interior_padding |
Tensor unidimensional de tipo de número entero | (C2-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3-C6) |
Limitaciones
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
- (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
- (C3)
0 <= interior_padding
- (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
Ejemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
// %edge_padding_low: [0, 1]
// %edge_padding_high: [2, 1]
// %interior_padding: [1, 2]
%result = "stablehlo.dynamic_pad"(%operand, %padding_value,
%edge_padding_low, %edge_padding_high, %interior_padding
) : (tensor<2x3xi64>, tensor<i64>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<5x9xi64>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
dynamic_reshape
Semántica
Esta operación es funcionalmente idéntica a
redimensionar
op, pero la forma del resultado se especifica de forma dinámica a través de output_shape
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C3) |
(I2) | output_shape |
Tensor unidimensional de tipo de número entero | (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C4) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
y De lo contrario,quantization_dimension(result)
podría variar.
- (C2)
size(operand) = size(result)
- (C3) Si es
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
- (C4)
size(output_shape) = rank(result)
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
// %output_shape: [3, 2]
%result = "stablehlo.dynamic_reshape"(%operand, %output_shape) : (tensor<2x3xi64>, tensor<2xi64>) -> tensor<3x2xi64>
// %result: [[1, 2], [3, 4], [5, 6]]
dynamic_slice
Semántica
Extrae una porción de operand
mediante índices de inicio calculados de forma dinámica.
y produce un tensor result
. start_indices
contienen los índices iniciales de
la porción para cada dimensión sujeta a posibles ajustes, y slice_sizes
contienen los tamaños de la porción para cada dimensión. Más formalmente,
result[result_index] = operand[operand_index]
donde:
adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes)
.operand_index = adjusted_start_indices + result_index
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
(I2) | start_indices |
número variádico de tensores 0 dimensiones de tipo de número entero | (C2) y (C3) |
(I3) | slice_sizes |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C5) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C5) |
Limitaciones
- (C1)
element_type(operand) = element_type(result)
- (C2)
size(start_indices) = size(slice_sizes) = rank(operand)
- (C3)
same(type(start_indices...))
- (C4)
0 <= slice_sizes <= shape(operand)
- (C5)
shape(result) = slice_sizes
Ejemplos
// %operand: [
// [0, 0, 1, 1],
// [0, 0, 1, 1],
// [0, 0, 0, 0],
// [0, 0, 0, 0]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
slice_sizes = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
// [1, 1],
// [1, 1]
// ]
dynamic_update_slice
Semántica
Produce un tensor result
que es igual al tensor operand
, excepto que
La porción que comienza en start_indices
se actualiza con los valores en update
.
De manera más formal, result[result_index]
se define de la siguiente manera:
update[update_index]
si es0 <= update_index < shape(update)
donde:adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update))
.update_index = result_index - adjusted_start_indices
.
- De lo contrario,
operand[result_index]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C4), (C6) |
(I2) | update |
tensor o tensor cuantificado por tensor | (C2), (C3) y (C6) |
(I3) | start_indices |
número variádico de tensores 0 dimensiones de tipo de número entero | (C4) y (C5) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
type(operand) = type(result)
- (C2)
element_type(update) = element_type(operand)
- (C3)
rank(update) = rank(operand)
- (C4)
size(start_indices) = rank(operand)
- (C5)
same(type(start_indices...))
- (C6)
0 <= shape(update) <= shape(operand)
Ejemplos
// %operand: [
// [1, 1, 0, 0],
// [1, 1, 0, 0],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
// %update: [
// [1, 1],
// [1, 1]
// ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
: (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1],
// [1, 1, 1, 1]
// ]
exponencial
Semántica
Realiza una operación exponencial por elementos en el tensor operand
y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
exp
de IEEE-754. - Para números complejos: exponencial compleja.
- Para tipos cuantizados:
dequantize_op_quantize(exponential, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]
exponential_minus_one
Semántica
Realiza una operación exponencial a nivel de elementos menos una operación en el tensor operand
y
produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
expm1
de IEEE-754. - Para números complejos: exponencial compleja menos uno.
- Para tipos cuantizados:
dequantize_op_quantize(exponential_minus_one, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]
final
Semántica
Realiza las transformaciones de Fourier hacia adelante e inversas para modelos reales y complejos de entrada y salida.
fft_type
es una de las siguientes opciones:
FFT
: Reenvía FFT complejo a complejo.IFFT
: FFT inverso de complejo a complejo.RFFT
: Reenvía FFT reales a complejos.IRFFT
: FFT inverso de real a complejo (es decir, toma complejo y muestra real).
Más formalmente, dada la función fft
, que toma tensores unidimensionales de
tipos complejos como entrada, produce tensores unidimensionales de los mismos tipos que
y calcula la transformación discreta de Fourier:
Para fft_type = FFT
, result
se define como el resultado final de una serie de L.
en los que L = size(fft_length)
. Por ejemplo, para L = 3
:
result1[i0, ..., :] = fft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Además, dada la función ifft
, que tiene el mismo tipo de firma y
calcula la inversa de fft
:
Para fft_type = IFFT
, result
se define como la inversa de los cálculos
para fft_type = FFT
. Por ejemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = ifft(result2[i0, ..., :])
.
Además, dada la función rfft
, que toma tensores unidimensionales de
tipos de punto flotante, produce tensores unidimensionales de tipos complejos de la
misma semántica de punto flotante y funciona de la siguiente manera:
rfft(real_operand) = truncated_result
dondecomplex_operand... = (real_operand..., 0.0)
.complex_result = fft(complex_operand)
.truncated_result = complex_result[:(rank(complex_result) / 2 + 1)]
.
(Cuando la transformación discreta de Fourier se calcula para operandos reales, la primera
Los elementos N/2 + 1
del resultado definen de manera clara el resto del resultado.
por lo que el resultado de rfft
se trunca para evitar el procesamiento de elementos redundantes).
Para fft_type = RFFT
, result
se define como el resultado final de una serie de L.
en los que L = size(fft_length)
. Por ejemplo, para L = 3
:
result1[i0, ..., :] = rfft(operand[i0, ..., :])
.result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1])
.result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1])
.
Por último, dada la función irfft
, que tiene el mismo tipo de firma y
calcula la inversa de rfft
:
Para fft_type = IRFFT
, result
se define como la inversa de los cálculos
para fft_type = RFFT
. Por ejemplo, para L = 3
:
result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1])
.result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1])
.result[i0, ..., :] = irfft(result2[i0, ..., :])
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo | (C1), (C2), (C4) y (C5) |
(I2) | fft_type |
enum de FFT , IFFT , RFFT y IRFFT |
(C2) y (C5) |
(I3) | fft_length |
Constante de tensor unidimensional de tipo si64 |
(C1), (C3) y (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo | (C2), (C4), (C5) |
Limitaciones
- (C1)
size(fft_length) <= rank(operand)
- (C2) La relación entre los tipos de elementos
operand
yresult
varía:- Si es
fft_type = FFT
,element_type(operand)
yelement_type(result)
tienen el mismo tipo complejo. - Si es
fft_type = IFFT
,element_type(operand)
yelement_type(result)
tienen el mismo tipo complejo. - Si es
fft_type = RFFT
,element_type(operand)
es un tipo de punto flotante.element_type(result)
es un tipo complejo del mismo punto flotante semántica. - Si es
fft_type = IRFFT
,element_type(operand)
es un tipo complejo.element_type(result)
es un tipo de punto flotante del mismo punto flotante. semántica.
- Si es
- (C3)
1 <= size(fft_length) <= 3
- (C4) Si se encuentra entre
operand
yresult
, hay un tensorreal
de un tipo de punto flotante y, luego,shape(real)[-size(fft_length):] = fft_length
. - (C5)
shape(result) = shape(operand)
, excepto por:- Si es
fft_type = RFFT
,dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1
- Si es
fft_type = IRFFT
,dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1
- Si es
Ejemplos
// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
fft_type = #stablehlo<fft_type FFT>,
fft_length = array<i64: 4>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]
piso
Semántica
Realiza el precio mínimo de elementos del tensor operand
y produce un tensor result
.
Implementa la operación roundToIntegralTowardNegative
del estándar IEEE-754.
especificación. Para los tipos cuantizados, realiza
dequantize_op_quantize(floor, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]
reunir
Semántica
Recopila segmentos del tensor operand
a partir de los desplazamientos especificados en start_indices
y produce un tensor result
.
En el siguiente diagrama, se muestra cómo los elementos de result
se asignan a los elementos de
operand
con un ejemplo concreto. En el diagrama, se seleccionan algunos result
de ejemplo
y explica en detalle a qué índices de operand
corresponden.
Más formalmente, result[result_index] = operand[operand_index]
, donde:
batch_dims = [d for d in axes(result) and d not in offset_dims]
.batch_index = result_index[batch_dims...]
.start_index
se define de la siguiente manera:start_indices[bi0, ..., :, ..., biN]
dondebi
son elementos individuales enbatch_index
y:
se insertan en el índiceindex_vector_dim
, siindex_vector_dim
<rank(start_indices)
- De lo contrario,
[start_indices[batch_index]]
.
- Para
d_operand
enaxes(operand)
,full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand])
si esd_operand = start_index_map[d_start]
.- De lo contrario,
full_start_index[d_operand] = 0
.
- Para
d_operand
enaxes(operand)
,full_batching_index[d_operand] = batch_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_operand = operand_batching_dims[i_batching]
yd_start = start_indices_batching_dims[i_batching]
- De lo contrario,
full_batching_index[d_operand] = 0
.
offset_index = result_index[offset_dims...]
.full_offset_index = [oi0, ..., 0, ..., oiN]
en el queoi
son individuales elementos enoffset_index
y0
se inserta en los índices decollapsed_slice_dims
yoperand_batching_dims
.operand_index = full_start_index + full_batching_index + full_offset_index
Si indices_are_sorted
es true
, la implementación puede suponer que
start_indices
se ordenan con respecto a start_index_map
; de lo contrario, la
el comportamiento no está definido. Más formalmente, para todos los i1 < i2
de indices(result)
,
full_start_index(i1) <= full_start_index(i2)
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C8), (C11), (C17), (C19-C21), (C23) |
(I2) | start_indices |
tensor de tipo de número entero | (C2-C3), (C14), (C17), (C22) |
(I3) | offset_dims |
Constante de tensor unidimensional de tipo si64 |
(C1), (C4-C5), (C22) |
(I4) | collapsed_slice_dims |
Constante de tensor unidimensional de tipo si64 |
(C1), (C6-C9), (C22) |
(I5) | operand_batching_dims |
Constante de tensor unidimensional de tipo si64 |
(C1), (C6), (C10-C12), (C16-C18), (C22) |
(I6) | start_indices_batching_dims |
Constante de tensor unidimensional de tipo si64 |
(C13-C17) |
(I7) | start_index_map |
Constante de tensor unidimensional de tipo si64 |
(C3), (C18-C19) |
(I8) | index_vector_dim |
constante de tipo si64 |
(C2-C3), (C15) y (C22) |
(I9) | slice_sizes |
Constante de tensor unidimensional de tipo si64 |
(C9), (C12), (C20-C22) |
(I10) | indices_are_sorted |
constante de tipo i1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C5), (C22-C23) |
Limitaciones
- (C1)
rank(operand) = size(offset_dims) + size(collapsed_slice_dims) + size(operand_batching_dims)
- (C2)
0 <= index_vector_dim <= rank(start_indices)
- (C3)
size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1
- (C4)
is_unique(offset_dims) and is_sorted(offset_dims)
- (C5)
0 <= offset_dims < rank(result)
- C6:
is_unique(concatenate(collapsed_slice_dims, operand_batching_dims))
- (C7)
is_sorted(collapsed_slice_dims)
. - (C8)
0 <= collapsed_slice_dims < rank(operand)
- (C9)
slice_sizes[collapsed_slice_dims...] <= 1
. - (C10)
is_sorted(operand_batching_dims)
. - (C11)
0 <= operand_batching_dims < rank(operand)
- (C12)
slice_sizes[operand_batching_dims...] <= 1
. - (C13)
is_unique(start_indices_batching_dims)
. - (C14)
0 <= start_indices_batching_dims < rank(start_indices)
. - (C15)
index_vector_dim not in start_indices_batching_dims
. - (C16)
size(operand_batching_dims) == size(start_indices_batching_dims)
. - (C17)
dim(operand, operand_batching_dims...) = dim(start_indices, start_indices_batching_dims...)
. - (C18)
is_unique(concatenate(start_index_map, operand_batching_dims))
. - (C19)
0 <= start_index_map < rank(operand)
. - (C20)
size(slice_sizes) = rank(operand)
. - (C21)
0 <= slice_sizes <= shape(operand)
. - (C22)
shape(result) = combine(batch_dim_sizes, offset_dim_sizes)
donde:batch_dim_sizes = shape(start_indices)
, excepto que el tamaño de la dimensión destart_indices
correspondientes aindex_vector_dim
no está incluido.offset_dim_sizes = slice_sizes
, excepto que los tamaños de las dimensionesslice_sizes
, que corresponde acollapsed_slice_dims
yoperand_batching_dims
no están incluidas.combine
colocabatch_dim_sizes
en los ejes correspondientes abatch_dims
yoffset_dim_sizes
en los ejes correspondientes aoffset_dims
.
- (C23)
element_type(operand) = element_type(result)
.
Ejemplos
// %operand: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %start_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
%result = "stablehlo.gather"(%operand, %start_indices) {
dimension_numbers = #stablehlo.gather<
offset_dims = [3, 4],
collapsed_slice_dims = [1],
operand_batching_dims = [0],
start_indices_batching_dims = [1],
start_index_map = [2, 1],
index_vector_dim = 3>,
slice_sizes = array<i64: 1, 1, 2, 2>,
indices_are_sorted = false
} : (tensor<2x3x4x2xi32>, tensor<2x2x3x2xi64>) -> tensor<2x2x3x2x2xi32>
// %result: [
// [
// [
// [[1, 2], [3, 4]],
// [[3, 4], [5, 6]],
// [[13, 14], [15, 16]]
// ],
// [
// [[33, 34], [35, 36]],
// [[35, 36], [37, 38]],
// [[41, 42], [43, 44]]
// ]
// ],
// [
// [
// [[1, 2], [3, 4]],
// [[13, 14], [15, 16]],
// [[21, 22], [23, 24]]
// ],
// [
// [[43, 44], [45, 46]],
// [[33, 34], [35, 36]],
// [[27, 28], [29, 30]]
// ]
// ]
// ]
get_dimension_size
Semántica
Produce el tamaño del dimension
determinado de operand
. Más formalmente,
result = dim(operand, dimension)
La semántica se relaciona solo con la forma
componente del tipo. El tipo de elemento puede ser cualquier cosa.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | C1 |
(I2) | dimension |
constante de tipo si64 |
C1 |
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de dimensión 0 de tipo si32 |
Limitaciones
- (C1)
0 <= dimension < rank(operand)
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3
get_tuple_element
Semántica
Extrae el elemento en la posición index
de la tupla operand
y produce un
result
Más formalmente, result = operand[index]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tuple | (C1) y (C2) |
(I2) | index |
constante de tipo si32 |
(C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
Cualquier tipo compatible | (C2) |
Limitaciones
- (C1)
0 <= index < size(operand)
- (C2)
type(result) = tuple_element_types(operand)[index]
Ejemplos
// %operand: ([1.0, 2.0], (3))
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
si
Semántica
Produce el resultado de la ejecución de exactamente una función de true_branch
.
false_branch
, según el valor de pred
. Más formalmente, result =
pred ? true_branch() : false_branch()
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | pred |
Tensor de dimensión 0 de tipo i1 |
|
(I2) | true_branch |
función | (C1-C3) |
(I3) | false_branch |
función | (C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
número variádico de tensores, tensores cuantificados o tokens | (C3) |
Limitaciones
- (C1)
input_types(true_branch) = input_types(false_branch) = []
- (C2)
output_types(true_branch) = output_types(false_branch)
- (C3)
type(results...) = output_types(true_branch)
Ejemplos
// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
"stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
"stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10
imagen
Semántica
Extrae la parte imaginaria, a nivel de los elementos, del operand
y produce un
Tensor result
. De manera más formal, para cada elemento x
:
imag(x) = is_complex(x) ? imaginary_part(x) :
constant(0, element_type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo | (C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1) y (C2) |
Limitaciones
- (C1)
shape(result) = shape(operand)
- (C2)
element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.- De lo contrario,
element_type(operand)
.
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]
entrada
Semántica
Lee datos de la entrada y genera results
.
La semántica de infeed_config
está definida por la implementación.
results
consisten en valores de carga útil que van primero y un token que aparece
último. En el futuro, planeamos dividir la carga útil y el token en dos
salidas independientes para mejorar la claridad
(#670).
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | token |
token |
(I2) | infeed_config |
constante de tipo string |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
número variádico de tensores, tensores cuantificados o tokens | (C1-C3) |
Limitaciones
- (C1)
0 < size(results)
- (C2)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C3)
is_token(type(results[-1]))
Ejemplos
// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]
iota
Semántica
Completa un tensor output
con valores en orden ascendente a partir de cero
junto con la dimensión iota_dimension
. Más formalmente,
output[output_index] = constant(is_quantized(output) ?
quantize(output_index[iota_dimension], element_type(output)) :
output_index[iota_dimension], element_type(output))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | iota_dimension |
si64 |
C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
0 <= iota_dimension < rank(output)
Ejemplos
%output = "stablehlo.iota"() {
iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3]
// ]
%output = "stablehlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4],
// [0, 1, 2, 3, 4]
// ]
is_finite
Semántica
Realiza la verificación a nivel de elementos si el valor de x
es finito (es decir, no es
+Inf, -Inf o NaN) y produce un tensor y
. Implementa isFinite
.
de la especificación IEEE-754. Para los tipos cuantizados, el resultado es
siempre es true
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | x |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
y |
tensor de tipo booleano | C1 |
Limitaciones
- (C1)
shape(x) = shape(y)
Ejemplos
// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]
log
Semántica
Realiza la operación de logaritmo en relación con los elementos en el tensor operand
y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
log
de IEEE-754. - Para números complejos: logaritmo complejo.
- Para tipos cuantizados:
dequantize_op_quantize(log, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]
log_plus_one
Semántica
Realiza el logaritmo por elementos más una operación en el tensor operand
y
produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
logp1
de IEEE-754. - Para números complejos: logaritmo complejo más uno.
- Para tipos cuantizados:
dequantize_op_quantize(log_plus_one, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]
logística
Semántica
Realiza una operación logística en relación con los elementos en el tensor operand
y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
division(1, addition(1, exp(-x)))
de IEEE-754. - Para números complejos: logística compleja.
- Para tipos cuantizados:
dequantize_op_quantize(logistic, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]
mapa
Semántica
Aplica una función de asignación computation
a inputs
junto con dimensions
y
produce un tensor result
.
Más formalmente, result[result_index] = computation(inputs...[result_index])
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4) |
(I2) | dimensions |
Constante de tensor unidimensional de tipo si64 |
(C3) |
(I3) | computation |
función | (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) y (C4) |
Limitaciones
- (C1)
shape(inputs...) = shape(result)
- (C2)
0 < size(inputs) = N
- (C3)
dimensions = range(rank(inputs[0]))
- (C4)
computation
tiene el tipo(tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>
dondeEi = element_type(inputs[i])
yE' = element_type(result)
.
Ejemplos
// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
stablehlo.return %0 : tensor<i64>
}) {
dimensions = array<i64: 0, 1>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]
máxima
Semántica
Realiza la operación máxima a nivel de elementos en los tensores lhs
y rhs
, y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: número máximo de números enteros.
- Para números de punto flotante:
maximum
de IEEE-754. - Para números complejos: máximo lexicográfico para el par
(real, imaginary)
. Imponer un orden en números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560). - Para tipos cuantizados:
dequantize_op_quantize(maximum, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]
mínima
Semántica
Realiza una operación mínima a nivel de elementos en los tensores lhs
y rhs
, y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico AND.
- Para números enteros: número entero mínimo.
- Para números de punto flotante:
minimum
de IEEE-754. - Para números complejos: mínimo lexicográfico para el par
(real, imaginary)
. Imponer un orden en números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con los números complejos para esta operación (#560). - Para tipos cuantizados:
dequantize_op_quantize(minimum, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
Ejemplos
// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]
multiplicar
Semántica
Realiza el producto en elementos de dos tensores, lhs
y rhs
, y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico AND.
- Para números enteros: multiplicación de números enteros
- Para números de punto flotante:
multiplication
de IEEE-754. - Para números complejos: multiplicación compleja.
- Para tipos cuantizados:
dequantize_op_quantize(multiply, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]
negativo
Semántica
Realiza la negación a nivel de elementos del tensor operand
y produce un result
.
tensor. Según el tipo de elemento, hace lo siguiente:
- Para números enteros firmados: negación de números enteros
- Para números enteros sin firma: conversión de bits a número entero firmado, negación de números enteros, bitcast de vuelta a un número entero sin firma.
- Para números de punto flotante:
negate
de IEEE-754. - Para números complejos: negación compleja
- Para tipos cuantizados:
dequantize_op_quantize(negate, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]
// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]
no
Semántica
Realiza NO a nivel de elementos del tensor operand
y produce un tensor result
.
Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: lógico NOT.
- Para números enteros: NOT a nivel de bits.
Argumentos
Nombre | Tipo | Limitaciones |
---|---|---|
operand |
tensor de tipo booleano o de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano o de número entero | C1 |
Limitaciones
- (C1)
type(operand) = type(result)
Ejemplos
// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]
// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]
optimization_barrier
Semántica
Garantiza que las operaciones que producen el operand
se ejecuten antes que cualquier
Operaciones que dependen de result
y evitan las transformaciones del compilador
de mover operaciones a través de la barrera. Aparte de eso, la operación es
una identidad, es decir, result = operand
Argumentos
Nombre | Tipo | Limitaciones |
---|---|---|
operand |
cantidad variable de tensores, tensores o tokens cuantificados por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
cantidad variable de tensores, tensores o tokens cuantificados por tensor | C1 |
Limitaciones
- (C1)
type(operand...) = type(result...)
Ejemplos
// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0
o
Semántica
Realiza OR a nivel de elementos de dos tensores, lhs
y rhs
, y produce un result
.
tensor. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: OR lógico.
- Para números enteros: OR bit a bit.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo de número entero o booleano | C1 |
(I2) | rhs |
tensor de tipo de número entero o booleano | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero o booleano | C1 |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]
salida
Semántica
Escribe inputs
en la salida y produce un token result
.
La semántica de outfeed_config
está definida por la implementación.
Entradas
Etiqueta | Nombre | Tipo |
---|---|---|
(I1) | inputs |
número variádico de tensores o tensores cuantificados |
(I2) | token |
token |
(I3) | outfeed_config |
constante de tipo string |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Ejemplos
%result = "stablehlo.outfeed"(%input0, %token) {
outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
almohadilla
Semántica
Expande operand
con el relleno alrededor del tensor y entre los elementos.
del tensor con la padding_value
especificada.
edge_padding_low
y edge_padding_high
especifican la cantidad de padding agregado.
el extremo bajo (junto al índice 0) y el extremo alto (junto al índice más alto) de
para cada dimensión respectivamente. La cantidad de padding puede ser negativa, en la que el
El valor absoluto del padding negativo indica la cantidad de elementos que se deben quitar.
de la dimensión especificada.
interior_padding
especifica la cantidad de padding que se agrega entre dos propiedades.
elementos en cada dimensión que pueden no ser negativos. El interior tiene padding.
antes del relleno del borde, de modo que este quite los elementos
el operando con relleno interior.
De manera más formal, result[result_index]
se define de la siguiente manera:
operand[operand_index]
siresult_index = edge_padding_low + operand_index * (interior_padding + 1)
- De lo contrario,
padding_value
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C4) |
(I2) | padding_value |
Tensor de dimensión 0 o tensor cuantificado por tensor | C1 |
(I3) | edge_padding_low |
Constante de tensor unidimensional de tipo si64 |
(C1) y (C4) |
(I4) | edge_padding_high |
Constante de tensor unidimensional de tipo si64 |
(C1) y (C4) |
(I5) | interior_padding |
Constante de tensor unidimensional de tipo si64 |
(C2-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C3-C6) |
Limitaciones
- (C1)
element_type(operand) = element_type(padding_value) = element_type(result)
- (C2)
size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand)
- (C3)
0 <= interior_padding
- (C4)
shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high
Ejemplos
// %operand: [
// [1, 2, 3],
// [4, 5, 6]
// ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
edge_padding_low = array<i64: 0, 1>,
edge_padding_high = array<i64: 2, 1>,
interior_padding = array<i64: 1, 2>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
// [0, 1, 0, 0, 2, 0, 0, 3, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 4, 0, 0, 5, 0, 0, 6, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0],
// [0, 0, 0, 0, 0, 0, 0, 0, 0]
// ]
partition_id
Semántica
Produce partition_id
del proceso actual.
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de dimensión 0 de tipo ui32 |
Ejemplos
%result = "stablehlo.partition_id"() : () -> tensor<ui32>
población
Semántica
Realiza el recuento a nivel de elementos de la cantidad de bits establecidos en el tensor operand
.
y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero | C1 |
Limitaciones
- (C1)
type(operand) = type(result)
Ejemplos
// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]
energía
Semántica
Realiza la exponenciación a nivel de elementos del tensor lhs
mediante el tensor rhs
y
produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números enteros: exponenciación de números enteros
- Para números de punto flotante:
pow
de IEEE-754. - Para números complejos: exponenciación compleja.
- Para tipos cuantizados:
dequantize_op_quantize(power, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]
real
Semántica
Extrae la parte real, en términos de elementos, de operand
y produce un result
.
tensor. De manera más formal, para cada elemento x
:
real(x) = is_complex(x) ? real_part(x) : x
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo | (C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1) y (C2) |
Limitaciones
- (C1)
shape(result) = shape(operand)
- (C2)
element_type(result)
se define de la siguiente manera:complex_element_type(element_type(operand))
si esis_complex(operand)
.- De lo contrario,
element_type(operand)
.
Ejemplos
// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]
recv
Semántica
Recibe datos de un canal con channel_id
y produce results
.
Si is_host_transfer
es true
, la operación transfiere datos desde el
host. De lo contrario, transfiere los datos desde otro dispositivo. Esto significa que
definido por la implementación. Esta marca duplica la información proporcionada en
channel_type
, así que en el futuro solo planeamos conservar uno de ellos
(#666).
results
consisten en valores de carga útil que van primero y un token que aparece
último. En el futuro, planeamos dividir la carga útil y el token en dos
salidas independientes para mejorar la claridad
(#670).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | token |
token |
(C4) |
(I2) | channel_id |
constante de tipo si64 |
|
(I3) | channel_type |
enum de DEVICE_TO_DEVICE y HOST_TO_DEVICE |
C1 |
(I4) | is_host_transfer |
constante de tipo i1 |
C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
número variádico de tensores, tensores cuantificados o tokens | (C2-C4) |
Limitaciones
- (C1)
channel_type
se define de la siguiente manera:HOST_TO_DEVICE
siis_host_transfer = true
,- De lo contrario,
DEVICE_TO_DEVICE
.
- (C2)
0 < size(results)
- (C3)
is_empty(result[:-1])
ois_tensor(type(results[:-1]))
. - (C4)
is_token(type(results[-1]))
Ejemplos
%results0, %results1 = "stablehlo.recv"(%token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
Reducir
Semántica
Aplica una función de reducción body
a inputs
y init_values
junto con
dimensions
y produce tensores results
.
El orden de las reducciones está definido por la implementación, lo que significa que body
y
init_values
debe formar un monoide para garantizar que la operación produzca el
los mismos resultados en todas las entradas de todas las implementaciones. Sin embargo, esta condición
no es suficiente para muchas
reducciones populares. P.ej., suma de punto flotante para
En realidad, body
y cero para init_values
no forman un monoide porque
la suma de punto flotante no es asociativa.
Más formalmente, results...[j0, ..., jR-1] = reduce(input_slices_converted)
, donde:
input_slices = inputs...[j0, ..., :, ..., jR-1]
, donde se insertan:
endimensions
.input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)
.init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)
.reduce(input_slices_converted) = exec(schedule)
para un árbol binarioschedule
donde:exec(node) = body(exec(node.left), exec(node.right))
.exec(leaf) = leaf.value
.
schedule
es un árbol binario completo definido por la implementación cuyo orden El recorrido consta de lo siguiente:- Valores
input_slices_converted...[index]
, para todos losindex
deindex_space(input_slices_converted)
en orden lexicográfico ascendente deindex
. - Intercalado con una cantidad definida por la implementación de
init_values_converted
en posiciones definidas por la implementación.
- Valores
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4), (C6) y (C7) |
(I2) | init_values |
número variádico de tensores de 0 dimensiones o tensores cuantificados por tensor | (C2) y (C3) |
(I3) | dimensions |
Constante de tensor unidimensional de tipo si64 |
(C4), (C5), (C7) |
(I4) | body |
función | (C6) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C3), (C7), (C8) |
Limitaciones
- (C1)
same(shape(inputs...))
- (C2)
element_type(inputs...) = element_type(init_values...)
- (C3)
0 < size(inputs) = size(init_values) = size(results) = N
- (C4)
0 <= dimensions < rank(inputs[0])
- (C5)
is_unique(dimensions)
- (C6)
body
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, dondeis_promotable(element_type(inputs[i]), Ei)
- (C7)
shape(results...) = shape(inputs...)
, excepto que la dimensión No se incluyen los tamaños deinputs...
correspondientes adimensions
. - (C8)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]
reduce_precision
Semántica
Realiza la conversión en elementos de operand
a otro tipo de punto flotante.
que use exponent_bits
y mantissa_bits
, y vuelva a la versión original
de punto flotante y produce un tensor output
.
Más formalmente:
- Los fragmentos de la mantisa del valor original se actualizan para redondear el original
valor al valor más cercano representable con
mantissa_bits
mediante Semántica deroundToIntegralTiesToEven
. - Entonces, si
mantissa_bits
es menor que la cantidad de bits de la mantisa de el valor original, los bits de la mantisa se truncan amantissa_bits
. - Entonces, si los bits exponentes del resultado intermedio no caben en el
proporcionado por
exponent_bits
, el resultado intermedio se desborda infinito con el signo original o subdesbordamiento a cero con el letrero original. - Para los tipos cuantizados, realiza
dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
(I2) | exponent_bits |
constante de tipo si32 |
(C2) |
(I3) | mantissa_bits |
constante de tipo si32 |
(C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(output)
- (C2)
1 <= exponent_bits
- (C3)
0 <= mantissa_bits
Ejemplos
// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
exponent_bits = 5 : i32,
mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]
reduce_scatter
Semántica
Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, realiza la reducción,
con computations
, sobre los valores del tensor operand
de cada proceso,
divide el resultado de la reducción a lo largo de scatter_dimension
en partes y dispersa
las partes divididas entre los procesos para producir el result
.
La operación divide la cuadrícula de procesos StableHLO en process_groups
, que tiene el siguiente valor:
se define de la siguiente manera:
cross_replica(replica_groups)
si eschannel_id <= 0 and use_global_device_ids = false
.cross_replica_and_partition(replica_groups)
si eschannel_id > 0 and use_global_device_ids = false
.flattened_ids(replica_groups)
si eschannel_id > 0 and use_global_device_ids = true
.
Luego, dentro de cada process_group
, haz lo siguiente:
reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation)
.parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension)
.result@receiver = parts@sender[receiver_index]
para todos lossender
deprocess_group
, dondereceiver_index = process_group.index(receiver)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1), (C2), (C7) y (C8) |
(I2) | scatter_dimension |
constante de tipo si64 |
(C1), (C2) y (C8) |
(I3) | replica_groups |
Constante de tensor bidimensional de tipo si64 |
(C3-C5) |
(I4) | channel_id |
constante de tipo si64 |
(C6) |
(I5) | use_global_device_ids |
constante de tipo i1 |
(C6) |
(I6) | computation |
función | (C7) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C8-C9) |
Limitaciones
- (C1)
dim(operand, scatter_dimension) % dim(process_groups, 1) = 0
- (C2)
0 <= scatter_dimension < rank(operand)
- (C3)
is_unique(replica_groups)
- (C4)
size(replica_groups)
se define de la siguiente manera:num_replicas
si se usacross_replica
.num_replicas
si se usacross_replica_and_partition
.num_processes
si se usaflattened_ids
.
- (C5)
0 <= replica_groups < size(replica_groups)
- (C6) Si es
use_global_device_ids = true
, entonceschannel_id > 0
. - (C7)
computation
tiene el tipo(tensor<E>, tensor<E>) -> (tensor<E>)
, dondeis_promotable(element_type(operand), E)
- (C8)
shape(result) = shape(operand)
, excepto:dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)
.
- (C9)
element_type(result) = E
.
Ejemplos
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
// [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
// [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
// [18, 20]]
// %result@(1, 0): [[14, 16],
// [22, 24]]
reduce_window
Semántica
Aplica una función de reducción body
a las ventanas de inputs
y init_values
y produce results
.
En el siguiente diagrama, se muestra cómo se calculan los elementos de results...
a partir de
inputs...
con un ejemplo concreto.
Más formalmente,
results...[result_index] = reduce(windows, init_values, axes(inputs...), body)
(consulta Reduce) donde:
padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1)
.window_start = result_index * window_strides
.window_end = window_start + (window_dimensions - 1) * window_dilations + 1
.windows = slice(padded_inputs..., window_start, window_end, window_dilations)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C4), (C6), (C8), (C10), (C12), (C13) y (C15) |
(I2) | init_values |
número variádico de tensores de 0 dimensiones o tensores cuantificados por tensor | (C1) y (C13) |
(I3) | window_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C4), (C5), (C15) |
(I4) | window_strides |
Constante de tensor unidimensional de tipo si64 |
(C6), (C7), (C15) |
(I5) | base_dilations |
Constante de tensor unidimensional de tipo si64 |
(C8), (C9), (C15) |
(I6) | window_dilations |
Constante de tensor unidimensional de tipo si64 |
(C10), (C11), (C15) |
(I7) | padding |
Constante de tensor bidimensional de tipo si64 |
(C12) y (C15) |
(I8) | body |
función | (C13) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1), (C14-C16) |
Limitaciones
- (C1)
0 < size(inputs) = size(init_values) = size(results) = N
- (C2)
same(shape(inputs...))
- (C3)
element_type(inputs...) = element_type(init_values...)
- (C4)
size(window_dimensions) = rank(inputs[0])
- (C5)
0 < window_dimensions
- (C6)
size(window_strides) = rank(inputs[0])
- (C7)
0 < window_strides
. - (C8)
size(base_dilations) = rank(inputs[0])
- (C9)
0 < base_dilations
. - (C10)
size(window_dilations) = rank(inputs[0])
. - (C11)
0 < window_dilations
- (C12)
shape(padding) = [rank(inputs[0]), 2]
. - (C13)
body
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,
tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, dondeis_promotable(element_type(inputs[i]), Ei)
- (C14)
same(shape(results...))
. - (C15)
shape(results[0]) = num_windows
donde:dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1
.padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]
.dilated_window_shape = (window_dimensions - 1) * window_dilations + 1
.is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape
.num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1
.
- (C16)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]
resto
Semántica
Realiza el resto de los elementos de los tensores de dividendo lhs
y rhs
, y
produce un tensor result
.
De manera más formal, el signo del resultado se toma del dividendo, y el
el valor absoluto del resultado siempre es menor que el valor absoluto del divisor.
El resto se calcula como lhs - d * rhs
, en el que d
se obtiene de la siguiente manera:
- Para números enteros:
stablehlo.divide(lhs, rhs)
. - Para números de punto flotante:
division(lhs, rhs)
de IEEE-754 con atributo de redondeoroundTowardZero
- Para números complejos: Por definir (#997).
- Para tipos cuantizados:
dequantize_op_quantize(remainder, lhs, rhs, type(result))
.
Para los tipos de elementos de punto flotante, esta operación contrasta con el
Operación remainder
de la especificación IEEE-754 en la que d
es un valor integral
que se aproxima al valor exacto de lhs/rhs
, empatizado con el par.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]
replica_id
Semántica
Produce replica_id
del proceso actual.
Salidas
Nombre | Tipo |
---|---|
result |
Tensor de dimensión 0 de tipo ui32 |
Ejemplos
%result = "stablehlo.replica_id"() : () -> tensor<ui32>
redimensionar
Semántica
Realiza la reforma del tensor operand
en un tensor result
. Conceptualmente,
implica mantener la misma representación canónica, pero
la forma, p.ej., de tensor<2x3xf32>
a tensor<3x2xf32>
o tensor<6xf32>
.
Más formalmente, result[result_index] = operand[operand_index]
donde
result_index
y operand_index
tienen la misma posición en el lexicográfico
ordenado de index_space(result)
y index_space(operand)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1-C3) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
y De lo contrario,quantization_dimension(result)
podría variar.
- (C2)
size(operand) = size(result)
- (C3) Si es
is_per_axis_quantized(operand)
:reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result))
.reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y)
.
Ejemplos
// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]
revertir
Semántica
Invierte el orden de los elementos en operand
a lo largo del dimensions
especificado
y produce un tensor result
. Más formalmente,
result[result_index] = operand[operand_index]
donde:
operand_index[d] = dim(result, d) - result_index[d] - 1
sid
endimensions
.- De lo contrario,
operand_index[d] = result_index[d]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1) y (C3) |
(I2) | dimensions |
Constante de tensor unidimensional de tipo si64 |
(C2) y (C3) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1) y (C3) |
Limitaciones
- (C1)
type(operand) = type(result)
- (C2)
is_unique(dimensions)
- (C3)
0 <= dimensions < rank(result)
Ejemplos
// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
dimensions = array<i64: 1>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]
rng
Semántica
Genera números aleatorios con el algoritmo rng_distribution
y produce un
Tensor result
de una forma determinada shape
.
Si es rng_distribution = UNIFORM
, se generan números al azar.
siguiendo la distribución uniforme durante el intervalo [a, b)
Si es a >= b
,
el comportamiento es indefinido.
Si es rng_distribution = NORMAL
, se generan números al azar.
Se sigue la distribución normal con una media = a
y una desviación estándar = b
.
Si es b < 0
, el comportamiento no está definido.
La implementación define la forma exacta en que se generan los números aleatorios. Para ejemplo, pueden o no ser deterministas, y pueden o no usar estado oculto.
En conversaciones con muchos interesados, esta op resultó tan eficaz dado que es obsoleto, por lo que en el futuro planeamos quitarlo (#597).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | a |
Tensor 0 dimensiones de número entero, booleano o tipo de punto flotante | (C1) y (C2) |
(I2) | b |
Tensor 0 dimensiones de número entero, booleano o tipo de punto flotante | (C1) y (C2) |
(I3) | shape |
Constante de tensor unidimensional de tipo si64 |
(C3) |
(I4) | rng_distribution |
enum de UNIFORM y NORMAL |
(C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, booleano o de punto flotante | (C1-C3) |
Limitaciones
- (C1)
element_type(a) = element_type(b) = element_type(result)
- (C2) Si es
rng_distribution = NORMAL
, entoncesis_float(a)
. - (C3)
shape(result) = shape
Ejemplos
// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
// [1, 0, 1],
// [1, 1, 1],
// [0, 0, 0]
// ]
rng_bit_generator
Semántica
Devuelve un output
lleno de bits aleatorios uniformes y un estado de salida actualizado.
output_state
con el algoritmo de generación de números pseudoaleatorios rng_algorithm
dado un estado inicial initial_state
. Se garantiza que el resultado
función determinista de initial_state
, pero no se garantiza que
son deterministas entre las implementaciones.
rng_algorithm
es una de las siguientes opciones:
DEFAULT
: Es el algoritmo definido por la implementación.THREE_FRY
: Variante definida por la implementación del algoritmo de Threefry.*PHILOX
: Variante definida por la implementación del algoritmo Philox.*
* Consulta: Salmon et al. SC 2011. Números aleatorios paralelos: tan fáciles como 1, 2, 3
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | rng_algorithm |
enum de DEFAULT , THREE_FRY y PHILOX |
(C2) |
(I2) | initial_state |
Tensor unidimensional de tipo ui64 |
(C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
output_state |
Tensor unidimensional de tipo ui64 |
C1 |
output |
tensor de número entero o de tipo de punto flotante |
Limitaciones
- (C1)
type(initial_state) = type(output_state)
- (C2)
size(initial_state)
se define de la siguiente manera:- define la implementación si es
rng_algorithm = DEFAULT
. 2
si esrng_algorithm = THREE_FRY
.2
o3
si esrng_algorithm = PHILOX
.
- define la implementación si es
Ejemplos
// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
// [9236835810183407956, 16087790271692313299],
// [18212823393184779219, 2658481902456610144]
// ]
round_nearest_afz
Semántica
Realiza el redondeo de elementos hacia el número entero más cercano y rompe los empates.
desde cero, en el tensor operand
, y produce un tensor result
. Implementa
la operación roundToIntegralTiesToAway
de la especificación IEEE-754. Para
tipos cuantizados, realiza
dequantize_op_quantize(round_nearest_afz, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]
round_nearest_even
Semántica
Realiza el redondeo de elementos hacia el número entero más cercano y rompe los empates.
hacia el número entero par, en el tensor operand
, y produce un result
tensor. Implementa la operación roundToIntegralTiesToEven
del estándar IEEE-754.
especificación. Para los tipos cuantizados, realiza
dequantize_op_quantize(round_nearest_even, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
rsqrt
Semántica
Realiza la operación de raíz cuadrada recíproca a nivel de los elementos en el tensor operand
y
produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
rSqrt
de IEEE-754. - Para números complejos: raíz cuadrada recíproca compleja.
- Para tipos cuantizados:
dequantize_op_quantize(rsqrt, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]
dispersión
Semántica
Produce tensores results
que son iguales a inputs
, excepto que
varias porciones especificadas por scatter_indices
se actualizan con los valores
updates
usando update_computation
.
En el siguiente diagrama, se muestra cómo los elementos de updates...
se asignan a los elementos de
results...
con un ejemplo concreto. En el diagrama se eligen algunos ejemplos
updates...
indexa y explica en detalle qué results...
indexa
a los que corresponden.
De manera más formal, para todos los update_index
en index_space(updates[0])
:
update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims]
.update_scatter_index = update_index[update_scatter_dims...]
.start_index
se define de la siguiente manera:scatter_indices[si0, ..., :, ..., siN]
en el quesi
son individuales los elementos enupdate_scatter_index
y:
se insertan en Índiceindex_vector_dim
, siindex_vector_dim
<rank(scatter_indices)
- De lo contrario,
[scatter_indices[update_scatter_index]]
.
- Para
d_input
enaxes(inputs[0])
,full_start_index[d_input] = start_index[d_start]
sid_input = scatter_dims_to_operand_dims[d_start]
- De lo contrario,
full_start_index[d_input] = 0
.
- Para
d_input
enaxes(inputs[0])
,full_batching_index[d_input] = update_scatter_index[d_start - (d_start < index_vector_dim ? 0 : 1)]
sid_input = input_batching_dims[i_batching]
yd_start = scatter_indices_batching_dims[i_batching]
- De lo contrario,
full_batching_index[d_input] = 0
.
update_window_index = update_index[update_window_dims...]
.full_window_index = [wi0, ..., 0, ..., wiN]
en el quewi
son individuales elementos enupdate_window_index
y0
se inserta en los índices deinserted_window_dims
yinput_batching_dims
.result_index = full_start_index + full_batching_index + full_window_index
.
Dado eso, results = exec(schedule, inputs)
, donde:
schedule
es una permutación definida por la implementación deindex_space(updates[0])
exec([update_index, ...], results) = exec([...], updated_results)
donde:- Si
result_index
está dentro de los límites deshape(results...)
updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
updated_values = update_computation(results...[result_index], updates_converted)
updated_results
es una copia deresults
conresults...[result_index]
se define enupdated_values...
.- De lo contrario
updated_results = results
.
- Si
exec([], results) = results
.
Si indices_are_sorted
es true
, la implementación puede suponer que
scatter_indices
se ordenan con respecto a scatter_dims_to_operand_dims
,
de lo contrario, el comportamiento será indefinido. Más formalmente, para todos los i1 < i2
de
indices(result)
, full_start_index(i1)
<= full_start_index(i2)
.
Si unique_indices
es true
, la implementación puede suponer que todo
Los índices result_index
que se dispersan son únicos. Si unique_indices
es
true
pero los índices que se dispersan no son únicos, por lo tanto, el comportamiento es
indefinido.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1), (C2), (C4-C6), (C11), (C13), (C18), (C21), (C23-C24) |
(I2) | scatter_indices |
tensor de tipo de número entero | (C4), (C15), (C19), (C22) |
(I3) | updates |
cantidad variádica de tensores o tensores cuantificados por tensor | (C3-C6), (C8) |
(I4) | update_window_dims |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C7-C8) |
(I5) | inserted_window_dims |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C9-C11) |
(I6) | input_batching_dims |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C9), (C12-13), (C17-18), (C20) |
(I7) | scatter_indices_batching_dims |
Constante de tensor unidimensional de tipo si64 |
(C14-C18) |
(I8) | scatter_dims_to_operand_dims |
Constante de tensor unidimensional de tipo si64 |
(C19-C21) |
(I9) | index_vector_dim |
constante de tipo si64 |
(C4), (C16), (C19), (C22) |
(I10) | indices_are_sorted |
constante de tipo i1 |
|
(I11) | unique_indices |
constante de tipo i1 |
|
(I12) | update_computation |
función | (C23) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C24-C25) |
Limitaciones
- (C1)
same(shape(inputs...))
- (C2) `rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims)
- size(input_batching_dims)`.
- (C3)
same(shape(updates...))
- (C4)
shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes)
donde:update_scatter_dim_sizes = shape(scatter_indices)
, excepto que el tamaño de dimensión descatter_indices
correspondiente a No se incluyeindex_vector_dim
.update_window_dim_sizes <= shape(inputs[0])
, excepto que los tamaños de las dimensiones eninputs[0]
correspondientes ainserted_window_dims
niinput_batching_dims
.combine
colocaupdate_scatter_dim_sizes
en los ejes correspondientes aupdate_scatter_dims
yupdate_window_dim_sizes
en los ejes correspondientes aupdate_window_dims
.
- (C5)
0 < size(inputs) = size(updates) = N
- (C6)
element_type(updates...) = element_type(inputs...)
- (C7)
is_unique(update_window_dims) and is_sorted(update_window_dims)
. - (C8)
0 <= update_window_dims < rank(updates[0])
is_unique(concatenate(inserted_window_dims, input_batching_dims))
(C9)- (C10)
is_sorted(inserted_window_dims)
. - (C11)
0 <= inserted_window_dims < rank(inputs[0])
- (C12)
is_sorted(input_batching_dims)
. - (C13)
0 <= input_batching_dims < rank(inputs[0]))
. - (C14)
is_unique(scatter_indices_batching_dims)
. - (C15)
0 <= scatter_indices_batching_dims < rank(scatter_indices)
. - (C16)
index_vector_dim not in scatter_indices_batching_dims
. - (C17)
size(input_batching_dims) == size(scatter_indices_batching_dims)
. - (C18)
dim(inputs[0], input_batching_dims...) = dim(scatter_indices, scatter_indices_batching_dims...)
. - (C19)
size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1
. - (C20)
is_unique(concatenate(scatter_dims_to_operand_dims, input_batching_dims))
. - (C21)
0 <= scatter_dims_to_operand_dims < rank(inputs[0])
. - (C22)
0 <= index_vector_dim <= rank(scatter_indices)
. - (C23)
update_computation
tiene el tipo(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)
, dondeis_promotable(element_type(inputs[i]), Ei)
. - (C24)
shape(inputs...) = shape(results...)
. - (C25)
element_type(results[i]) = Ei
para todos losi
en[0,N)
.
Ejemplos
// %input: [
// [
// [[1, 2], [3, 4], [5, 6], [7, 8]],
// [[9, 10],[11, 12], [13, 14], [15, 16]],
// [[17, 18], [19, 20], [21, 22], [23, 24]]
// ],
// [
// [[25, 26], [27, 28], [29, 30], [31, 32]],
// [[33, 34], [35, 36], [37, 38], [39, 40]],
// [[41, 42], [43, 44], [45, 46], [47, 48]]
// ]
// ]
// %scatter_indices: [
// [
// [[0, 0], [1, 0], [2, 1]],
// [[0, 1], [1, 1], [0, 9]]
// ],
// [
// [[0, 0], [2, 1], [2, 2]],
// [[1, 2], [0, 1], [1, 0]]
// ]
// ]
// %update: [
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ],
// [
// [[1, 1], [1, 1], [1, 1]],
// [[1, 1], [1, 1], [1, 1]]
// ]
// ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
scatter_dimension_numbers = #stablehlo.scatter<
update_window_dims = [3, 4],
inserted_window_dims = [1],
input_batching_dims = [0],
scatter_indices_batching_dims = [1],
scatter_dims_to_operand_dims = [2, 1],
index_vector_dim = 3>,
indices_are_sorted = false,
unique_indices = false
} : (tensor<2x3x4x2xi64>, tensor<2x2x3x2xi64>, tensor<2x2x3x2x2xi64>) -> tensor<2x3x4x2xi64>
// %result: [
// [
// [[3, 4], [6, 7], [6, 7], [7, 8]],
// [[9, 10],[11, 12], [15, 16], [17, 18]],
// [[17, 18], [19, 20], [22, 23], [24, 25]]
// ],
// [
// [[25, 26], [28, 29], [30, 31], [31, 32]],
// [[35, 36], [38, 39], [38, 39], [39, 40]],
// [[41, 42], [44, 45], [46, 47], [47, 48]]
// ]
// ]
select
Semántica
Produce un tensor result
en el que cada elemento se selecciona desde on_true
.
El tensor on_false
basado en el valor del elemento correspondiente de pred
.
Más formalmente, result[result_index] = pred_element ? on_true[result_index] :
on_false[result_index]
, donde pred_element = rank(pred) = 0 ? pred[] :
pred[result_index]
. Para los tipos cuantizados, realiza
dequantize_select_quantize(pred, on_true, on_false, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | pred |
tensor de tipo i1 |
C1 |
(I2) | on_true |
tensor o tensor cuantificado por tensor | (C1-C2) |
(I3) | on_false |
tensor o tensor cuantificado por tensor | (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C2) |
Limitaciones
- (C1)
rank(pred) = 0 or shape(pred) = shape(on_true)
- (C2)
baseline_type(on_true) = baseline_type(on_false) = baseline_type(result)
Ejemplos
// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]
select_and_scatter
Semántica
Dispersa los valores del tensor source
con scatter
basado en la
resultado de reduce_window
del tensor input
con select
y produce
Un tensor result
.
En el siguiente diagrama, se muestra cómo se calculan los elementos de result
a partir de
operand
y source
con un ejemplo concreto.
Más formalmente:
selected_values = reduce_window_without_init(...)
con las siguientes entradas:inputs = [operand].
window_dimensions
,window_strides
ypadding
, que se usan tal cual.base_dilations = windows_dilations = 1
.body
se define de la siguiente manera:
def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>: return select(arg0, arg1) ? arg0 : arg1;
donde
E = element_type(operand)
yreduce_window_without_init
funcionan exactamente comoreduce_window
, excepto que laschedule
de lareduce
(consulta reduce) no incluye valores init. Actualmente es sin especificar qué sucede si la ventana correspondiente no tiene valores (#731).result[result_index] = reduce([source_values], [init_value], [0], scatter)
En el ejemplo anterior, se ilustra lo siguiente:source_values = [source[source_index] for source_index in source_indices]
.selected_index(source_index) = operand_index
siselected_values[source_index]
tiene el elementooperand
desdeoperand_index
.source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C4), (C6), (C8-C11) |
(I2) | source |
tensor o tensor cuantificado por tensor | (C1) y (C2) |
(I3) | init_value |
Tensor de dimensión 0 o tensor cuantificado por tensor | (C3) |
(I4) | window_dimensions |
Constante de tensor unidimensional de tipo si64 |
(C2), (C4), (C5) |
(I5) | window_strides |
Constante de tensor unidimensional de tipo si64 |
(C2), (C6), (C7) |
(I6) | padding |
Constante de tensor bidimensional de tipo si64 |
(C2) y (C8) |
(I7) | select |
función | (C9) |
(I8) | scatter |
función | C10 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C11-C12) |
Limitaciones
- (C1)
element_type(operand) = element_type(source)
- (C2)
shape(source) = num_windows
donde:padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1]
.is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape
.num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1
.
- (C3)
element_type(init_value) = element_type(operand)
- (C4)
size(window_dimensions) = rank(operand)
- (C5)
0 < window_dimensions
- (C6)
size(window_strides) = rank(operand)
- (C7)
0 < window_strides
. - (C8)
shape(padding) = [rank(operand), 2]
- (C9)
select
tiene el tipo(tensor<E>, tensor<E>) -> tensor<i1>
, dondeE = element_type(operand)
- (C10)
scatter
tiene el tipo(tensor<E>, tensor<E>) -> tensor<E>
, dondeis_promotable(element_type(operand), E)
- (C11)
shape(operand) = shape(result)
- (C12)
element_type(result) = E
.
Ejemplos
// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GE>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 2, 1>,
padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]
enviar
Semántica
Envía inputs
a un canal channel_id
y produce un token result
.
Si is_host_transfer
es true
, la operación transfiere los datos al
host. De lo contrario, transfiere los datos a otro dispositivo. Esto significa que
definido por la implementación. Esta marca duplica la información proporcionada en
channel_type
, así que en el futuro solo planeamos conservar uno de ellos
(#666).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
número variádico de tensores o tensores cuantificados | |
(I2) | token |
token |
|
(I3) | channel_id |
constante de tipo si64 |
|
(I4) | channel_type |
enum de DEVICE_TO_DEVICE y DEVICE_TO_HOST |
C1 |
(I5) | is_host_transfer |
constante de tipo i1 |
C1 |
Salidas
Nombre | Tipo |
---|---|
result |
token |
Limitaciones
- (C1)
channel_type
se define de la siguiente manera:DEVICE_TO_HOST
siis_host_transfer = true
,- De lo contrario,
DEVICE_TO_DEVICE
.
Ejemplos
%result = "stablehlo.send"(%operand, %token) {
channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token
shift_left
Semántica
Realiza la operación de desplazamiento a la izquierda en el tensor lhs
con un número de rhs
.
de bits y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo de número entero | C1 |
(I2) | rhs |
tensor de tipo de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero | C1 |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
Ejemplos
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
shift_right_arithmetic
Semántica
Realiza una operación aritmética de desplazamiento hacia la derecha en relación con los elementos en el tensor lhs
. Para ello, haz lo siguiente:
rhs
de bits y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo de número entero | C1 |
(I2) | rhs |
tensor de tipo de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero | C1 |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]
shift_right_logical
Semántica
Realiza la operación de cambio lógico a la derecha de los elementos en el tensor lhs
con rhs
.
cantidad de bits y produce un tensor result
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo de número entero | C1 |
(I2) | rhs |
tensor de tipo de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de número entero | C1 |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
Ejemplos
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
firmar
Semántica
Muestra el signo del operand
a nivel de elementos y produce un tensor result
.
De manera más formal, para cada elemento x
, la semántica se puede expresar mediante
Sintaxis de Python de la siguiente manera:
def sign(x):
if is_integer(x):
if compare(x, 0, LT, SIGNED): return -1
if compare(x, 0, EQ, SIGNED): return 0
return 1
elif is_float(x):
if is_nan(x): return NaN
if compare(x, -0.0, EQ, FLOAT): return -0.0
if compare(x, +0.0, EQ, FLOAT): return +0.0
if compare(x, 0.0, LT, FLOAT): return -1.0
return 1.0
elif is_complex(x):
if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
return divide(x, convert(abs(x), type(x)))
Para los tipos cuantizados, realiza
dequantize_op_quantize(sign, operand, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de número entero con firma, de punto flotante o de tipo complejo, o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero con firma, de punto flotante o de tipo complejo, o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
seno
Semántica
Realiza una operación de seno en elementos en el tensor operand
y produce un result
.
tensor. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
sin
de IEEE-754. - Para números complejos: seno complejo.
- Para tipos cuantizados:
dequantize_op_quantize(sine, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]
porción
Semántica
Extrae una porción de operand
mediante índices de inicio calculados de forma estática
y produce un tensor result
. start_indices
contienen los índices iniciales de
la porción de cada dimensión, limit_indices
contiene los índices finales.
(exclusivo) para la porción de cada dimensión, y strides
contiene las zancadas.
para cada dimensión.
Más formalmente, result[result_index] = operand[operand_index]
donde
operand_index = start_indices + result_index * strides
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado por tensor | (C1-C3), (C5) |
(I2) | start_indices |
Constante de tensor unidimensional de tipo si64 |
(C2), (C3) y (C5) |
(I3) | limit_indices |
Constante de tensor unidimensional de tipo si64 |
(C2), (C3) y (C5) |
(I4) | strides |
Constante de tensor unidimensional de tipo si64 |
(C2) y (C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado por tensor | (C1), (C5) |
Limitaciones
- (C1)
element_type(operand) = element_type(result)
- (C2)
size(start_indices) = size(limit_indices) = size(strides) = rank(operand)
- (C3)
0 <= start_indices <= limit_indices <= shape(operand)
- (C4)
0 < strides
- (C5)
shape(result) = ceil((limit_indices - start_indices) / strides)
Ejemplos
// %operand: [
// [0, 0, 0, 0],
// [0, 0, 1, 1],
// [0, 0, 1, 1]
// ]
%result = "stablehlo.slice"(%operand) {
start_indices = array<i64: 1, 2>,
limit_indices = array<i64: 3, 4>,
strides = array<i64: 1, 1>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
// [1, 1],
// [1, 1]
// ]
ordenar
Semántica
Ordena juntas porciones de 1 dimensión de inputs
a lo largo de la dimensión dimension
.
de acuerdo con un comparator
y produce results
.
A diferencia de entradas similares en otras operaciones, dimension
permite valores negativos.
con la semántica que se describe a continuación. Es posible que esto no esté permitido en el futuro
por motivos de coherencia
(#1377).
Si is_stable
es verdadero, el orden es estable, es decir, el orden relativo de
elementos considerados iguales según el comparador. Para el caso
donde hay una sola entrada, dos elementos e1
y e2
se consideran
iguales por el comparador si y solo si
comparator(e1, e2) = comparator(e2, e1) = false
Consulta la formalización a continuación
sobre cómo esto se generaliza a varias entradas.
De manera más formal, para todos los result_index
en index_space(results[0])
:
adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension
.result_slice = [ri0, ..., :, ..., riR-1]
en el queriN
son individuales elementos enresult_index
, y:
se inserta enadjusted_dimension
.inputs_together = (inputs[0]..., ..., inputs[N-1]...)
.results_together[result_slice] = sort(inputs_together[result_slice], comparator_together)
.- donde
sort
ordena una porción de 1 dimensión en orden no descendente y se espera quecomparator_together
muestratrue
si el argumento del lado izquierdo es menor que el segundo argumento de la derecha. def comparator_together(lhs_together, rhs_together): args = [] for (lhs_el, rhs_el) in zip(lhs_together, rhs_together): args.append(lhs_el) args.append(rhs_el) return comparator(*args)
(results[0]..., ..., results[N-1]...) = results_together
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | inputs |
cantidad variádica de tensores o tensores cuantificados por tensor | (C1-C5) |
(I2) | dimension |
constante de tipo si64 |
(C4) |
(I3) | is_stable |
constante de tipo i1 |
|
(I4) | comparator |
función | C5 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
cantidad variádica de tensores o tensores cuantificados por tensor | (C2) y (C3) |
Limitaciones
- (C1)
0 < size(inputs)
- (C2)
type(inputs...) = type(results...)
- (C3)
same(shape(inputs...) + shape(results...))
- (C4)
-R <= dimension < R
, dondeR = rank(inputs[0])
. - (C5)
comparator
tiene el tipo(tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>
, dondeEi = element_type(inputs[i])
.
Ejemplos
// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
%predicate = "stablehlo.compare"(%arg0, %arg1) {
comparison_direction = #stablehlo<comparison_direction GT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
dimension = 0 : i64,
is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]
rq
Semántica
Realiza la operación de raíz cuadrada de elementos en el tensor operand
y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
squareRoot
de IEEE-754. - Para números complejos: raíz cuadrada compleja
- Para tipos cuantizados:
dequantize_op_quantize(sqrt, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]
subtract
Semántica
Realiza la resta a nivel de elementos de dos tensores, lhs
y rhs
, y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para los números enteros: resta de números enteros.
- Para números de punto flotante:
subtraction
de IEEE-754. - Para números complejos: resta compleja.
- Para tipos cuantizados:
dequantize_op_quantize(subtract, lhs, rhs, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
(I2) | rhs |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de número entero, de punto flotante o complejo, o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(lhs) = baseline_type(rhs) = baseline_type(result)
Ejemplos
// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]
tan
Semántica
Realiza la operación tangente en cuanto a elementos en el tensor operand
y produce un
Tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
tan
de IEEE-754. - Para números complejos: tangente compleja.
- Para tipos cuantizados:
dequantize_op_quantize(tan, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [
// [0.0, 1.57079632], // [0, pi/2]
// [3.14159265, 4.71238898] // [pi, 3pi/2]
// ]
%result = "stablehlo.tan"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [
// [0.0, 1.63312e+16],
// [0.0, 5.44375e+15]
// ]
tanh
Semántica
Realiza la operación tangente hiperbólica a nivel de elementos en el tensor operand
y
produce un tensor result
. Según el tipo de elemento, hace lo siguiente:
- Para números de punto flotante:
tanh
de IEEE-754. - Para números complejos: tangente hiperbólica compleja
- Para tipos cuantizados:
dequantize_op_quantize(tanh, operand, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_type(operand) = baseline_type(result)
Ejemplos
// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]
transponer
Semántica
Permuta las dimensiones del tensor operand
con permutation
y produce un
Tensor result
. Más formalmente, result[result_index] = operand[operand_index]
donde result_index[d] = operand_index[permutation[d]]
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor o tensor cuantificado | (C1-C4) |
(I2) | permutation |
Constante de tensor unidimensional de tipo si64 |
(C2-C4) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor o tensor cuantificado | (C1), (C3-C4) |
Limitaciones
- (C1)
element_type(result)
se obtiene de la siguiente manera:element_type(operand)
, si es!is_per_axis_quantized(operand)
.element_type(operand)
, excepto quequantization_dimension(operand)
y De lo contrario,quantization_dimension(result)
podría variar.
- (C2)
permutation
es una permutación derange(rank(operand))
. - (C3)
shape(result) = dim(operand, permutation...)
- (C4) Si es
is_per_axis_quantized(result)
, entoncesquantization_dimension(operand) = permutation(quantization_dimension(result))
Ejemplos
// %operand: [
// [[1,2], [3,4], [5,6]],
// [[7,8], [9,10], [11,12]]
// ]
%result = "stablehlo.transpose"(%operand) {
permutation = array<i64: 2, 1, 0>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
// [[1,7], [3,9], [5,11]],
// [[2,8], [4,10], [6,12]]
// ]
triangular_solve
Semántica
Resolver lotes de sistemas de ecuaciones lineales con un triángulo triangular inferior o superior matrices de coeficientes.
Más formalmente, con a
y b
, result[i0, ..., iR-3, :, :]
es la solución.
a op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :]
cuando left_side
es
true
o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :]
cuando
left_side
es false
, y resuelve la variable x
en la que se determina op(a)
.
antes del transpose_a
, que puede ser una de las siguientes opciones:
NO_TRANSPOSE
: Realiza la operación cona
tal como está.TRANSPOSE
: Realiza una operación en la transposición dea
.ADJOINT
: Realiza una operación en la transposición conjugada dea
.
Los datos de entrada solo se leen desde el triángulo inferior de a
si lower
es true
o
de lo contrario, el triángulo superior es a
. Los datos de salida se devuelven en el mismo triángulo.
Los valores del otro triángulo están definidos por la implementación.
Si unit_diagonal
es verdadero, la implementación puede suponer que la línea diagonal
elementos de a
son iguales a 1; de lo contrario, el comportamiento no estará definido.
Para los tipos cuantizados, realiza
dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower,
unit_diagonal, transpose_a), a, b, type(result))
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | a |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | (C1-C3) |
(I2) | b |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | (C1-C4) |
(I3) | left_side |
constante de tipo i1 |
(C3) |
(I4) | lower |
constante de tipo i1 |
|
(I5) | unit_diagonal |
constante de tipo i1 |
|
(I6) | transpose_a |
enum de NO_TRANSPOSE , TRANSPOSE y ADJOINT |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de punto flotante o tipo complejo o tensor cuantificado por tensor | C1 |
Limitaciones
- (C1)
baseline_element_type(a) = baseline_element_type(b)
- (C2)
2 <= rank(a) = rank(b) = R
- (C3) La relación entre
shape(a)
yshape(b)
se define de la siguiente manera:shape(a)[:-3] = shape(b)[:-3]
.dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1)
.
- (C4)
baseline_type(b) = baseline_type(result)
Ejemplos
// %a = [
// [1.0, 0.0, 0.0],
// [2.0, 4.0, 0.0],
// [3.0, 5.0, 6.0]
// ]
// %b = [
// [2.0, 0.0, 0.0],
// [4.0, 8.0, 0.0],
// [6.0, 10.0, 12.0]
// ]
%result = "stablehlo.triangular_solve"(%a, %b) {
left_side = true,
lower = true,
unit_diagonal = false,
transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
// [2.0, 0.0, 0.0],
// [0.0, 2.0, 0.0],
// [0.0, 0.0, 2.0]
// ]
tuple
Semántica
Produce una tupla result
a partir de los valores val
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | val |
cantidad variádica de valores | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tuple | C1 |
Limitaciones
- (C1)
result
tiene el tipotuple<E0, ..., EN-1>
, en el queEi = type(val[i])
.
Ejemplos
// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))
uniform_dequantize
Semántica
Realiza la conversión a nivel de elementos del tensor cuantificado operand
en un
Tensor de punto flotante result
según los parámetros de cuantización definidos
por el tipo operand
.
Más formalmente, result = dequantize(operand)
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor cuantificado | (C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo de punto flotante | (C1) y (C2) |
Limitaciones
- (C1)
shape(operand) = shape(result)
- (C2)
element_type(result) = expressed_type(operand)
Ejemplos
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
uniform_quantize
Semántica
Realiza la conversión a nivel de elementos de un tensor de punto flotante o un tensor cuantificado.
operand
a un tensor cuantificado result
según la cuantización
parámetros definidos por el tipo result
.
Más formalmente,
- Si es
is_float(operand)
:result = quantize(operand, type(result))
.
- Si es
is_quantized(operand)
:float_result = dequantize(operand)
.result = quantize(float_result, type(result))
.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
tensor de punto flotante o tipo cuantificado | (C1) y (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor cuantificado | (C1) y (C2) |
Limitaciones
- (C1)
shape(operand) = shape(result)
- (C2)
expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand)
Ejemplos
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]
// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
mientras
Semántica
Produce el resultado de la ejecución de la función body
0 o más veces, mientras que
La función cond
da como resultado true
. De manera más formal, la semántica puede expresarse
con la sintaxis de Python de la siguiente manera:
internal_state = operand
while cond(*internal_state):
internal_state = body(*internal_state)
results = internal_state
Se está por definir el comportamiento de un bucle infinito (#383).
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | operand |
número variádico de tensores, tensores cuantificados o tokens | (C1-C3) |
(I2) | cond |
función | C1 |
(I3) | body |
función | (C2) |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
results |
número variádico de tensores, tensores cuantificados o tokens | (C3) |
Limitaciones
- (C1)
cond
tiene el tipo(T0, ..., TN-1) -> tensor<i1>
, en el queTi = type(operand[i])
- (C2)
body
tiene el tipo(T0, ..., TN-1) -> (T0, ..., TN-1)
, dondeTi = type(operand[i])
- (C3)
type(results...) = type(operand...)
Ejemplos
// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%cond = "stablehlo.compare"(%arg0, %ten) {
comparison_direction = #stablehlo<comparison_direction LT>
} : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %cond : tensor<i1>
}, {
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%new_sum = stablehlo.add %arg1, %one : tensor<i64>
%new_i = stablehlo.add %arg0, %one : tensor<i64>
stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10
xor
Semántica
Realiza XOR a nivel de elementos de dos tensores, lhs
y rhs
, y produce un result
.
tensor. Según el tipo de elemento, hace lo siguiente:
- Para valores booleanos: XOR lógico.
- Para números enteros: XOR a nivel de bits.
Entradas
Etiqueta | Nombre | Tipo | Limitaciones |
---|---|---|---|
(I1) | lhs |
tensor de tipo booleano o de número entero | C1 |
(I2) | rhs |
tensor de tipo booleano o de número entero | C1 |
Salidas
Nombre | Tipo | Limitaciones |
---|---|---|
result |
tensor de tipo booleano o de número entero | C1 |
Limitaciones
- (C1)
type(lhs) = type(rhs) = type(result)
Ejemplos
// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]
// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]
Interoperabilidad de dialectos
Por el momento, los programas StableHLO en el entorno a veces contienen operaciones no están definidas por StableHLO.
Módulo, función, llamada y devolución
StableHLO usa operaciones upstream de MLIR para ModuleOp, FuncOp, CallOp y Op. de retorno Esto se logró para mejorar la interoperabilidad con la maquinaria MLIR existente, ya que pases útiles están dirigidos a FuncOp y ModuleOp, y muchas compilaciones canalizaciones esperan que estas ops estén presentes. Las garantías de compatibilidad total que se aplica a estas operaciones. Si algo cambia en estas operaciones de forma incompatible (es decir, eliminación), se agregarán equivalentes de StableHLO para preservar compatibilidad.
CHLO
El opset de CHLO contiene operaciones de nivel superior que se descomponen en StableHLO. Por el momento, no hay garantías de compatibilidad para CHLO. Para la compatibilidad garantías, el pase chlo-legalize-to-stablehlo debe usarse antes de la serialización.
Operaciones de formas
Es un caso de uso común en la comunidad usar ciertas operaciones
Dialectos de MLIR en programas dinámicos StableHLO para realizar cálculos de formas.
Por lo general, se incluyen el dialecto shape
ops como shape_of
o num_elements
, o dialecto tensor
ops como dim
o from_elements
, y el tipo index
integrado.
El RFC de Dynamism > O2
indica que están fuera del alcance, sin embargo, cierta compatibilidad con los tipos index
es
se incluyen con fines de interoperabilidad. No hay garantías de compatibilidad para estos
ops o tipos. El shape-legalize-to-stablehlo
pase se puede usar para convertir estas operaciones en operaciones StableHLO totalmente compatibles.
Operaciones obsoletas
Hay varias operaciones StableHLO que se heredaron de MHLO que dejaron de estar disponibles y están fuera de StableHLO. Puedes conocer todos los detalles en StableHLO v1.0 Cleanup #2283. El problema de la herramienta de seguimiento para estas bajas es #2340.
Estas operaciones se dividen en las siguientes categorías:
- "No está en HLO" de operaciones StableHLO; inicialmente formaban parte de
el opset StableHLO, pero luego se consideró que no encajaba bien:
broadcast
,create_token
,cross-replica-sum
,dot
yeinsum
,torch_index_select
yunary_einsum
(n.o 3). - ops sin usar: estas operaciones pueden haber sido útiles en algún momento, pero
estaban subdesarrolladas, o las canalizaciones que usan estas operaciones han sido
o refactorización para no necesitarlos más. Esto incluye a
map
,tuple
(#598), Comparaciones entreget_tuple_element
,rng
ycomplex
#560, y convoluciónwindow_reversal
(#1181).
Algunas de estas operaciones se pueden quitar con facilidad, ya que se pueden expresar mediante
operaciones existentes (broadcast
, create_token
, cross-replica-sum
, dot
,
unary_einsum
) y se quitará después del período de compatibilidad existente
pases (6 meses). Aún se están explorando otras opciones para quitarlas (einsum
,
get_tuple_element
, map
, torch_index_select
de rng
, tuple
y complex
comparaciones, window_reversal
). Comentarios pendientes de la comunidad,
estas operaciones se quitarán o se agregarán a la especificación con compatibilidad total. Hasta
se conocen y solo se garantiza su compatibilidad durante 6 meses.
Ejecución
Ejecución secuencial
Un programa StableHLO se ejecuta proporcionando valores de entrada a la función main
.
y calcular valores de salida. Los valores de salida de una función se calculan
Ejecuta el gráfico de operaciones con permisos de administrador en la op return
correspondiente.
El orden de ejecución se define en la implementación siempre que esté alineado con
Dataflow, es decir, si las operaciones se ejecutan antes de su uso. En StableHLO, todas
consumen un token y producen un token (varios tokens pueden
multiplexar en un token a través de after_all
), por lo que el orden de ejecución del lado
efectos también se alinea con Dataflow. Por ejemplo, en el siguiente programa,
hay dos órdenes de ejecución posibles: %0
→ %1
→ %2
→ return
y
%1
→ %0
→ %2
→ return
.
func.func @main() -> tensor<f64> {
%0 = stablehlo.constant dense<1.0> : tensor<f64>
%1 = stablehlo.constant dense<2.0> : tensor<f64>
%2 = stablehlo.add %0, %1 : tensor<f64>
return %2 : tensor<f64>
}
Más formalmente, un proceso StableHLO es una combinación de lo siguiente:
1) un programa StableHLO, 2) estados de operación (aún no ejecutado)
ya se ejecutó) y 3) los valores intermedios en los que funciona el proceso.
El proceso comienza con valores de entrada a la función main
, avanza
el gráfico de ops que actualiza los estados de operación y los valores intermedios, y
finaliza con los valores de salida. Se está por definir una formalización adicional
(#484).
Ejecución paralela
Los programas StableHLO se pueden ejecutar en paralelo y se pueden organizar en una cuadrícula de procesos 2D.
de num_replicas
por num_partitions
, que ambos tienen el tipo ui32
.
En la cuadrícula de procesos de StableHLO, num_replicas * num_partitions
de StableHLO
procesos se ejecutan al mismo tiempo. Cada proceso tiene un
process_id = (replica_id, partition_id)
, donde
replica_id
en replica_ids = range(num_replicas)
y
partition_id
en partition_ids = range(num_partitions)
, que tienen
escribe ui32
.
El tamaño de la cuadrícula de procesos se conoce estáticamente para cada programa (en el
en el futuro, planeamos hacerla
parte explícita de los programas StableHLO
#650) y la posición
dentro de la cuadrícula de procesos se conoce estáticamente para cada proceso. Cada proceso tiene
acceso a su posición dentro de la cuadrícula de procesos a través de replica_id
y
partition_id
de operaciones
Dentro de la cuadrícula de procesos, los programas pueden ser todos iguales (en el módulo Programa, varios datos" estilo), todos pueden ser diferentes (en la sección "Varios programas, Varios datos" estilo) o algún elemento intermedio. En el futuro, estamos planeando para incorporar compatibilidad con otras expresiones idiomáticas de definición de programas paralelos de StableHLO, incluido GSPMD (#619).
Dentro de la matriz de procesos, los procesos son, en su mayoría, independientes entre sí. Tienen estados de operación separados, valores de entrada/intermedios/salida separados. y la mayoría de las ops se ejecutan por separado entre procesos, con el excepto por una pequeña cantidad de operaciones colectivas que se describen a continuación.
Dado que la ejecución de la mayoría de las ops solo usa valores de la misma
por lo general, no es ambiguo hacer referencia a estos valores por sus nombres.
Sin embargo, cuando se describe la semántica de ops colectivas, eso no es suficiente.
que genera la notación name@process_id
para referirse al valor name
dentro de un proceso particular. (Desde esa perspectiva, los name
descalificados pueden
visto como una abreviatura de name@(replica_id(), partition_id())
).
El orden de ejecución en los procesos está definido por la implementación, excepto por la sincronización introducida por la comunicación punto a punto y las operaciones colectivas como se describe a continuación.
Comunicación punto a punto
Los procesos del StableHLO pueden comunicarse entre sí a través de
Canales StableHLO. Un canal se representa con un tipo de ID positivo
si64
A través de varias operaciones, es posible enviar valores a canales y
recibirlos de los canales.
Formalización adicional, p.ej., de dónde provienen estos IDs de canal, cómo procesos que los programas los reconocen y qué tipo de sincronización es que presentan, se debe definir (#484).
Comunicación por transmisión
Cada proceso de StableHLO tiene acceso a dos interfaces de transmisión:
- Entrada que se puede leer.
- Salida en la que se puede escribir.
A diferencia de los canales, que se usan para la comunicación entre procesos y, por lo tanto, tienen procesos en ambos extremos, las entradas y salidas tienen sus otros final de la implementación definida.
Formalización adicional, p.ej., cómo la comunicación por transmisión influye en la ejecución el orden y el tipo de sincronización que este ingresa, por definir (#484).
Operaciones colectivas
Hay seis operaciones colectivas en StableHLO: all_gather
, all_reduce
,
all_to_all
, collective_broadcast
, collective_permute
y
reduce_scatter
Todas estas operaciones dividen los procesos en el proceso StableHLO
cuadrícula en grupos de procesos estables y ejecuta un cálculo conjunto
de cada grupo de procesos, independientemente
de otros grupos de procesos.
Dentro de cada grupo de procesos, las operaciones colectivas pueden introducir una sincronización barrera. Formalización adicional, p.ej., explicar cuándo exactamente se produce la sincronización, cómo llegan los procesos a esa barrera y qué sucede si no lo hacen, se está por definir (#484).
Si el grupo de procesos implica comunicación entre particiones, es decir, hay
procesos en el grupo de procesos cuyos IDs de partición son diferentes, entonces la ejecución
de la op colectiva necesita un canal y esta debe proporcionar
positivo channel_id
de tipo si64
. La comunicación entre réplicas no necesita
canales.
Los cálculos realizados por las operaciones colectivas son específicos de las operaciones individuales y se describen en las secciones de operaciones individuales anteriores. Sin embargo, las estrategias en la que la cuadrícula de procesos se divide en grupos de procesos y se comparten entre estas operaciones y se describen en esta sección. Más formalmente, StableHLO admite el siguiendo cuatro estrategias.
cross_replica
Solo las comunicaciones entre réplicas ocurren dentro de cada grupo de procesos. Esta
la estrategia toma replica_groups
, una lista de listas de IDs de réplica, y procesa
un producto cartesiano de replica_groups
por partition_ids
. replica_groups
debe tener elementos únicos y abarcar todos los replica_ids
. Más formalmente, usar
Sintaxis de Python:
def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
for partition_id in partition_ids:
process_group = []
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]]
y num_partitions = 2
,
Se producirá cross_replica
[[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]]
cross_partition
Solo las comunicaciones entre particiones ocurren dentro de cada grupo de procesos. Esta
estrategia toma partition_groups
, una lista de listas de IDs de partición, y
calcula un producto cartesiano de partition_groups
por replica_ids
.
partition_groups
debe tener elementos únicos y abarcar todos los partition_ids
.
De manera más formal, usando la sintaxis de Python:
def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
for partition_group in partition_groups:
for replica_id in replica_ids:
process_group = []
for partition_id in partition_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para partition_groups = [[0, 1]]
y num_replicas = 4
,
Se producirá cross_partition
[[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]]
cross_replica_and_partition
Las comunicaciones entre réplicas y particiones pueden ocurrir dentro de cada
grupo de procesos. Esta estrategia toma replica_groups
, una lista de listas de
IDs de réplica y calcula los productos cartesianos de cada replica_group
según
partition_ids
Las replica_groups
deben tener elementos únicos y abarcar todas
replica_ids
De manera más formal, usando la sintaxis de Python:
def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
for replica_group in replica_groups:
process_group = []
for partition_id in partition_ids:
for replica_id in replica_group:
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para replica_groups = [[0, 1], [2, 3]]
y num_partitions = 2
,
Se producirá cross_replica_and_partition
[[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]]
flattened_ids
Esta estrategia toma flattened_id_groups
, una lista de listas "acopladas"
IDs de proceso con el formato replica_id * num_partitions + partition_id
y
los convierte en IDs de procesos. flattened_id_groups
debe tener elementos únicos
y abarca todos los process_ids
. De manera más formal, usando la sintaxis de Python:
def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
for flattened_id_group in flattened_id_groups:
process_group = []
for flattened_id in flattened_id_group:
replica_id = flattened_id // num_partitions
partition_id = flattened_id % num_partitions
process_group.append((replica_id, partition_id))
yield process_group
Por ejemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]]
,
num_replicas = 4
y num_partitions = 2
, flattened_ids
producirán
[[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]]
Precisión
Por el momento, StableHLO no ofrece garantías sobre la precisión numérica, pero esto puede cambiar en el futuro (#1156).
Semántica de ejecución de operaciones cuantizadas
La interpretación de operaciones StableHLO cuantificadas puede variar según el los requisitos y capacidades de hardware. Por ejemplo, algunos hardware pueden optar por a interpretar operaciones cuantizadas con la fórmula "decuantizar, realizar operaciones y, finalmente, cuantizar" de administración de amenazas. Otros pueden realizar todo con aritmética de números enteros. Por lo tanto, la interpretación de las operaciones cuantificadas del StableHLO se determinan exclusivamente por el método para implementarlos. La interpretación de la cuantización híbrida (#1575) deben basarse en lo siguiente: la semántica según lo prescrito en la especificación (mediante 1792).
Errores
Los programas StableHLO se validan a través de un amplio conjunto de restricciones para ops individuales, lo que descarta muchas clases de errores antes del tiempo de ejecución. Sin embargo, aún son posibles las condiciones de error, p.ej., a través de desbordamientos de enteros, accesos fuera de los límites, etc. A menos que se indique explícitamente, todos estos errores dar como resultado un comportamiento definido por la implementación, pero esto puede cambiar future (#1157).
Excepciones de punto flotante
Como excepción a esta regla, las excepciones de punto flotante en los programas StableHLO
tienen comportamientos bien definidos. Las operaciones que dan como resultado excepciones definidas por el
estándar IEEE-754 (operación no válida, división por cero, desbordamiento, subdesbordamiento o
excepciones inexactas) producen resultados predeterminados (como se define en el estándar) y
continuar la ejecución sin activar la marca de estado correspondiente similares a
Control de excepciones de raiseNoFlag
del estándar. Excepciones para cargas no estándar
(p.ej., aritmética compleja y ciertas funciones trascendentales) se
definido por la implementación.
La forma no coincide
StableHLO admite tensores de forma dinámica. Sin embargo, las formas deben coincidir en tiempo de ejecución; de lo contrario, el comportamiento será indefinido. StableHLO no indica explícitamente proporciona una op que puede confirmar que un tensor tiene una forma dada en el tiempo de ejecución. Generar el código correcto es responsabilidad del productor.
Como ejemplo específico, el siguiente programa es válido. Sin embargo, en el tiempo de ejecución,
las formas exactas de %arg0
y %arg1
deberán ser iguales; de lo contrario,
el comportamiento del programa no está definido:
func.func @foo(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<?xi32>
return %0 : tensor<?xi32>
}
Notation
Para describir la sintaxis, en este documento se usa el sabor ISO modificado de EBNF
sintaxis (ISO/IEC 14977:1996,
Wikipedia),
con dos modificaciones: 1) las reglas se definen con ::=
en lugar de =
2) la concatenación se expresa mediante la yuxtaposición en lugar de ,
.
Para describir la semántica (es decir, dentro de las secciones “Tipos”, “Constantes” y “Operaciones”), usamos fórmulas basadas en la sintaxis de Python extendida con compatibilidad para expresar de forma concisa operaciones de array como se describe a continuación. Esto funciona bien para fragmentos pequeños de código, pero en casos excepcionales cuando se agregan fragmentos más grandes necesario, usamos la sintaxis normal de Python que siempre se presenta de forma explícita.
Fórmulas
Exploremos cómo funcionan las fórmulas a partir de un ejemplo de la dot_general
especificación. Una de las restricciones de esta operación se ve de la siguiente manera:
dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...)
Los nombres usados en esta fórmula provienen de dos fuentes: 1) funciones globales,
es decir, dim
y 2) definiciones de los miembros del elemento del programa correspondiente, es decir,
Entradas lhs
, lhs_batching_dimensions
, rhs
y rhs_batching_dimensions
definidos en el campo “Entradas” de dot_general
.
Como se mencionó anteriormente, la sintaxis de esta fórmula se basa en Python, con algunas extensiones orientadas a la brevedad. Para entender la fórmula, transformemos en la sintaxis normal de Python.
A) En estas fórmulas, usamos =
para representar la igualdad, así que el primer paso
Para obtener la sintaxis de Python, reemplaza =
por ==
, de la siguiente manera:
dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...)
B) Además, estas fórmulas admiten elipses (...
) que convierten expresiones escalares
en expresiones de tensor. En pocas palabras, f(xs...)
significa "para cada
los escalares x
en el tensor xs
, calcular un f(x)
escalar y, luego, mostrar todos
estos resultados escalares como un resultado de tensor". En la sintaxis básica de Python,
nuestra fórmula de ejemplo se convierte en:
[dim(lhs, dim1) for dim1 in lhs_batching_dimensions] ==
[dim(rhs, dim2) for dim2 in rhs_batching_dimensions]
Gracias a los puntos suspensivos, a menudo es posible evitar trabajar al nivel de
escalares individuales. Sin embargo, en algunos casos complicados, los datos semiinformales de nivel inferior
la sintaxis se puede usar como en la fórmula start_indices[bi0, ..., :, ..., biN]
de la especificación gather
. Para ser concisos, no
proporcionan un formalismo exacto para traducir dicha sintaxis a Python normal, en
espera que sea intuitivamente comprensible caso por caso.
Haznos saber si algunas fórmulas específicas parecen opacas, y trataremos de
y mejorarlas.
Además, notarás que las fórmulas usan puntos suspensivos para expandir todo tipo de listas, incluidos los tensores y las listas de tensores (que, por ejemplo, pueden surgir de un modelo de tensores), etc. Esta es otra área en la que no se proporciona formalismo (p.ej., las listas ni siquiera forman parte del sistema de tipos StableHLO) y en cambio, dependen de la comprensión intuitiva.
C) El último medio notable que empleamos es el implícito la transmisión de contenido. Aunque el opset StableHLO no admite transmisiones implícitas, hacen las fórmulas, también al servicio de la concisión. En pocas palabras, si un modelo escalar se usa en un contexto en el que se espera un tensor, el escalar se transmite a la forma esperada.
Para continuar con el ejemplo de dot_general
, esta es otra restricción:
0 <= lhs_batching_dimensions < rank(lhs)
Como se define en el dot_general
especificación, lhs_batching_dimensions
es un tensor, sin embargo, tanto 0
como
rank(lhs)
son escalares. Después de aplicar una transmisión implícita, la fórmula
se convierte en [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)]
.
Cuando se aplica a una operación dot_general
en particular, esta fórmula
evaluar en un tensor de booleanos. Cuando las fórmulas se usan como restricciones,
se aplica si la fórmula se evalúa como true
o como un tensor que
solo tiene true
elementos.
Nombres
En las fórmulas, el alcance léxico incluye: 1) funciones globales, 2) definiciones de miembros,
3) las definiciones locales. A continuación, se proporciona la lista de funciones globales. La lista de las definiciones de elementos depende del elemento del programa al que se aplica la notación se aplicó a:
- Para las operaciones, las definiciones de miembros incluyen nombres ingresados en "Entradas" y "Resultados" secciones.
- Para todo lo demás, las definiciones de miembros incluyen partes estructurales de la
Elemento de programa que lleva el nombre de los no terminales del EBNF correspondiente. La mayoría de
con el tiempo, los nombres de estas partes estructurales se obtienen convirtiendo la
nombres de las terminales no terminadas a snake case (p.ej.,
IntegerLiteral
=>integer_literal
), pero a veces los nombres se abrevian en el proceso (p.ej.,QuantizationStorageType
=>storage_type
), en cuyo caso los nombres son de manera explícita de manera similar a las “Entradas” / "Resultados" secciones en funcionamiento. y las especificaciones del servicio. - Además, las definiciones de miembros siempre incluyen
self
para hacer referencia al elemento de programa correspondiente.
Valores
Cuando se evalúan las fórmulas, funcionan con los siguientes tipos de valores:
1) Value
(valores reales, p.ej., dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
;
siempre conocen sus tipos),
2) Placeholder
(valores futuros, p.ej., lhs
, rhs
o result
; su valor real
aún no se conocen los valores, solo se conocen sus tipos),
3) Type
(tipos definidos en la sección "Tipos"),
4) Function
(funciones globales, según se define en la sección "Funciones").
Según el contexto, los nombres pueden referirse a diferentes valores. Más
específicamente, la “semántica” para ops (y equivalentes para otros programas
) define la lógica del entorno de ejecución, por lo que todas las entradas están disponibles como Value
.
En cambio, las "Restricciones" ops (y equivalentes) define
"tiempo de compilación" lógica, es decir, algo que generalmente se ejecuta antes del tiempo de ejecución,
por lo que solo están disponibles las entradas constantes como Value
, y otras entradas
disponible solo como Placeholder
.
Nombres | En "Semántica" | En “Restricciones” |
---|---|---|
Funciones globales | Function |
Function |
Entradas constantes | Value |
Value |
Entradas no constantes | Value |
Placeholder |
Salidas | Value |
Placeholder |
Definiciones locales | Depende de la definición | Depende de la definición |
Consideremos una operación transpose
de ejemplo:
%result = "stablehlo.transpose"(%operand) {
permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
Para esta operación, permutation
es una constante, por lo que está disponible como Value
.
tanto en semántica como en restricciones. Por el contrario, operand
y result
son
Está disponible como Value
en semántica, pero solo como Placeholder
en restricciones.
Funciones
Construcción de tipos
No hay funciones que se puedan usar para construir tipos. En cambio, directamente
usar sintaxis de tipo porque suele ser más conciso. P.ej.,
(tensor<E>, tensor<E>) -> (tensor<E>)
en lugar de function_type(
[tensor_type([], E), tensor_type([], E)], [tensor_type([], E)])
.
Funciones en tipos
element_type
se define en tipos de tensores y tipos de tensores cuantificados, y muestra, respectivamente,TensorElementType
oQuantizedTensorElementType
. parte delTensorType
oQuantizedTensorType
correspondiente.
def element_type(x: Value | Placeholder | Type):
if type(x) == TensorType:
return tensor_element_type(x)
if type(x) == QuantizedTensorType:
return quantized_tensor_element_type(x)
if type(x) is not Type:
return element_type(type(x))
is_per_axis_quantized(x: Value | Placeholder | Type) -> Value
es una combinación de teclas parais_quantized(x) and quantization_dimension(x) is not None
.is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value
es un atajo parais_quantized(x) and quantization_dimension(x) is None
.is_promotable(x: Type, y: Type) -> bool
verifica si se puede promocionar el tipox
. para escribiry
. Cuandox
yy
seanQuantizedTensorElementType
, la promoción solo se aplica astorage_type
. Esta versión específica de la promoción está que se usa actualmente en el contexto del procesamiento de la reducción (consultar RFC para obtener más detalles).
def is_promotable(x: Type, y: Type) -> Value:
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
if is_same_type == False:
return False
if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)
if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))
if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))
return false
is_quantized(x: Value | Placeholder | Type) -> Value
es la combinación de teclas parais_quantized_tensor_element_type(x)
is_type_name(x: Value | Placeholder | Type) -> Value
Disponible para todos de tipos de datos. Por ejemplo,is_float(x)
muestratrue
six
es unaFloatType
. Six
es un valor o un marcador de posición, esta función es un atajo parais_type_name(type(x))
max_value(x: Type) -> Value
muestra el valor máximo de unTensorElementType
Six
no es unTensorElementType
, muestraNone
.min_value(x: Type) -> Value
muestra el valor mínimo posible de unTensorElementType
Six
no es unTensorElementType
, muestraNone
.member_name(x: Value | Placeholder | Type) -> Any
Disponible para todos los miembros definicionesmember_name
de todos los tipos. Por ejemplo,tensor_element_type(x)
. muestra la parteTensorElementType
deTensorType
correspondiente. Six
es un valor o un marcador de posición, esta función es un atajo paramember_name(type(x))
Six
no es un tipo que tenga un miembro apropiado o un valor o un marcador de posición de ese tipo muestraNone
.is_empty_algorithm(*args: Type)
comprueba si se configuraron todos los campos del algoritmo de puntos. aNone
. Esto es necesario, ya que los algoritmos de punto tienen una implementación definida. comportamientos predeterminados, por lo que especificar un valor predeterminado sería incorrecto.
Construcción de valores
operation_name(*xs: Value | Type) -> Value
Disponible para todas las operaciones. Por ejemplo,add(lhs, rhs)
toma dos valores de tensor,lhs
yrhs
, y muestra el resultado de la evaluación de la operaciónadd
con estas entradas. Para algunas operaciones, p.ej.,broadcast_in_dim
, los tipos de sus resultados son los siguientes: “load-bearing”, es decir, necesario para evaluar una operación En este caso, la función toma estos tipos como argumentos.
Funciones en valores
Todos los operadores y funciones de Python están disponibles. P.ej., ambos suscripción y dividir las notaciones de Python están disponibles para indexarse en tensores, tensores cuantificados y tuplas.
to_destination_type(x: Value, destination_type: Type) -> Value
se define en tensores y muestra el valor convertido dex
basado entype(x)
ydestination_type
de la siguiente manera:
def to_destination_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)
if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))
return convert(x, destination_type)
Hay una conversación temprana sobre la combinación de convert
, uniform_quantize
y
Operaciones uniform_dequantize
(#1576).
Después de la combinación, no necesitamos la función anterior y podemos usar el nombre de la operación.
para convert
.
is_nan(x: Value) -> Value
se define en tensores y muestratrue
si todos los elementos dex
sonNaN
ofalse
de lo contrario. Six
no es un tensor, muestraNone
.is_sorted(x: Value) -> Value
se define en tensores y muestratrue
si los elementos dex
se ordenan en forma ascendente con respecto al de los índices, o bienfalse
de lo contrario. Six
no es un tensor, muestraNone
.is_unique(x: Value) -> Value
se define en los tensores y muestratrue
six
. no tiene elementos duplicados nifalse
de lo contrario. Six
no es un tensor, muestraNone
.member_name(x: Value) -> Any
se define para todas las definiciones de los miembros. Elmember_name
de todos los valores. Por ejemplo,real_part(x)
muestraRealPart
. parte de unComplexConstant
correspondiente. Six
no es un valor que tenga un miembro correspondiente, muestraNone
.same(x: Value) -> Value
se define en tensores y muestratrue
si los elementos dex
son todos iguales entre sí, de lo contrario,false
. Si el tensor no tiene elementos, que cuentan como "todos iguales entre sí", es decir, el La función muestratrue
. Six
no es un tensor, muestraNone
.split(x: Value, num_results: Value, axis: Value) -> Value
se define en tensores y muestra segmentosnum_results
dex
a lo largo del ejeaxis
. Six
no es un tensor nidim(x, axis) % num_results != 0
, muestraNone
.is_defined_in_parent_scope(x: Value) -> Value
se define en las cadenas y muestratrue
six
es el nombre de una función definida en el mismo alcance. como la función principal de la op relevante.is_namespaced_op_name(x: Value) -> Value
se define en cadenas y muestratrue
six
es un nombre de operación válido, que respeta el siguiente comando normal expresión:[a-zA-Z][a-zA-Z0-9_]*([.][a-zA-Z0-9_$]+)+
Cálculos de formas
axes(x: Value | Placeholder | Type) -> Value
es la combinación de teclas pararange(rank(x))
dim(x: Value | Placeholder | Type, axis: Value) -> Value
es la combinación de teclas parashape(x)[axis]
dims(x: Value | Placeholder | Type, axes: List) -> List
es la combinación de teclas paralist(map(lambda axis: dim(x, axis), axes))
index_space(x: Value | Placeholder | Type) -> Value
se define en los tensores y muestra los índicessize(x)
para elTensorType
correspondiente ordenados de orden lexicográfico ascendente, es decir,[0, ..., 0]
,[0, ..., 1]
, ...,shape(x) - 1
Six
no es un tipo de tensor, un tipo de tensor cuantificado o un valor o un marcador de posición de uno de estos tipos, muestraNone
.rank(x: Value | Placeholder | Type) -> Value
es la combinación de teclas parasize(shape(x))
shape(x: Value | Placeholder | Type) -> Value
se define en la sección “Funciones de tipos" mediantemember_name
.size(x: Value | Placeholder | Type) -> Value
es la combinación de teclas parareduce(lambda x, y: x * y, shape(x))
Cálculos de cuantización
def baseline_element_type(x: Value | Placeholder | Type) -> Type
es un atajo paraelement_type(baseline_type(x))
.baseline_type
se define en tipos de tensores y tipos de tensores cuantificados, y las transforma en un “modelo de referencia”, es decir, un tipo con la misma forma, pero con la Los parámetros de cuantización del tipo de elemento se restablecen a los valores predeterminados. Este es Se usa como truco útil para comparar tipos de tensores y tensores cuantificados. de manera uniforme, lo que se necesita con bastante frecuencia. Para los tipos cuantizados, esto permite que comparan tipos ignorando los parámetros de cuantización, es decir,shape
storage_type
,expressed_type
,storage_min
,storage_max
yquantization_dimension
(para el tipo cuantizado por eje) debe coincidir, peroscales
yzero points
pueden variar.
def baseline_type(x: Value | Placeholder | Type) -> Type:
if type(x) == TensorType:
return x
if type(x) == QuantizedTensorType:
element_type = quantized_tensor_element_type(x)
baseline_element_type = QuantizedTensorElementType(
storage_type = storage_type(element_type),
storage_min = storage_min(element_type),
storage_max = storage_max(element_type),
expressed_type = expressed_type(element_type),
quantization_dimension = quantization_dimension(element_type),
scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
return QuantizedTensorType(shape(x), baseline_element_type)
if type(x) is not Type:
return baseline_element_type(type(x))
dequantize
se define en tipos de tensores cuantificados y los convierte en tipos de tensores de punto flotante. Esto sucede cuando se convierten elementos cuantizados que representan valores de números enteros del tipo de almacenamiento en valores valores de punto flotante del tipo expresado con el punto cero y la escala asociados con el tipo de elemento cuantizado.
def compute_zero_points(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
zero_points[i] = zero_points(quantized_type)[i[d]]
return zero_points
def compute_scales(quantized_type, result_type):
if is_per_tensor_quantized(quantized_type):
return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
type(result_type))
if is_per_axis_quantized(quantized_type):
for i in index_space(result_type):
d = quantization_dimension(quantized_type)
scales[i] = scales(quantized_type)[i[d]]
return scales
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
quantize
se define en tipos de tensores de punto flotante y los convierte en de tensores cuantificados. Esto sucede con la conversión de valores de punto flotante del tipo expresado en los valores de número entero correspondientes del tipo de almacenamiento con el punto cero y la escala asociadas al tipo de elemento cuantizado.
def quantize(x: Value, result_type: Type) -> Value:
assert is_float(x) and is_quantized(result_type)
zero_points = compute_zero_points(result_type, TensorType(shape(x), storage_type(result_type)))
converted_zero_points = convert(zero_points, expressed_type(result_type))
converted_min = convert(storage_min(result_type), expressed_type(result_type))
converted_max = convert(storage_max(result_type), expressed_type(result_type))
x_scaled = x / compute_scales(result_type, type(x))
x_scaled_add_zp = x_scaled + converted_zero_points
x_clamped = clamp(converted_min, x_scaled_add_zp, converted_max)
x_rounded = round_nearest_even(x_clamped)
return convert(x_rounded, result_type)
dequantize_op_quantize
se usa para especificar cálculos a nivel de elementos en con tensores cuantificados. Descuantiza, es decir, convierte elementos cuantificados en sus tipos expresados, realiza una operación y luego cuantiza, es decir, convierte los resultados a sus tipos de almacenamiento. Por el momento, esta función solo funciona para la cuantización por tensor. La cuantización por eje está en desarrollo (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
return quantize(float_result, output_type)
def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
inputs = inputs_and_output_type[:-3]
float_inputs = map(dequantize, inputs)
float_results = op(*float_inputs)
return map(quantize, float_results, inputs_and_output_type[-3:])
def dequantize_compare(lhs, rhs, comparison_direction):
float_lhs = dequantize(lhs)
float_rhs = dequantize(rhs)
return compare(float_lhs, float_rhs, comparison_direction, FLOAT)
def dequantize_select_quantize(pred, on_true, on_false, output_type):
float_on_true = dequantize(on_true)
float_on_false = dequantize(on_false)
float_result = select(pred, float_on_true, float_on_false)
return quantize(float_result, output_type)
hybrid_dequantize_then_op
se usa para especificar la cuantización de solo pesos op híbrida que acepta lhs en punto flotante y rh en tipos cuantizados. Integra descuantiza entradas cuantificadas en sus tipos expresados y realiza cálculos en flote. Tipo de elemento de tensor de lhs de número de punto flotante y tipo expresado de HR cuantificada “tensor” debe ser idéntico.
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
Cálculos de cuadrícula
cross_partition(replica_groups: Value) -> Value
Ver la "cross_replica" sección anterior.cross_replica(replica_groups: Value) -> Value
Ver la "cross_replica" sección anterior.cross_replica_and_partition(replica_groups: Value) -> Value
Consulta la "cross_replica_and_partition" sección anterior.flattened_ids(replica_groups: Value) -> Value
Ver los "flattened_ids" sección anterior.
Dinamismo
Los valores StableHLO pueden tener tamaños de dimensión dinámicos, p.ej., tensor<?xi64>
Sin embargo, los valores StableHLO no pueden tener un número dinámico de dimensiones (sin clasificar
dinamismo, p.ej., tensor<*xi64>
). Los operandos y resultados pueden usar datos
tamaños de las dimensiones, incluso si hay limitaciones. Las restricciones serán
de forma estática, si es posible; de lo contrario, se difieren al entorno de ejecución
y las faltas de coincidencia darán
un comportamiento indefinido. Consulta los ejemplos a continuación.
Discrepancias de forma para operaciones unarias a nivel de elementos
Considera el siguiente programa de juguetes:
func.func @foo(%arg0: tensor<?xf64>) {
%0 = stablehlo.abs %arg0 : (tensor<?xf64>) -> tensor<2xf64>
return
}
Este programa es inusual, porque no es común conocer la forma del
pero no la forma de la entrada. No obstante, este es un StableHLO válido
. No es posible validar de forma estática la operación abs
en esta
porque se desconoce la forma exacta del operando. Sin embargo, las formas
sin duda son compatibles, lo que se puede comprobar de forma estática: ?
podría resultar
sea 2
en el tiempo de ejecución, y no habría problema. Sin embargo, ?
podría
también resulta ser algún otro número entero, en cuyo caso el comportamiento es indefinido.
Ten en cuenta que, si el tamaño de la dimensión es dinámico en el resultado, no puede haber comportamiento indefinido. De hecho, no hay una de tamaño, por lo que no puede haber un no coincide.
Incompatibilidad de formas para operaciones binarias a nivel de los elementos
Considera el siguiente programa de juguetes:
func.func @foo(%arg0: tensor<?xf64>, %arg1: tensor<?xf64>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf64>, tensor<?xf64>) -> tensor<?xf64>
return
}
Cuando se trata de operaciones binarias, las formas de las entradas el resultado debe coincidir en el tiempo de ejecución. En el tiempo de compilación, las dimensiones estáticas deben ser iguales. de lo contrario, solo deben ser compatibles. Si cualquier dimensión es dinámica en las entradas, podría haber unidades no definidas comportamiento del usuario durante el tiempo de ejecución, ya que es posible que el tamaño dinámico no coincida con el valor tamaño en el otro operando (ya sea estático o dinámico). Si todas las entradas están estático, entonces no importa si el resultado es dinámico o no: de forma las dimensiones conocidas se verificarán de forma estática, y las dimensiones dinámicas no imponer ninguna restricción.
Incompatibilidad de forma para ops que toman su forma de salida como un operando.
Considera el siguiente programa de juguetes:
func.func @foo(%arg0: tensor<2xi32>) {
%0 = stablehlo.dynamic_iota %arg0, dim = 0 : (tensor<2xi32>) -> tensor<3x4xi64>
return
}
Los valores del operando de forma en el tiempo de ejecución deben coincidir con la forma del resultado,
de lo contrario, el comportamiento será indefinido. Es decir, durante el tiempo de ejecución, %arg0
debe tener un
valor de dense<[3, 4]> : tensor<2xi32>
. Si el operando de forma es constante, este
pueden verificarse de forma estática. Si la forma del resultado es completamente dinámica,
no puede ser una falta de coincidencia.