Especificación de StableHLO

StableHLO es un conjunto de operaciones para operaciones de alto nivel (HLO) en modelos de aprendizaje automático (AA). StableHLO funciona como una capa de portabilidad entre diferentes frameworks de AA y compiladores de AA: los frameworks de AA que producen programas StableHLO son compatibles con los compiladores de AA que consumen programas StableHLO.

Nuestro objetivo es simplificar y acelerar el desarrollo del AA creando una mayor interoperabilidad entre varios frameworks de AA (como TensorFlow, JAX y PyTorch) y compiladores de AA (como IREE y XLA). En ese sentido, este documento proporciona una especificación para el lenguaje de programación StableHLO.

Esta especificación contiene tres secciones principales. Primero, en la sección Programas, se describe la estructura de los programas StableHLO que constan de funciones de StableHLO que, a su vez, constan de operaciones de StableHLO. Dentro de esa estructura, la sección Ops especifica la semántica de las operaciones individuales. La sección Execution proporciona semántica para todas estas operaciones que se ejecutan juntas dentro de un programa. Por último, en la sección Notación, se analiza la notación utilizada en toda la especificación.

Programas

Program ::= {Func}

Los programas StableHLO constan de una cantidad arbitraria de funciones StableHLO. A continuación, se muestra un programa de ejemplo con una función @main que tiene 3 entradas (%image, %weights y %bias) y 1 resultado. El cuerpo de la función tiene 6 operaciones.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Funciones

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

Las funciones estables (que también se denominan funciones con nombre) tienen un identificador, entradas y salidas, y un cuerpo. En el futuro, planeamos agregar metadatos adicionales para las funciones a fin de lograr una mejor compatibilidad con HLO (#425, #626, #740 y #744).

Identificadores

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Los identificadores estables son similares a los identificadores en muchos lenguajes de programación, con dos peculiaridades: 1) todos los identificadores tienen sigils que distinguen diferentes tipos de identificadores, 2) los identificadores de valores pueden ser completamente numéricos para simplificar la generación de programas StableHLO.

Tipos

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Los tipos de StableHLO se clasifican en tipos de valores (que también se denominan tipos de primera clase), que representan valores de StableHLO y tipos que no son de valores que describen otros elementos del programa. Los tipos StableHLO son similares a los tipos en muchos lenguajes de programación, cuya particularidad principal es la naturaleza específica del dominio de StableHLO, que genera algunos resultados inusuales (p.ej., los tipos escalares no son tipos de valores).

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

Los tipos de tensor representan tensores, es decir, arrays multidimensionales. Tienen una forma y un tipo de elemento, en los que una forma representa tamaños de dimensión no negativos en orden ascendente de las dimensiones correspondientes (que también se denominan ejes) numeradas del 0 al R-1. La cantidad de dimensiones R se denomina clasificación. Por ejemplo, tensor<2x3xf32> es un tipo de tensor con forma 2x3 y tipo de elemento f32. Tiene dos dimensiones (o, en otras palabras, dos ejes): 0 y 1, cuyos tamaños son 2 y 3. Su clasificación es 2.

Esto define la compatibilidad con formas estáticas en las que los tamaños de dimensión se conocen estáticamente. En el futuro, planeamos agregar compatibilidad con formas dinámicas en las que los tamaños de las dimensiones sean desconocidos de forma parcial o total (#8). Además, planeamos explorar la extensión de los tipos de tensores más allá de los tamaños de dimensión y los tipos de elementos, por ejemplo, para incluir diseños (#629) y dispersión (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nombre Tipo Restricciones
storage_type tipo de número entero (C1-C4) o C9
storage_min constante de número entero (C2), (C4) y (C8)
storage_max constante de número entero (C3), (C4) y (C8)
expressed_type tipo de punto flotante (C1), (C5)
quantization_dimension constante de número entero opcional (C11-C13).
scales número variádico de constantes de punto flotante (C5-C7), (C10), (C11) y (C13)
zero_points número variádico de constantes de números enteros (C8-C10)

Los tipos de elementos cuantificados representan valores de números enteros de un tipo de almacenamiento en el rango de storage_min a storage_max (inclusive) que corresponden a valores de punto flotante de un tipo expresado. Para un número entero determinado i, el valor de punto flotante correspondiente f se puede calcular como f = (i - zero_point) * scale, en el que scale y zero_point se denominan parámetros de cuantización. Los valores storage_min y storage_max son opcionales en la gramática, pero tienen valores predeterminados de min_value(storage_type) y max_value(storage_type), respectivamente. Los tipos de elementos cuantizados tienen las siguientes restricciones:

  • (C1) num_bits(storage_type) < num_bits(expressed_type).
  • (C2) type(storage_min) = storage_type.
  • (C3) type(storage_max) = storage_type.
  • (C4) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C5) type(scales...) = expressed_type.
  • (C6) 0 < scales.
  • (C7) is_finite(scales...).
  • (C8) storage_min <= zero_points <= storage_max.
  • (C9) type(zero_points...) = storage_type.
  • (C10) size(scales) = size(zero_points).
  • (C11) Si es is_empty(quantization_dimension), entonces size(scales) = 1.
  • (C12) 0 <= quantization_dimension.

Por el momento, QuantizationScale es una constante de punto flotante, pero hay gran interés en las escalas basadas en números enteros, representadas con multiplicadores y cambios. Tenemos pensado explorar este tema en un futuro cercano (#1404).

Hay un debate en curso sobre la semántica de QuantizationZeroPoint, que incluye el tipo, los valores y si puede haber solo uno o varios puntos cero en un tipo de tensor cuantificado. En función de los resultados de este análisis, la especificación en torno a los puntos cero podría cambiar en el futuro (#1405).

Otro debate en curso involucra la semántica de QuantizationStorageMin y QuantizationStorageMax para determinar si se debe imponer alguna restricción a estos valores y a los de tensores cuantificados (#1406).

Por último, planeamos explorar la representación de escalas desconocidas y puntos cero, de manera similar a cómo planeamos explorar la representación de tamaños de dimensión desconocidos (#1407).

Los tipos de tensores cuantificados representan tensores con elementos cuantificados. Estos tensores son exactamente los mismos que los tensores normales, con la excepción de que sus elementos tienen tipos de elementos cuantificados, en lugar de tipos de elementos regulares.

En los tensores cuantizados, la cuantización puede ser por tensor, es decir, tener un scale y zero_point para todo el tensor o puede ser por eje, es decir, tener varios scales y zero_points, un par por porción de una dimensión en particular quantization_dimension. De manera más formal, en un tensor t con cuantización por eje, hay segmentos dim(t, quantization_dimension) de quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :], etc. Todos los elementos de la i.a porción usan scales[i] y zero_points[i] como sus parámetros de cuantización. Los tipos de tensores cuantificados tienen las siguientes restricciones:

  • Para la cuantización por tensor:
    • Sin restricciones adicionales.
  • Para la cuantización por eje:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Los tipos de token representan tokens, es decir, valores opacos que producen y consumen algunas operaciones. Los tokens se usan para imponer el orden de ejecución a las operaciones como se describe en la sección Ejecución.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Los tipos de tupla representan tuplas, es decir, listas heterogéneas. Las tuplas son una función heredada que solo existe para ser compatible con HLO. En HLO, las tuplas se usan para representar entradas y salidas variables. En StableHLO, las entradas y salidas variables son compatibles de forma nativa y el único uso de tuplas en StableHLO es representar de manera integral la ABI de HLO, en la que, p.ej., T, tuple<T> y tuple<tuple<T>> pueden ser sustancialmente diferentes según una implementación en particular. En el futuro, planeamos realizar cambios en la ABI de HLO, lo que podría permitirnos quitar los tipos de tuplas de StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Los tipos de elementos representan elementos de tipos de tensores. A diferencia de muchos lenguajes de programación, estos tipos no son de primera clase en StableHLO. Esto significa que los programas estables no pueden representar directamente valores de estos tipos (como resultado, es idiomático representar valores escalares de tipo T con valores de tensores de 0 dimensiones de tipo tensor<T>).

  • Tipo booleano representa los valores booleanos true y false.
  • Los tipos de número entero pueden ser con firma (si) o sin firma (ui) y tener uno de los anchos de bits admitidos (4, 8, 16, 32 o 64). Los tipos de siN firmados representan valores de números enteros de -2^(N-1) a 2^(N-1)-1 inclusive, y los tipos uiN sin firma representan valores de números enteros de 0 a 2^N-1 inclusive.
  • Los tipos de punto flotante pueden ser uno de los siguientes:
  • Los tipos complejos representan valores complejos que tienen una parte real y una parte imaginaria del mismo tipo de elemento. Los tipos complejos admitidos son complex<f32> (ambas partes son del tipo f32) y complex<f64> (ambas partes son del tipo f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Los tipos de funciones representan funciones con nombre y funciones anónimas. Tienen tipos de entrada (la lista de tipos en el lado izquierdo de ->) y tipos de salida (la lista de tipos a la derecha de ->). En muchos lenguajes de programación, los tipos de funciones son de primera clase, pero no en StableHLO.

StringType ::= 'string'

Tipo de string representa secuencias de bytes. A diferencia de muchos lenguajes de programación, el tipo de string no es la primera clase en StableHLO y solo se usa para especificar metadatos estáticos para los elementos del programa.

Operaciones

Las operaciones estables (que también se denominan ops) representan un conjunto cerrado de operaciones de alto nivel en modelos de aprendizaje automático. Como se mencionó antes, la sintaxis de StableHLO está inspirada en gran medida en MLIR, que no es necesariamente la alternativa más ergonómica, pero podría decirse que es la mejor opción para el objetivo de StableHLO de crear más interoperabilidad entre los frameworks de AA y los compiladores de AA.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

Las operaciones estables (que también se llaman ops) tienen un nombre, entradas y salidas, y una firma. El nombre consta del prefijo stablehlo. y un nombre mnemotécnico que identifica de forma única una de las operaciones admitidas. Consulta la siguiente lista para obtener una lista completa de todas las operaciones admitidas.

Por el momento, los programas StableHLO en la naturaleza a veces contienen operaciones que no se describen en este documento. En el futuro, planeamos absorber estas operaciones en el conjunto de operaciones StableHLO o prohibir que aparezcan en los programas StableHLO. Mientras tanto, la siguiente es la lista de estas operaciones:

  • builtin.module, func.func, func.call y func.return (#425).
  • chlo (#602).
  • Categoría "No en HLO" de operaciones StableHLO. Inicialmente, eran parte del conjunto de operaciones StableHLO, pero luego se determinó que no encajaban bien: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select, unary_einsum (#3).
  • Categoría “Dynamism” de las operaciones StableHLO. Se iniciaron a partir de MHLO, pero aún no las especificamos: compute_reshape_shape, cstr_reshapable, dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, real_dynamic_slice, set_dimension_size (#8).
  • Cálculos de forma, incluidas las operaciones arith, shape y tensor (#8).
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

Las operaciones consumen entradas y producen salidas. Las entradas se categorizan en valores de entrada (que se procesan durante la ejecución), funciones de entrada (proporcionadas de forma estática, porque en StableHLO las funciones no son valores de primera clase) y atributos de entrada (también proporcionados de forma estática). El tipo de entradas y salidas que consume y produce una op depende de su nemotécnica. Por ejemplo, la op add consume 2 valores de entrada y produce 1 valor de salida. En comparación, la op select_and_scatter consume 3 valores de entrada, 2 funciones de entrada y 3 atributos de entrada.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

Las funciones de entrada (también llamadas funciones anónimas) son muy similares a las funciones con nombre, con la excepción de que 1) no tienen un identificador (por eso el nombre es "anónimo"), y 2) no declaran tipos de salida (los tipos de salida se deducen de la operación return dentro de la función).

La sintaxis de las funciones de entrada incluye una parte que no se usa en este momento (consulta la producción de Unused más arriba) que brinda compatibilidad con MLIR. En MLIR, existe un concepto más general de “regiones” que pueden tener varios “bloques” de operaciones conectados a través de operaciones de salto. Estos bloques tienen IDs que corresponden a la producción de Unused, de modo que se puedan distinguir entre sí. StableHLO no tiene operaciones de salto, por lo que la parte correspondiente de la sintaxis de MLIR no se usa (pero sigue estando allí).

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Los atributos de entrada tienen un nombre y un valor que es una de las constantes admitidas. Son la forma principal de especificar metadatos estáticos para elementos del programa. Por ejemplo, la operación concatenate usa el atributo dimension para especificar la dimensión con la que se concatenan sus valores de entrada. De manera similar, la op slice usa varios atributos, como start_indices y limit_indices, a fin de especificar los límites que se usan para dividir el valor de entrada.

Por el momento, los programas StableHLO en la naturaleza a veces contienen atributos que no se describen en este documento. En el futuro, planeamos absorber estos atributos en el conjunto de operaciones StableHLO o prohibir que aparezcan en los programas StableHLO. Mientras tanto, esta es la lista de estos atributos:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Metadatos de ubicación (#594)
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

La firma de operaciones consta de los tipos de todos los valores de entrada (la lista de tipos en el lado izquierdo de ->) y los tipos de todos los valores de salida (la lista de tipos en el lado derecho de ->). En sentido estricto, los tipos de entrada son redundantes y los tipos de salida casi siempre son redundantes (porque para la mayoría de las operaciones StableHLO, los tipos de salida se pueden inferir de las entradas). No obstante, la firma de operaciones forma parte deliberadamente de la sintaxis de StableHLO para brindar compatibilidad con MLIR.

A continuación, se muestra un ejemplo de operación mnemotécnica select_and_scatter. Consume 3 valores de entrada (%operand, %source y %init_value), 2 funciones de entrada y 3 atributos de entrada (window_dimensions, window_strides y padding). Ten en cuenta que la firma de la op solo incluye los tipos de sus valores de entrada (pero no los tipos de funciones y atributos de entrada que se proporcionan intercalados).

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Constantes

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

Las constantes StableHLO tienen un literal y un tipo que, en conjunto, representan un valor StableHLO. Por lo general, el tipo forma parte de la sintaxis constante, excepto cuando no es ambiguo (p.ej., una constante booleana tiene el tipo i1 sin ninguna ambigüedad, mientras que una constante de número entero puede tener varios tipos posibles).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

Las constantes booleanas representan valores booleanos true y false. Las constantes booleanas tienen el tipo i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Las constantes de números enteros representan valores de números enteros a través de strings que usan notación decimal o hexadecimal. No se admiten otras bases, p.ej., binaria u octal. Las constantes de número entero tienen las siguientes restricciones:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Las constantes de punto flotante representan valores de punto flotante a través de cadenas que usan notación decimal o científica. Además, la notación hexadecimal se puede usar para especificar directamente los bits subyacentes en el formato de punto flotante del tipo correspondiente. Las constantes de punto flotante tienen las siguientes restricciones:

  • (C1) Si se usa la notación no hexadecimal, is_wellformed(float_literal, float_type).
  • (C2) Si se usa la notación hexadecimal, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Las constantes complejas representan valores complejos mediante listas de una parte real (va primero) y una parte imaginaria (va por segundo). Por ejemplo, (1.0, 0.0) : complex<f32> representa a 1.0 + 0.0i y (0.0, 1.0) : complex<f32> representa 0.0 + 1.0i. El orden en el que estas partes se almacenan en la memoria está definido por la implementación. Las constantes complejas tienen las siguientes restricciones:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

Las constantes tensoriales representan valores de tensor mediante listas anidadas especificadas a través de notación NumPy. Por ejemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> representa un valor de tensor con la siguiente asignación de índices a elementos: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5 y {1, 2} => 6. El orden en el que estos elementos se almacenan en la memoria se define por la implementación. Las constantes de Tensor tienen las siguientes restricciones:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), donde:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), donde:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • de lo contrario, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

Las constantes de tensor cuantificadas representan valores de tensor cuantificados que usan la misma notación que las constantes del tensor, con elementos especificados como constantes de su tipo de almacenamiento. Las constantes de tensor cuantificadas tienen las siguientes restricciones:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

Los literales de string constan de bytes especificados mediante caracteres ASCII y secuencias de escape. Dado que son independientes de la codificación, la interpretación de estos bytes está definida por la implementación. Los literales de string tienen el tipo string.

Ops

abs

Semántica

Realiza una operación abs a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números enteros con firma: módulo de números enteros.
  • Para números de punto flotante: abs de IEEE-754.
  • Para números complejos: módulo complejo.
  • Para tipos cuantizados: dequantize_op_quantize(abs, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de número entero con signo, punto flotante o tipo complejo, o tensor cuantificado por tensor (C1-C2).

Salidas

Nombre Tipo Restricciones
result tensor de número entero con signo, tipo de punto flotante o tensor cuantificado por tensor (C1-C2).

Restricciones

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) se define de la siguiente manera:
    • complex_element_type(element_type(operand)) si es is_complex(operand).
    • De lo contrario, baseline_element_type(operand).

Ejemplos

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Más ejemplos

add

Semántica

Realiza una adición a nivel de elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: OR lógico.
  • Para números enteros: suma de números enteros.
  • Para números de punto flotante: addition de IEEE-754.
  • Para números complejos: suma compleja.
  • Para tipos cuantizados: dequantize_op_quantize(add, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor o por tensor cuantizado (C1)
(I2). rhs tensor o por tensor cuantizado (C1)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1)

Restricciones

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Ejemplos

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Más ejemplos

after_all

Semántica

Garantiza que las operaciones que producen el inputs se ejecuten antes de cualquier operación que dependa de result. La ejecución de esta operación no hace nada; solo existe para establecer dependencias de datos de result a inputs.

Entradas

Etiqueta Nombre Tipo
(I1). inputs número variable de token

Salidas

Nombre Tipo
result token

Ejemplos

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Más ejemplos

all_gather

Semántica

Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, concatena los valores del tensor operand de cada proceso junto con all_gather_dim y produce un tensor result.

La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:

  • cross_replica(replica_groups) si es channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si es channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si es channel_id > 0 and use_global_device_ids = true.

Luego, dentro de cada process_group, haz lo siguiente:

  • operands@receiver = [operand@sender for sender in process_group] para todos los receiver en process_group.
  • result@process = concatenate(operands@process, all_gather_dim) para todos los process en process_group.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1), (C6)
(I2). all_gather_dim constante de tipo si64 (C1), (C6)
(I3). replica_groups Constante tensorial bidimensional de tipo si64 (C2-C4).
(I4). channel_id constante de tipo si64 (C5)
(I5). use_global_device_ids constante de tipo i1 (C5)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C6)

Restricciones

  • (C1) 0 <= all_gather_dim < rank(operand).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) se define de la siguiente manera:
    • num_replicas si se usa cross_replica.
    • num_replicas si se usa cross_replica_and_partition.
    • num_processes si se usa flattened_ids.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Si es use_global_device_ids = true, entonces channel_id > 0.
  • (C6) type(result) = type(operand), excepto:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

Ejemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

Más ejemplos

all_reduce

Semántica

Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, aplica una función de reducción computation a los valores del tensor operand de cada proceso y produce un tensor result.

La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:

  • cross_replica(replica_groups) si es channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si es channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si es channel_id > 0 and use_global_device_ids = true.

Luego, dentro de cada process_group, haz lo siguiente:

  • result@process[result_index] = exec(schedule) para algún árbol binario schedule, en el que:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule es un árbol binario definido por la implementación cuyo recorrido en orden es to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C5) y (C6)
(I2). replica_groups número variable de constantes de tensor unidimensionales de tipo si64 (C1-C3).
(I3). channel_id constante de tipo si64 (C4)
(I4). use_global_device_ids constante de tipo i1 (C4)
(I5). computation la función (C5)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C6-C7).

Restricciones

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) se define de la siguiente manera:
    • num_replicas si se usa cross_replica.
    • num_replicas si se usa cross_replica_and_partition.
    • num_processes si se usa flattened_ids.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Si es use_global_device_ids = true, entonces channel_id > 0.
  • (C5) computation tiene el tipo (tensor<E>, tensor<E>) -> (tensor<E>), en el que is_promotable(element_type(operand), E).
  • (C6) shape(result) = shape(operand).
  • (C7) element_type(result) = E.

Ejemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

Más ejemplos

all_to_all

Semántica

Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, se dividen los valores del tensor operand a lo largo de split_dimension en partes, dispersa las partes divididas entre los procesos, concatena las partes dispersas junto con concat_dimension y produce un tensor result.

La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:

  • cross_replica(replica_groups) si es channel_id <= 0.
  • cross_partition(replica_groups) si es channel_id > 0.

Luego, dentro de cada process_group, haz lo siguiente:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) para todos los sender en process_group.
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group], donde receiver_index = process_group.index(receiver).
  • result@process = concatenate(scattered_parts@process, concat_dimension).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1-C3) o C9
(I2). split_dimension constante de tipo si64 (C1), (C2) y (C9)
(I3). concat_dimension constante de tipo si64 (C3) y (C9)
(I4). split_count constante de tipo si64 (C2), (C4), (C8) y (C9)
(I5). replica_groups Constante tensorial bidimensional de tipo si64 (C5-C8).
(I6). channel_id constante de tipo si64

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C9)

Restricciones

  • (C1) 0 <= split_dimension < rank(operand).
  • (C2) dim(operand, split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operand).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) se define de la siguiente manera:
    • num_replicas si se usa cross_replica.
    • num_partitions si se usa cross_partition.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(result) = type(operand), excepto:
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

Ejemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

Más ejemplos

y

Semántica

Realiza el operador AND a nivel de elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: lógico AND.
  • Para números enteros: AND a nivel de bits.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de tipo booleano o número entero (C1)
(I2). rhs tensor de tipo booleano o número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo booleano o número entero (C1)

Restricciones

  • (C1) type(lhs) = type(rhs) = type(result).

Ejemplos

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

Semántica

Realiza la operación atan2 a nivel de los elementos en los tensor lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: atan2 de IEEE-754.
  • Para números complejos: atan2 complejo.
  • Para tipos cuantizados: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)
(I2). rhs tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Ejemplos

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Más ejemplos

batch_norm_grad

Semántica

Calcula los gradientes de varias entradas de batch_norm_training que se propagan hacia atrás desde grad_output y produce los tensores grad_operand, grad_scale y grad_offset. De manera más formal, esta operación se puede expresar como una descomposición de operaciones StableHLO existentes mediante la sintaxis de Python de la siguiente manera:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Para tipos cuantizados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1-C3) o C5
(I2). scale Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2), (C4) y (C5)
(I3). mean Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2) y (C4)
(I4). variance Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2) y (C4)
(I5). grad_output tensor de tipo de punto flotante o por tensor cuantificado (C2) y (C3)
(I6). epsilon constante de tipo f32
(I7). feature_index constante de tipo si64 (C1), (C5)

Salidas

Nombre Tipo Restricciones
grad_operand tensor de tipo de punto flotante o por tensor cuantificado (C2) y (C3)
grad_scale Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2) y (C4)
grad_offset Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2) y (C4)

Restricciones

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale y grad_offset tienen el mismo baseline_element_type.
  • (C3) operand, grad_output y grad_operand tienen la misma forma.
  • (C4) scale, mean, variance, grad_scale y grad_offset tienen la misma forma.
  • (C5) size(scale) = dim(operand, feature_index).

Ejemplos

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semántica

Normaliza el tensor operand en todas las dimensiones, excepto feature_index, y produce un tensor result. De manera más formal, esta operación se puede expresar como una descomposición de operaciones StableHLO existentes mediante la sintaxis de Python de la siguiente manera:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Para tipos cuantizados, realiza dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1-C7).
(I2). scale Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2) y (C3)
(I3). offset Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2) y (C4)
(I4). mean Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C5)
(I5). variance Tensor unidimensional de punto flotante o tipo cuantificado por tensor (C2) y (C6)
(I6). epsilon constante de tipo f32
(I7). feature_index constante de tipo si64 (C1) o (C3-C6)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante o por tensor cuantificado (C2) y (C7)

Restricciones

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance y result tienen el mismo baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semántica

Calcula la media y la varianza de todas las dimensiones, excepto la dimensión feature_index, y normaliza el tensor operand, lo que produce los tensores output, batch_mean y batch_var. De manera más formal, esta operación se puede expresar como una descomposición a operaciones StableHLO existentes mediante la sintaxis de Python de la siguiente manera:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Para tipos cuantizados, realiza dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1)
(I2). scale Tensor unidimensional de punto flotante o por tensor cuantificado (C2) y (C3)
(I3). offset Tensor unidimensional de punto flotante o por tensor cuantificado (C2) y (C4)
(I4). epsilon constante de tipo f32 (C1) o (C3-C6)
(I5). feature_index constante de tipo si64 (C1) o (C3-C6)

Salidas

Nombre Tipo Restricciones
output tensor de tipo de punto flotante o por tensor cuantificado (C7)
batch_mean Tensor unidimensional de punto flotante o por tensor cuantificado (C2) y (C5)
batch_var Tensor unidimensional de punto flotante o por tensor cuantificado (C2) y (C6)

Restricciones

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var y output tienen el mismo baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Ejemplos

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semántica

Realiza una operación de transmisión de bits en el tensor operand y produce un tensor result en el que los bits de todo el tensor operand se vuelven a interpretar con el tipo de tensor result.

De manera más formal, con E = element_type(operand), E' = element_type(result) y R = rank(operand):

  • Si es num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Si es num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Si es num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits muestra la representación en la memoria de un valor determinado, y su comportamiento está definido por la implementación porque la representación exacta de los tensores está definida por la implementación, y la representación exacta de los tipos de elementos también está definida por la implementación.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o tensor cuantificado (C1-C2).

Salidas

Nombre Tipo Restricciones
result tensor o tensor cuantificado (C1-C2).

Restricciones

  • (C1) Dados E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) y R = rank(operand):
    • Si es num_bits(E') = num_bits(E), es shape(result) = shape(operand).
    • Si es num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) para todos los 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Si es num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) para todos los 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Si es is_complex(operand) or is_complex(result), entonces is_complex(operand) and is_complex(result).

Ejemplos

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Más ejemplos

broadcast_in_dim

Semántica

Expande las dimensiones o la clasificación de un tensor de entrada mediante la duplicación de los datos en el tensor operand y produce un tensor result. De manera más formal, result[result_index] = operand[operand_index], donde, para todos los d en axes(operand):

  • operand_index[d] = 0 si es dim(operand, d) = 1.
  • De lo contrario, operand_index[d] = result_index[broadcast_dimensions[d]].

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o tensor cuantificado (C1-C2) o (C5-C6)
(I2). broadcast_dimensions Constante tensorial unidimensional de tipo si64 (C2-C6).

Salidas

Nombre Tipo Restricciones
result tensor o tensor cuantificado (C1), (C3) y (C5-C6)

Restricciones

  • (C1) element_type(result) se obtiene de la siguiente manera:
    • element_type(operand), si es !is_per_axis_quantized(operand).
    • element_type(operand), con la excepción de que quantization_dimension(operand), scales(operand) y zero_points(operand) pueden diferir de quantization_dimension(result), scales(result) y zero_points(result) en caso contrario.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Para todos los d en axes(operand):
    • dim(operand, d) = 1 o
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Si es is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Si es dim(operand, quantization_dimension(operand)) = 1, entonces scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Ejemplos

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Más ejemplos

caso

Semántica

Genera el resultado a partir de la ejecución de exactamente una función de branches según el valor de index. De manera más formal, result = selected_branch(), en el que sucede lo siguiente:

  • selected_branch = branches[index] si es 0 <= index < size(branches).
  • De lo contrario, selected_branch = branches[-1].

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). index Tensor de 0 dimensiones de tipo si32
(I2). branches número variable de funciones (C1-C4).

Salidas

Nombre Tipo Restricciones
results cantidad variable de tensores, tensores cuantificados o tokens (C4)

Restricciones

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Ejemplos

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Más ejemplos

CTR

Semántica

Realiza una operación de raíz cúbica a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: rootn(x, 3) de IEEE-754.
  • Para números complejos: raíz cúbica compleja.
  • Para tipos cuantizados: dequantize_op_quantize(cbrt, operand, type(result))

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Más ejemplos

ceil

Semántica

Ejecuta el ceil a nivel de elementos del tensor operand y produce un tensor result. Implementa la operación roundToIntegralTowardPositive de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(ceil, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante o por tensor cuantificado (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Más ejemplos

colectivo

Semántica

Calcula la descomposición de Cholesky de un lote de matrices.

De manera más formal, para todo i en index_space(result), result[i0, ..., iR-3, :, :] es una descomposición de Cholesky de a[i0, ..., iR-3, :, :] en forma de una matriz triangular inferior (si lower es true) o una matriz triangular superior (si lower es false). Los valores de salida del triángulo opuesto, es decir, el triángulo superior estricto o al triángulo inferior estricto, respectivamente, están definidos por la implementación.

Si existe i en la que la matriz de entrada no es una matriz ermitiana definida de forma positiva, entonces el comportamiento es indefinido.

Para tipos cuantizados, realiza dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). a tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1-C3).
(I2). lower Constante tensorial de 0 dimensiones de tipo i1

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Ejemplos

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

restringir

Semántica

Sujeta cada elemento del tensor operand entre un valor mínimo y máximo, y produce un tensor result. De manera más formal, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), donde min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Para tipos cuantizados, realiza dequantize_op_quantize(clamp, min, operand, max, type(result)).

La imposición de un orden en números complejos implica una semántica sorprendente, por lo que en el futuro planeamos quitar la compatibilidad con números complejos para esta operación (#560).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). min tensor o por tensor cuantizado (C1), (C3)
(I2). operand tensor o por tensor cuantizado (C1-C4).
(I3). max tensor o por tensor cuantizado (C2) y (C3)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C4)

Restricciones

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Ejemplos

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Más ejemplos

collective_broadcast

Semántica

Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, envía el valor del tensor operand del proceso de origen a los procesos de destino y produce un tensor result.

La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:

  • cross_replica(replica_groups) si es channel_id <= 0.
  • cross_partition(replica_groups) si es channel_id > 0.

Luego, el valor de result@process se obtiene de la siguiente manera:

  • Es operand@process_groups[i, 0] si existe un i tal que el proceso esté en process_groups[i].
  • broadcast_in_dim(constant(0, element_type(result)), [], type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor (C3)
(I2). replica_groups número variable de constantes de tensor unidimensionales de tipo si64 (C1), (C2)
(I3). channel_id constante de tipo si64

Salidas

Nombre Tipo Restricciones
result tensor (C3)

Restricciones

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, en el que N se define de la siguiente manera:
    • num_replicas si se usa cross_replica.
    • num_partitions si se usa cross_partition.
  • (C3) type(result) = type(operand).

Ejemplos

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semántica

Dentro de cada grupo de procesos en la cuadrícula de procesos StableHLO, envía el valor del tensor operand del proceso de origen al proceso de destino y produce un tensor result.

La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:

  • cross_replica(source_target_pairs) si es channel_id <= 0.
  • cross_partition(source_target_pairs) si es channel_id > 0.

Luego, el valor de result@process se obtiene de la siguiente manera:

  • operand@process_groups[i, 0], si existe un i tal que process_groups[i, 1] = process.
  • broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C5)
(I2). source_target_pairs Constante tensorial bidimensional de tipo si64 (C1-C4).
(I3). channel_id constante de tipo si64

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1)

Restricciones

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, en la que N se define de la siguiente manera:
    • num_replicas si se usa cross_replica.
    • num_partitions si se usa cross_partition.
  • (C5) type(result) = type(operand).

Ejemplos

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Más ejemplos

compare

Semántica

Realiza una comparación con elementos de los tensores lhs y rhs de acuerdo con comparison_direction y compare_type, y produce un tensor result.

Los valores de comparison_direction y compare_type tienen la siguiente semántica:

Para los tipos de elementos booleanos y enteros:

  • EQ: lhs = rhs
  • NE: lhs != rhs
  • GE: lhs >= rhs
  • GT: lhs > rhs
  • LE: lhs <= rhs
  • LT: lhs < rhs

Para los tipos de elementos de punto flotante con compare_type = FLOAT, la op implementa las siguientes operaciones IEEE-754:

  • EQ: compareQuietEqual
  • NE: compareQuietNotEqual
  • GE: compareQuietGreaterEqual
  • GT: compareQuietGreater
  • LE: compareQuietLessEqual
  • LT: compareQuietLess

Para los tipos de elementos de punto flotante con compare_type = TOTALORDER, la operación usa la combinación de operaciones totalOrder y compareQuietEqual de IEEE-754. Esta función parece no estar en uso, por lo que, en el futuro, planeamos quitarla (#584).

Para los tipos de elementos complejos, la comparación lexicográfica de los pares (real, imag) se realiza con los comparison_direction y compare_type proporcionados. La imposición de un orden en números complejos implica una semántica sorprendente, por lo que, en el futuro, planeamos quitar la compatibilidad con números complejos cuando comparison_direction sea GE, GT, LE o LT (#560).

Para tipos cuantizados, se realiza dequantize_compare(lhs, rhs, comparison_direction).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor o por tensor cuantizado (C1-C3).
(I2). rhs tensor o por tensor cuantizado (C1-C2).
(I3). comparison_direction enum de EQ, NE, GE, GT, LE y LT
(I4). compare_type enum de FLOAT, TOTALORDER, SIGNED y UNSIGNED (C3)

Salidas

Nombre Tipo Restricciones
result tensor de tipo booleano (C2)

Restricciones

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type se define de la siguiente manera:
    • SIGNED si es is_signed_integer(element_type(lhs)).
    • UNSIGNED si es is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT o TOTALORDER si es is_float(element_type(lhs)).
    • FLOAT si es is_complex(element_type(lhs)).

Ejemplos

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Más ejemplos

complejo

Semántica

Realiza la conversión a nivel de los elementos en un valor complejo a partir de un par de valores imaginarios y reales, lhs y rhs, y produce un tensor result.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de tipo f32 o f64 (C1-C3).
(I2). rhs tensor de tipo f32 o f64 (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo complejo (C2) y (C3)

Restricciones

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) tiene el tipo complex<E>, donde E = element_type(lhs).

Ejemplos

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Más ejemplos

concatenate

Semántica

Concatena inputs a lo largo de la dimensión dimension en el mismo orden que los argumentos dados y produce un tensor result. De manera más formal, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], en el que sucede lo siguiente:

  1. id = d0 + ... + dk-1 + kd.
  2. d es igual a dimension, y d0, ... son los tamaños de da dimensión de inputs.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). inputs cantidad variádica de tensores o tensores cuantificados por tensor (C1-C6).
(I2). dimension constante de tipo si64 (C2), (C4) y (C6)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C5-C6).

Restricciones

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)), excepto dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]), excepto para:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Ejemplos

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Más ejemplos

constante

Semántica

Produce un tensor output a partir de una constante value.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). value constante (C1)

Salidas

Nombre Tipo Restricciones
output tensor o tensor cuantificado (C1)

Restricciones

  • (C1) type(value) = type(output).

Ejemplos

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Más ejemplos

generar una conversión

Semántica

Realiza una conversión a nivel de elementos de un tipo de elemento a otro en el tensor operand y produce un tensor result.

Para las conversiones de boolean-to-any-supported-type, el valor false se convierte en cero, y el valor true se convierte en uno. Para las conversiones any-supported-type-to-boolean, un valor cero se convierte en false y los valores distintos de cero se convierten en true. A continuación, se muestra cómo funciona esto en tipos complejos.

En el caso de las conversiones que incluyen integer-to-integer, integer-to-floating-point o floating-point-to-floating-point, si el valor de origen se puede representar con exactitud en el tipo de destino, el valor del resultado es esa representación exacta. De lo contrario, el comportamiento está por definir (#180).

En el caso de las conversiones que utilizan floating-point-to-integer, la parte fraccionaria se trunca. Si el valor truncado no se puede representar en el tipo de destino, el comportamiento se define por definir (#180).

Las conversiones que implican complejo a complejo siguen el mismo comportamiento de las conversiones de punto flotante a punto flotante para convertir partes reales y imaginarias.

Para las conversiones de complex-to-any-other-type y complex-to-any-other-type, se ignora el valor imaginario de origen o el valor imaginario de destino se pone en cero, respectivamente. La conversión de la parte real sigue a las conversiones de punto flotante.

En principio, esta operación podría expresar descuantización (conversión de tensores cuantificados a tensores regulares), cuantización (conversión de tensores regulares a tensores cuantificados) y recuantización (conversión entre tensores cuantificados), pero, por el momento, tenemos operaciones dedicadas: uniform_dequantize para el primer caso de uso y uniform_quantize para el segundo y el tercero. En el futuro, estas dos operaciones se pueden combinar en convert (#1576).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor (C1)

Restricciones

  • (C1) shape(operand) = shape(result).

Ejemplos

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Más ejemplos

convolución

Semántica

Calcula productos de puntos entre las ventanas de lhs y porciones de rhs y produce result. En el siguiente diagrama, se muestra cómo se calculan los elementos en result a partir de lhs y rhs con un ejemplo concreto.

De manera más formal, considera el siguiente reencuadre de las entradas en términos de lhs para poder expresar ventanas de lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Este reencuadre usa las siguientes funciones auxiliares:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1], donde j[d] = i[permutation[d]].

Si es feature_group_count = 1 y batch_group_count = 1, para todos los output_spatial_index en index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product, en el que sucede lo siguiente:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Parece que esta función no se usa, por lo que planeamos quitarla (#1181) en el futuro.
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Si es feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Si es batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Para tipos cuantizados, se realiza dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor o por tensor cuantizado (C1), (C10-C11), (C14) (C25), (C27-C30)
(I2). rhs tensor o tensor cuantificado (C1), (C14-C16), (C25) y (C27-C32)
(I3). window_strides Constante tensorial unidimensional de tipo si64 (C2-C3) o (C25)
(I4). padding Constante tensorial bidimensional de tipo si64 (C4) o (C25)
(I5). lhs_dilation Constante tensorial unidimensional de tipo si64 (C5-C6) o (C25)
(I6). rhs_dilation Constante tensorial unidimensional de tipo si64 (C7-C8) o (C25)
(I7). window_reversal Constante tensorial unidimensional de tipo i1 (C9)
(I8). input_batch_dimension constante de tipo si64 (C10), (C13) y (C25)
(I9). input_feature_dimension constante de tipo si64 (C11) o (C13-C14)
(I10). input_spatial_dimensions Constante tensorial unidimensional de tipo si64 (C12), (C13) y (C25)
(I11). kernel_input_feature_dimension constante de tipo si64 (C14) o (C18)
(I12). kernel_output_feature_dimension constante de tipo si64 (C15-C16), (C18), (C25), (C32)
(I13). kernel_spatial_dimensions Constante tensorial unidimensional de tipo si64 (C17-C18) o (C25)
(I14). output_batch_dimension constante de tipo si64 (C20) o C25
(I15). output_feature_dimension constante de tipo si64 (C20), (C25) y (C33)
(I16). output_spatial_dimensions Constante tensorial unidimensional de tipo si64 (C19-C20) o (C25)
(I17). feature_group_count constante de tipo si64 (C11), (C14), (C16), (C21), (C23)
(I18). batch_group_count constante de tipo si64 (C10), (C15), (C22), (C23), (C25)
(I19). precision_config cantidad variable de enumeraciones de DEFAULT, HIGH y HIGHEST (C24)

Salidas

Nombre Tipo Restricciones
result tensor o tensor cuantificado (C25-C28), (C30-C31) y (C33)

Restricciones

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Dado input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Dado kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Dado output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) se define de la siguiente manera:
    • dim(lhs, input_batch_dimension) / batch_group_count si es result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) si es result_dim = output_feature_dimension.
    • De lo contrario, num_windows, donde:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Si la operación usa tensores no cuantificados, haz lo siguiente:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Si la operación usa tensores cuantificados, haz lo siguiente:
    • (C28) is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result).
    • (C29) storage_type(lhs) = storage_type(rhs).
    • (C30) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C31) Si es is_per_tensor_quantized(rhs), entonces is_per_tensor_quantized(result).
    • (C32) Si es is_per_axis_quantized(rhs), entonces quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C33) Si es is_per_axis_quantized(result), entonces quantization_dimension(result) = output_feature_dimension.

Ejemplos

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

coseno

Semántica

Realiza la operación de coseno a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: cos de IEEE-754.
  • Para números complejos: coseno complejo.
  • Para tipos cuantizados: dequantize_op_quantize(cosine, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Más ejemplos

count_leading_zeros

Semántica

Realiza un recuento en elementos del número de bits cero iniciales en el tensor operand y produce un tensor result.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de número entero (C1)

Restricciones

  • (C1) type(operand) = type(result).

Ejemplos

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Más ejemplos

custom_call

Semántica

Encapsula una operación call_target_name definida por la implementación que toma inputs y called_computations, y produce results. Se pueden usar has_side_effect, backend_config y api_version para proporcionar metadatos adicionales definidos por la implementación.

Por el momento, esta operación contiene una colección de metadatos bastante desorganizada que refleja la evolución orgánica de su operación equivalente en el compilador XLA. En el futuro, planeamos unificar estos metadatos (#741).

Entradas

Etiqueta Nombre Tipo
(I1). inputs número variable de valores
(I2). call_target_name constante de tipo string
(I3). has_side_effect constante de tipo i1
(I4). backend_config constante de tipo string
(I5). api_version constante de tipo si32
(I6). called_computations número variable de constantes de tipo string

Salidas

Nombre Tipo
results número variable de valores

Ejemplos

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

dividir

Semántica

Realiza la división a nivel de los elementos de los tensores lhs y divisor rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números enteros: división de números enteros que produce el cociente algebraico con cualquier parte fraccionaria descartada.
  • Para números de punto flotante: division de IEEE-754.
  • Para números complejos: división compleja.
  • Para tipos cuantizados:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)
(I2). rhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Ejemplos

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Más ejemplos

dot_general

Semántica

Calcula productos de puntos entre porciones de lhs y porciones de rhs, y produce un tensor result.

De manera más formal, result[result_index] = dot_product, donde:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index, donde size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) y size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y)).

Para tipos cuantizados, se realiza dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Esto solo especifica la semántica para la cuantización por tensor. La cuantización por eje está en desarrollo (#1574). Además, en el futuro, es posible que consideremos agregar compatibilidad con la cuantización híbrida (#1575).

precision_config controla la compensación entre velocidad y precisión para los cálculos en backends del acelerador. Puede ser una de las siguientes opciones (por el momento, la semántica de estos valores enum no se especifica de forma adecuada, pero planeamos abordar esto en #755):

  • DEFAULT: Es el cálculo más rápido, pero la aproximación menos precisa al número original.
  • HIGH: Es un cálculo más lento, pero una aproximación más precisa al número original.
  • HIGHEST: Es el cálculo más lento, pero la aproximación más precisa al número original.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor o por tensor cuantizado (C5-C6), (C9-C10) o (C12-C16)
(I2). rhs tensor o por tensor cuantizado (C7-C10) o C12
(I3). lhs_batching_dimensions Constante tensorial unidimensional de tipo si64 (C1), (C3), (C5), (C9) y (C12)
(I4). rhs_batching_dimensions Constante tensorial unidimensional de tipo si64 (C1), (C4), (C7) y (C9)
(I5). lhs_contracting_dimensions Constante tensorial unidimensional de tipo si64 (C2), (C3), (C6) y (C10)
(I6). rhs_contracting_dimensions Constante tensorial unidimensional de tipo si64 (C2), (C4), (C8) y (C10)
(I7). precision_config cantidad variable de enumeraciones de DEFAULT, HIGH y HIGHEST (C11)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C12), (C14) y (C16)

Restricciones

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Si la operación usa tensores no cuantificados, haz lo siguiente:
    • (C13) element_type(lhs) = element_type(rhs).
  • Si la operación usa tensores cuantificados, haz lo siguiente:
    • (C14) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C15) storage_type(lhs) = storage_type(rhs).
    • (C16) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C17) zero_points(rhs) = 0.

Ejemplos

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Más ejemplos

dynamic_slice

Semántica

Extrae una porción de operand mediante índices de inicio calculados de forma dinámica y produce un tensor result. start_indices contiene los índices iniciales de la porción para cada dimensión sujeta a un posible ajuste, y slice_sizes contiene los tamaños de la porción para cada dimensión. De manera más formal, result[result_index] = operand[operand_index], en el que sucede lo siguiente:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1), (C2) y (C4)
(I2). start_indices número variádico de tensores de 0 dimensiones de tipo de número entero (C2) y (C3)
(I3). slice_sizes Constante tensorial unidimensional de tipo si64 (C2), (C4) y (C5)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1), (C5)

Restricciones

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Ejemplos

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Más ejemplos

dynamic_update_slice

Semántica

Produce un tensor result que es igual al tensor operand, excepto que la porción que comienza en start_indices se actualiza con los valores en update. De manera más formal, result[result_index] se define de la siguiente manera:

  • update[update_index] si es 0 <= update_index < shape(update), donde:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • De lo contrario, operand[result_index].

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1-C4) o (C6)
(I2). update tensor o por tensor cuantizado (C2), (C3) y (C6)
(I3). start_indices número variádico de tensores de 0 dimensiones de tipo de número entero (C4), (C5)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1)

Restricciones

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Ejemplos

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Más ejemplos

exponencial

Semántica

Realiza una operación exponencial a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: exp de IEEE-754.
  • Para números complejos: exponencial compleja.
  • Para tipos cuantizados: dequantize_op_quantize(exponential, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Más ejemplos

exponential_minus_one

Semántica

Realiza exponenciales a nivel de elementos menos una operación en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: expm1 de IEEE-754.
  • Para números complejos: exponencial compleja menos uno.
  • Para tipos cuantizados: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Más ejemplos

FFT

Semántica

Realiza las transformaciones inversas y directas de Fourier para entradas y salidas reales y complejas.

fft_type es una de las siguientes opciones:

  • FFT: Reenvía la FFT compleja a compleja.
  • IFFT: FFT de complejo a complejo inverso.
  • RFFT: Reenvía la FFT real a compleja.
  • IRFFT: FFT inverso de real a complejo (es decir, toma complejo, muestra real)

Más formalmente, dada la función fft, que toma tensores unidimensionales de tipos complejos como entrada, produce tensores unidimensionales de los mismos tipos como salida y calcula la transformación discreta de Fourier:

Para fft_type = FFT, result se define como el resultado final de una serie de cálculos L en los que L = size(fft_length). Por ejemplo, para L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Además, dada la función ifft, que tiene la misma firma de tipo y procesa el inverso de fft, sucede lo siguiente:

En fft_type = IFFT, result se define como el inverso de los cálculos para fft_type = FFT. Por ejemplo, para L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Además, dada la función rfft, que toma tensores unidimensionales de tipos de punto flotante, produce tensores unidimensionales de tipos complejos con la misma semántica de punto flotante y funciona de la siguiente manera:

  • rfft(real_operand) = truncated_result, donde
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

(Cuando se calcula la transformación discreta de Fourier para operandos reales, los primeros elementos N/2 + 1 del resultado definen de manera inequívoca el resto del resultado, por lo que el resultado de rfft se trunca para evitar el cálculo de elementos redundantes).

Para fft_type = RFFT, result se define como el resultado final de una serie de cálculos L en los que L = size(fft_length). Por ejemplo, para L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Por último, con la función irfft, que tiene la misma firma de tipo y calcula el inverso de rfft, haz lo siguiente:

En fft_type = IRFFT, result se define como el inverso de los cálculos para fft_type = RFFT. Por ejemplo, para L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o de tipo complejo (C1), (C2), (C4) y (C5)
(I2). fft_type enum de FFT, IFFT, RFFT y IRFFT (C2) y (C5)
(I3). fft_length Constante tensorial unidimensional de tipo si64 (C1), (C3) y (C4)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o de tipo complejo (C2), (C4) y (C5)

Restricciones

  • (C1) size(fft_length) <= rank(operand).
  • (C2) La relación entre los tipos de elementos operand y result varía:
    • Si fft_type = FFT, element_type(operand) y element_type(result) tienen el mismo tipo complejo.
    • Si fft_type = IFFT, element_type(operand) y element_type(result) tienen el mismo tipo complejo.
    • Si es fft_type = RFFT, element_type(operand) es un tipo de punto flotante y element_type(result) es un tipo complejo de la misma semántica de punto flotante.
    • Si es fft_type = IRFFT, element_type(operand) es un tipo complejo y element_type(result) es un tipo de punto flotante de la misma semántica de punto flotante.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Si entre operand y result, hay un tensor real de un tipo de punto flotante, entonces shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand), excepto para:
    • Si es fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Si es fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Ejemplos

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

piso

Semántica

Realiza el piso a nivel de elementos del tensor operand y produce un tensor result. Implementa la operación roundToIntegralTowardNegative de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(floor, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante o por tensor cuantificado (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Más ejemplos

recopilar

Semántica

Recopila segmentos del tensor operand de los desplazamientos especificados en start_indices y produce un tensor result.

En el siguiente diagrama, se muestra cómo se asignan los elementos de result a los elementos de operand con un ejemplo concreto. En el diagrama, se eligen algunos índices result de ejemplo y se explican en detalle a qué índices de operand corresponden.

De manera más formal, result[result_index] = operand[operand_index], donde:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index se define de la siguiente manera:
    • start_indices[bi0, ..., :, ..., biN], en el que bi son elementos individuales en batch_index y : se inserta en el índice index_vector_dim, si index_vector_dim < rank(start_indices).
    • De lo contrario, [start_indices[batch_index]].
  • Para d_operand en axes(operand):
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) si es d_operand = start_index_map[d_start].
    • De lo contrario, full_start_index[d_operand] = 0.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], en el que oi son elementos individuales en offset_index y 0 se inserta en los índices de collapsed_slice_dims.
  • operand_index = full_start_index + full_offset_index.

Si indices_are_sorted es true, la implementación puede suponer que start_indices se ordenan con respecto a start_index_map; de lo contrario, el comportamiento no está definido. De manera más formal, para todos los i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1), (C7), (C10-C12) y (C14)
(I2). start_indices tensor de tipo de número entero (C2), (C3) y (C13)
(I3). offset_dims Constante tensorial unidimensional de tipo si64 (C1), (C4-C5) y (C13)
(I4). collapsed_slice_dims Constante tensorial unidimensional de tipo si64 (C1), (C6-C8) y (C13)
(I5). start_index_map Constante tensorial unidimensional de tipo si64 (C3), (C9) y (C10)
(I6). index_vector_dim constante de tipo si64 (C2), (C3) y (C13)
(I7). slice_sizes Constante tensorial unidimensional de tipo si64 (C8) o (C11-C13)
(I8). indices_are_sorted constante de tipo i1

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C5) o (C13-C14)

Restricciones

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes), donde:
    • batch_dim_sizes = shape(start_indices), excepto que no se incluye el tamaño de la dimensión de start_indices correspondiente a index_vector_dim.
    • offset_dim_sizes = shape(slice_sizes), excepto que no se incluyen los tamaños de dimensión en slice_sizes correspondientes a collapsed_slice_dims.
    • combine coloca batch_dim_sizes en los ejes correspondientes a batch_dims y offset_dim_sizes en los ejes correspondientes a offset_dims.
  • (C14) element_type(operand) = element_type(result).

Ejemplos

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Más ejemplos

get_dimension_size

Semántica

Produce el tamaño del dimension determinado de operand. De manera más formal, result = dim(operand, dimension).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor (C1)
(I2). dimension constante de tipo si64 (C1)

Salidas

Nombre Tipo
result Tensor de 0 dimensiones de tipo si32

Restricciones

  • (C1) 0 <= dimension < rank(operand).

Ejemplos

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Más ejemplos

get_tuple_element

Semántica

Extrae el elemento en la posición index de la tupla operand y produce un result. Más formalmente, result = operand[index].

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tuple (C1), (C2)
(I2). index constante de tipo si32 (C1), (C2)

Salidas

Nombre Tipo Restricciones
result Cualquier tipo admitido (C2)

Restricciones

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Ejemplos

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Más ejemplos

if

Semántica

Produce el resultado a partir de la ejecución de exactamente una función de true_branch o false_branch según el valor de pred. Más formalmente, result = pred ? true_branch() : false_branch().

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). pred Tensor de 0 dimensiones de tipo i1
(I2). true_branch la función (C1-C3).
(I3). false_branch la función (C1), (C2)

Salidas

Nombre Tipo Restricciones
results cantidad variable de tensores, tensores cuantificados o tokens (C3)

Restricciones

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Ejemplos

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Más ejemplos

imagen

Semántica

Extrae la parte imaginaria a nivel de elementos de operand y produce un tensor result. De manera más formal, para cada elemento x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o de tipo complejo (C1), (C2)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante (C1), (C2)

Restricciones

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) se define de la siguiente manera:
    • complex_element_type(element_type(operand)) si es is_complex(operand).
    • De lo contrario, element_type(operand).

Ejemplos

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Más ejemplos

entrada

Semántica

Lee datos de la entrada y produce results.

La semántica de infeed_config está definida por la implementación.

results consisten en valores de carga útil que van primero y un token que va en último lugar. En el futuro, planeamos dividir la carga útil y el token en dos resultados separados para mejorar la claridad (#670).

Entradas

Etiqueta Nombre Tipo
(I1). token token
(I2). infeed_config constante de tipo string

Salidas

Nombre Tipo Restricciones
results cantidad variable de tensores, tensores cuantificados o tokens (C1-C3).

Restricciones

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) o is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Ejemplos

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Más ejemplos

Iota

Semántica

Rellena un tensor output con valores en orden creciente a partir de cero a lo largo de la dimensión iota_dimension. Más formalmente,

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). iota_dimension si64 (C1)

Salidas

Nombre Tipo Restricciones
output tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) 0 <= iota_dimension < rank(output).

Ejemplos

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Más ejemplos

is_finite

Semántica

Realiza una verificación a nivel de elementos si el valor en x es finito (es decir, no es +Inf, -Inf ni NaN) y produce un tensor y. Implementa la operación isFinite de la especificación IEEE-754. Para los tipos cuantizados, el resultado es siempre true.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). x tensor de tipo de punto flotante o por tensor cuantificado (C1)

Salidas

Nombre Tipo Restricciones
y tensor de tipo booleano (C1)

Restricciones

  • (C1) shape(x) = shape(y).

Ejemplos

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Más ejemplos

log

Semántica

Realiza una operación de logaritmo a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: log de IEEE-754.
  • Para números complejos: logaritmo complejo.
  • Para tipos cuantizados: dequantize_op_quantize(log, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Más ejemplos

log_plus_one

Semántica

Realiza el logaritmo a nivel de elementos más una operación en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: logp1 de IEEE-754.
  • Para números complejos: logaritmo complejo más uno.
  • Para tipos cuantizados: dequantize_op_quantize(log_plus_one, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Más ejemplos

logística

Semántica

Realiza una operación logística a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: division(1, addition(1, exp(-x))) de IEEE-754.
  • Para números complejos: logística compleja.
  • Para tipos cuantizados: dequantize_op_quantize(logistic, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Más ejemplos

map

Semántica

Aplica una función de asignación computation a inputs junto al dimensions y produce un tensor result.

Más formalmente, result[result_index] = computation(inputs...[result_index]). Ten en cuenta que, actualmente, no se usan dimensions y es probable que se quiten en el futuro (#487).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). inputs cantidad variádica de tensores o tensores cuantificados por tensor (C1-C4).
(I2). dimensions Constante tensorial unidimensional de tipo si64 (C3)
(I3). computation la función (C4)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1), (C4)

Restricciones

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation tiene el tipo (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, en el que Ei = element_type(inputs[i]) y E' = element_type(result).

Ejemplos

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Más ejemplos

máxima

Semántica

Realiza una operación máxima a nivel de elementos en los tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: OR lógico.
  • Para números enteros: número entero máximo.
  • Para números de punto flotante: maximum de IEEE-754.
  • Para números complejos: máximo lexicográfico para el par (real, imaginary). La imposición de un orden en números complejos implica una semántica sorprendente, por lo que en el futuro planeamos quitar la compatibilidad con números complejos para esta operación (#560).
  • Para tipos cuantizados:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor o por tensor cuantizado (C1)
(I2). rhs tensor o por tensor cuantizado (C1)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1)

Restricciones

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Ejemplos

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Más ejemplos

mínima

Semántica

Realiza una operación mínima a nivel de elementos en los tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: lógico AND.
  • Para números enteros: número entero mínimo.
  • Para números de punto flotante: minimum de IEEE-754.
  • En el caso de números complejos: mínimo lexicográfico para el par (real, imaginary). La imposición de un orden en números complejos implica una semántica sorprendente, por lo que en el futuro planeamos quitar la compatibilidad con números complejos para esta operación (#560).
  • Para tipos cuantizados:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor o por tensor cuantizado (C1)
(I2). rhs tensor o por tensor cuantizado (C1)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1)

Restricciones

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Ejemplos

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Más ejemplos

multiplicar

Semántica

Realiza el producto a nivel de elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: lógico AND.
  • Para números enteros: multiplicación de números enteros.
  • Para números de punto flotante: multiplication de IEEE-754.
  • Para números complejos: multiplicación compleja.
  • Para tipos cuantizados:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor o por tensor cuantizado (C1)
(I2). rhs tensor o por tensor cuantizado (C1)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Más ejemplos

negate

Semántica

Realiza una negación a nivel de elementos del tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números enteros con firma: negación de números enteros.
  • Para números enteros sin firma: transmisión de bits a número entero con firma, negación de número entero, conversión de bits a número entero sin firma.
  • Para números de punto flotante: negate de IEEE-754.
  • Para números complejos: negación compleja.
  • Para tipos cuantizados: dequantize_op_quantize(negate, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Más ejemplos

no me encuentro en

Semántica

Realiza NOT a nivel de elementos del tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: NOT lógico.
  • Para números enteros: NOT a nivel de bits.

Argumentos

Nombre Tipo Restricciones
operand tensor de tipo booleano o número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo booleano o número entero (C1)

Restricciones

  • (C1) type(operand) = type(result).

Ejemplos

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

Semántica

Garantiza que las operaciones que producen operand se ejecuten antes que cualquier operación que dependa de result y evita que las transformaciones del compilador muevan las operaciones a través de la barrera. Aparte de eso, la operación es una identidad, es decir, result = operand.

Argumentos

Nombre Tipo Restricciones
operand cantidad variable de tensores, tokens cuantificados por tensor o tokens (C1)

Salidas

Nombre Tipo Restricciones
result cantidad variable de tensores, tokens cuantificados por tensor o tokens (C1)

Restricciones

  • (C1) type(operand...) = type(result...).

Ejemplos

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Más ejemplos

o

Semántica

Realiza OR a nivel de elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: OR lógico.
  • Para números enteros: OR a nivel de bits.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de tipo entero o booleano (C1)
(I2). rhs tensor de tipo entero o booleano (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo entero o booleano (C1)

Restricciones

  • (C1) type(lhs) = type(rhs) = type(result).

Ejemplos

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

salida

Semántica

Escribe inputs en el feed de salida y produce un token de result.

La semántica de outfeed_config está definida por la implementación.

Entradas

Etiqueta Nombre Tipo
(I1). inputs cantidad variable de tensores o tensores cuantificados
(I2). token token
(I3). outfeed_config constante de tipo string

Salidas

Nombre Tipo
result token

Ejemplos

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Más ejemplos

almohadilla

Semántica

Expande operand mediante el relleno alrededor del tensor y entre los elementos del tensor con el padding_value determinado.

edge_padding_low y edge_padding_high especifican la cantidad de padding agregado en el extremo inferior (junto al índice 0) y en el extremo alto (junto al índice más alto) de cada dimensión, respectivamente. La cantidad de padding puede ser negativa, y el valor absoluto del padding negativo indica la cantidad de elementos que se quitarán de la dimensión especificada.

interior_padding especifica la cantidad de padding agregado entre dos elementos en cada dimensión que no puede ser negativo. El padding interior se produce antes del padding de bordes, de modo que el padding de borde negativo quitará elementos del operando con padding interno.

De manera más formal, result[result_index] se define de la siguiente manera:

  • operand[operand_index] si es result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • De lo contrario, padding_value.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1), (C2) y (C4)
(I2). padding_value Tensor de 0 dimensiones o tensor cuantificado por tensor (C1)
(I3). edge_padding_low Constante tensorial unidimensional de tipo si64 (C1), (C4)
(I4). edge_padding_high Constante tensorial unidimensional de tipo si64 (C1), (C4)
(I5). interior_padding Constante tensorial unidimensional de tipo si64 (C2-C4).

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C3-C6).

Restricciones

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Ejemplos

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Más ejemplos

partition_id

Semántica

Se produce el partition_id del proceso actual.

Salidas

Nombre Tipo
result Tensor de 0 dimensiones de tipo ui32

Ejemplos

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Más ejemplos

popcnt

Semántica

Realiza un recuento a nivel de elementos del número de bits configurado en el tensor operand y produce un tensor result.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de número entero (C1)

Restricciones

  • (C1) type(operand) = type(result).

Ejemplos

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Más ejemplos

potencia

Semántica

Realiza la exponente a nivel de elementos del tensor lhs con el tensor rhs y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números enteros: exponente de números enteros.
  • Para números de punto flotante: pow de IEEE-754.
  • Para números complejos: exponente complejo.
  • Para tipos cuantizados: dequantize_op_quantize(power, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)
(I2). rhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Más ejemplos

real

Semántica

Extrae la parte real a nivel de elementos del operand y produce un tensor result. De manera más formal, para cada elemento x: real(x) = is_complex(x) ? real_part(x) : x.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o de tipo complejo (C1), (C2)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante (C1), (C2)

Restricciones

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) se define de la siguiente manera:
    • complex_element_type(element_type(operand)) si es is_complex(operand).
    • De lo contrario, element_type(operand).

Ejemplos

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Más ejemplos

recibir

Semántica

Recibe datos de un canal con channel_id y produce results.

Si is_host_transfer es true, la operación transfiere datos desde el host. De lo contrario, se transferirán los datos desde otro dispositivo. Esto significa que está definido por la implementación. Esta marca duplica la información proporcionada en channel_type, por lo que, en el futuro, planeamos conservar solo una de ellas (#666).

results consisten en valores de carga útil que van primero y un token que va en último lugar. En el futuro, planeamos dividir la carga útil y el token en dos resultados separados para mejorar la claridad (#670).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). token token (C4)
(I2). channel_id constante de tipo si64
(I3). channel_type enum de DEVICE_TO_DEVICE y HOST_TO_DEVICE (C1)
(I4). is_host_transfer constante de tipo i1 (C1)

Salidas

Nombre Tipo Restricciones
results cantidad variable de tensores, tensores cuantificados o tokens (C2-C4).

Restricciones

  • (C1) channel_type se define de la siguiente manera:
    • HOST_TO_DEVICE si es is_host_transfer = true,
    • De lo contrario, DEVICE_TO_DEVICE.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) o is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Ejemplos

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Más ejemplos

Reducir

Semántica

Aplica una función de reducción body a inputs y init_values junto a dimensions y produce tensores results.

El orden de las reducciones está definido por la implementación, lo que significa que body y init_values deben formar un monoide a fin de garantizar que la operación produzca los mismos resultados para todas las entradas en todas las implementaciones. Sin embargo, esta condición no se aplica a muchas reducciones populares. Por ejemplo, la suma de punto flotante para body y cero para init_values en realidad no forman un monoide porque la suma de punto flotante no es asociativa.

De manera más formal, results...[j0, ..., jR-1] = reduce(input_slices_converted), donde:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], donde : se insertan en dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) para algún árbol binario schedule, en el que:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule es un árbol binario completo definido por la implementación cuyo recorrido en orden consiste en lo siguiente:
    • Valores input_slices_converted...[index], para todos los index en index_space(input_slices_converted) en el orden lexicográfico ascendente de index.
    • Se intercala con una cantidad definida por la implementación de init_values_converted en posiciones definidas por la implementación.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). inputs cantidad variádica de tensores o tensores cuantificados por tensor (C1-C4), (C6) y (C7)
(I2). init_values número variádico de tensores de 0 dimensiones o tensores cuantificados por tensor (C2) y (C3)
(I3). dimensions Constante tensorial unidimensional de tipo si64 (C4), (C5) y (C7)
(I4). body la función (C6)

Salidas

Nombre Tipo Restricciones
results cantidad variádica de tensores o tensores cuantificados por tensor (C3), (C7) y (C8)

Restricciones

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body tiene el tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), en el que is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), excepto que no se incluyen los tamaños de dimensión de inputs... que corresponden a dimensions.
  • (C8) element_type(results[i]) = Ei para todos los i en [0,N).

Ejemplos

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Más ejemplos

reduce_precision

Semántica

Realiza la conversión a nivel de elementos de operand a otro tipo de punto flotante que usa exponent_bits y mantissa_bits, y de vuelta al tipo de punto flotante original, y produce un tensor output.

Más formalmente:

  • Los bits de mantisa del valor original se actualizan para redondear el valor original al valor más cercano representable con mantissa_bits usando la semántica de roundToIntegralTiesToEven.
  • Entonces, si mantissa_bits son menores que la cantidad de bits mantisa del valor original, los bits mantisas se truncan como mantissa_bits.
  • Luego, si los bits exponentes del resultado intermedio no caben en el rango proporcionado por exponent_bits, el resultado intermedio desborda hasta el infinito con el signo original o se desborda a cero con el signo original.
  • Para tipos cuantizados, se realiza dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1)
(I2). exponent_bits constante de tipo si32 (C2)
(I3). mantissa_bits constante de tipo si32 (C3)

Salidas

Nombre Tipo Restricciones
output tensor de tipo de punto flotante o por tensor cuantificado (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Ejemplos

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Más ejemplos

reduce_scatter

Semántica

Dentro de cada grupo de procesos en la cuadrícula de procesos de StableHLO, realiza la reducción mediante computations sobre los valores del tensor operand de cada proceso, divide el resultado de la reducción junto con scatter_dimension en partes y dispersa las partes divididas entre los procesos para producir el result.

La operación divide la cuadrícula de procesos de StableHLO en process_groups, que se define de la siguiente manera:

  • cross_replica(replica_groups) si es channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) si es channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) si es channel_id > 0 and use_global_device_ids = true.

Luego, dentro de cada process_group, haz lo siguiente:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] para todos los sender en process_group, donde receiver_index = process_group.index(receiver).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1), (C2), (C7) y (C8)
(I2). scatter_dimension constante de tipo si64 (C1), (C2) y (C8)
(I3). replica_groups Constante tensorial bidimensional de tipo si64 (C3-C5).
(I4). channel_id constante de tipo si64 (C6)
(I5). use_global_device_ids constante de tipo i1 (C6)
(I6). computation la función (C7)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C8-C9).

Restricciones

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) se define de la siguiente manera:
    • num_replicas si se usa cross_replica.
    • num_replicas si se usa cross_replica_and_partition.
    • num_processes si se usa flattened_ids.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Si es use_global_device_ids = true, entonces channel_id > 0.
  • (C7) computation tiene el tipo (tensor<E>, tensor<E>) -> (tensor<E>), en el que is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand), excepto:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Ejemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Más ejemplos

reduce_window

Semántica

Aplica una función de reducción body a las ventanas de inputs y init_values, y produce results.

En el siguiente diagrama, se muestra cómo se calculan los elementos en results... a partir de inputs... mediante un ejemplo concreto.

De manera más formal, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (consulta reducir) donde:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). inputs cantidad variádica de tensores o tensores cuantificados por tensor (C1-C4), (C6), (C8), (C10), (C12), (C13) y (C15)
(I2). init_values número variádico de tensores de 0 dimensiones o tensores cuantificados por tensor (C1) o (C13)
(I3). window_dimensions Constante tensorial unidimensional de tipo si64 (C4), (C5) y (C15)
(I4). window_strides Constante tensorial unidimensional de tipo si64 (C6), (C7) y (C15)
(I5). base_dilations Constante tensorial unidimensional de tipo si64 (C8), (C9) y (C15)
(I6). window_dilations Constante tensorial unidimensional de tipo si64 (C10), (C11) y (C15)
(I7). padding Constante tensorial bidimensional de tipo si64 (C12) o (C15)
(I8). body la función (C13)

Salidas

Nombre Tipo Restricciones
results cantidad variádica de tensores o tensores cuantificados por tensor (C1) o (C14-C16)

Restricciones

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body tiene el tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), en el que is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows, donde:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei para todos los i en [0,N).

Ejemplos

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Más ejemplos

resto

Semántica

Realiza el resto a nivel de los elementos de los tensores lhs y divisor rhs, y produce un tensor result.

Más formalmente, el signo del resultado se toma del dividendo, y el valor absoluto del resultado siempre es menor que el valor absoluto del divisor. El resto se calcula como lhs - d * rhs, en el que d se calcula de la siguiente manera:

  • Para números enteros: stablehlo.divide(lhs, rhs).
  • Para números de punto flotante: division(lhs, rhs) de IEEE-754 con el atributo de redondeo roundTowardZero.
  • Para números complejos: Por definir (#997).
  • Para tipos cuantizados:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Para los tipos de elementos de punto flotante, esta operación contrasta con la operación remainder de la especificación IEEE-754, en la que d es un valor integral más cercano al valor exacto de lhs/rhs con vínculos a pares.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)
(I2). rhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Más ejemplos

replica_id

Semántica

Se produce el replica_id del proceso actual.

Salidas

Nombre Tipo
result Tensor de 0 dimensiones de tipo ui32

Ejemplos

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Más ejemplos

cambiar forma

Semántica

Realiza el cambio de forma del tensor operand a uno result. Conceptualmente, significa mantener la misma representación canónica, pero posiblemente cambiar la forma, p.ej., de tensor<2x3xf32> a tensor<3x2xf32> o tensor<6xf32>.

De manera más formal, result[result_index] = operand[operand_index], en el que result_index y operand_index tienen la misma posición en el orden lexicográfico de index_space(result) y index_space(operand).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o tensor cuantificado (C1-C3).

Salidas

Nombre Tipo Restricciones
result tensor o tensor cuantificado (C1-C3).

Restricciones

  • (C1) element_type(result) se obtiene de la siguiente manera:
    • element_type(operand), si es !is_per_axis_quantized(operand).
    • element_type(operand), excepto que quantization_dimension(operand) y quantization_dimension(result) pueden diferir.
  • (C2) size(operand) = size(result).
  • (C3) Si es is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Ejemplos

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Más ejemplos

reverse

Semántica

Invierte el orden de los elementos en operand a lo largo del dimensions especificado y produce un tensor result. De manera más formal, result[result_index] = operand[operand_index], en el que sucede lo siguiente:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 si es d en dimensions.
  • De lo contrario, operand_index[d] = result_index[d].

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1), (C3)
(I2). dimensions Constante tensorial unidimensional de tipo si64 (C2) y (C3)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1), (C3)

Restricciones

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Ejemplos

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Más ejemplos

rng

Semántica

Genera números aleatorios mediante el algoritmo rng_distribution y produce un tensor result de una forma shape determinada.

Si es rng_distribution = UNIFORM, los números aleatorios se generan siguiendo la distribución uniforme en el intervalo [a, b). Si es a >= b, el comportamiento no está definido.

Si es rng_distribution = NORMAL, los números aleatorios se generan según la distribución normal con una media = a y una desviación estándar = b. Si es b < 0, el comportamiento no está definido.

La implementación define exactamente cómo se generan los números aleatorios. Por ejemplo, pueden o no ser deterministas y pueden o no usar el estado oculto.

En conversaciones con muchas partes interesadas, esta operación dejó de estar disponible, por lo que en el futuro planeamos quitarla (#597).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). a Tensor de 0 dimensiones de tipo entero, booleano o punto flotante (C1), (C2)
(I2). b Tensor de 0 dimensiones de tipo entero, booleano o punto flotante (C1), (C2)
(I3). shape Constante tensorial unidimensional de tipo si64 (C3)
(I4). rng_distribution enum de UNIFORM y NORMAL (C2)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de número entero, booleano o punto flotante (C1-C3).

Restricciones

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Si es rng_distribution = NORMAL, entonces is_float(a).
  • (C3) shape(result) = shape.

Ejemplos

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semántica

Muestra un output con bits aleatorios uniformes y un estado de salida actualizado output_state mediante el algoritmo generador de números seudoaleatorio rng_algorithm, con un estado inicial initial_state. Se garantiza que el resultado sea una función determinista de initial_state, pero no que sea determinista entre implementaciones.

rng_algorithm es una de las siguientes opciones:

  • DEFAULT: Es un algoritmo definido por la implementación.
  • THREE_FRY: Variante definida por la implementación del algoritmo de Threefry.*
  • PHILOX: Variante definida por la implementación del algoritmo Philox.*

* Consultar: Salmon et al. SC 2011. Números aleatorios paralelos: tan sencillos como 1, 2, 3

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). rng_algorithm Enum de DEFAULT, THREE_FRY y PHILOX (C2)
(I2). initial_state Tensor unidimensional de tipo ui64 (C1), (C2)

Salidas

Nombre Tipo Restricciones
output_state Tensor unidimensional de tipo ui64 (C1)
output tensor de tipo de número entero o punto flotante

Restricciones

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) se define de la siguiente manera:
    • definida por la implementación si es rng_algorithm = DEFAULT.
    • 2 si es rng_algorithm = THREE_FRY.
    • 2 o 3 si es rng_algorithm = PHILOX.

Ejemplos

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semántica

Realiza un redondeo a nivel de elementos hacia el número entero más cercano, rompiendo los empates de cero en el tensor operand y produce un tensor result. Implementa la operación roundToIntegralTiesToAway de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante o por tensor cuantificado (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Más ejemplos

round_nearest_even

Semántica

Realiza un redondeo a nivel de elementos hacia el número entero más cercano, rompiendo los empates hacia el número entero par en el tensor operand y produce un tensor result. Implementa la operación roundToIntegralTiesToEven de la especificación IEEE-754. Para tipos cuantizados, realiza dequantize_op_quantize(round_nearest_even, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de tipo de punto flotante o por tensor cuantificado (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante o por tensor cuantificado (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Más ejemplos

RQR

Semántica

Realiza una operación de raíz cuadrada recíproca a nivel de los elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: rSqrt de IEEE-754.
  • Para números complejos: raíz cuadrada recíproca compleja.
  • Para tipos cuantizados: dequantize_op_quantize(rsqrt, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Más ejemplos

scatter

Semántica

Produce tensores results que son iguales a inputs, excepto que varias porciones especificadas por scatter_indices se actualizan con los valores updates mediante update_computation.

En el siguiente diagrama, se muestra cómo se asignan los elementos de updates... a los elementos de results... con un ejemplo concreto. En el diagrama, se eligen algunos índices updates... de ejemplo y se explican en detalle a qué índices results... corresponden.

De manera más formal, para todos los update_index en index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index se define de la siguiente manera:
    • scatter_indices[si0, ..., :, ..., siN], en el que si son elementos individuales en update_scatter_index y : se inserta en el índice index_vector_dim, si index_vector_dim < rank(scatter_indices).
    • De lo contrario, [scatter_indices[update_scatter_index]].
  • Para d_input en axes(inputs[0]):
    • full_start_index[d_input] = start_index[d_start] si es d_input = scatter_dims_to_operand_dims[d_start].
    • De lo contrario, full_start_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], en el que wi son elementos individuales en update_window_index y 0 se inserta en los índices de inserted_window_dims.
  • result_index = full_start_index + full_window_index.

Por lo tanto, results = exec(schedule, inputs), donde:

  • schedule es una permutación definida por la implementación de index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results), donde:
    • Si result_index está dentro de los límites de shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results es una copia de results con results...[result_index] establecido en updated_values....
    • De lo contrario
    • updated_results = results.
  • exec([], results) = results.

Si indices_are_sorted es true, la implementación puede suponer que scatter_indices se ordenan con respecto a scatter_dims_to_operand_dims; de lo contrario, el comportamiento no está definido. De manera más formal, para todos los i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Si unique_indices es true, la implementación puede suponer que todos los índices result_index que se dispersan son únicos. Si unique_indices es true, pero los índices que se dispersan no son únicos, el comportamiento no está definido.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). inputs cantidad variádica de tensores o tensores cuantificados por tensor (C1), (C2), (C4-C6), (C10), (C13), (C15-C16)
(I2). scatter_indices tensor de tipo de número entero (C4), (C11) y (C14)
(I3). updates cantidad variádica de tensores o tensores cuantificados por tensor (C3-C6) o (C8)
(I4). update_window_dims Constante tensorial unidimensional de tipo si64 (C2), (C4), (C7) y (C8)
(I5). inserted_window_dims Constante tensorial unidimensional de tipo si64 (C2), (C4), (C9) y (C10)
(I6). scatter_dims_to_operand_dims Constante tensorial unidimensional de tipo si64 (C11-C13).
(I7). index_vector_dim constante de tipo si64 (C4), (C11) y (C14)
(I8). indices_are_sorted constante de tipo i1
(I9). unique_indices constante de tipo i1
(I10). update_computation la función (C15)

Salidas

Nombre Tipo Restricciones
results cantidad variádica de tensores o tensores cuantificados por tensor (C15-C17).

Restricciones

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes), donde:
    • update_scatter_dim_sizes = shape(scatter_indices), excepto que no se incluye el tamaño de la dimensión de scatter_indices correspondiente a index_vector_dim.
    • update_window_dim_sizes <= shape(inputs[0]), excepto que no se incluyen los tamaños de dimensión en inputs[0] correspondientes a inserted_window_dims.
    • combine coloca update_scatter_dim_sizes en los ejes correspondientes a update_scatter_dims y update_window_dim_sizes en los ejes correspondientes a update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(inserted_window_dims) and is_sorted(update_window_dims).
  • (C10) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C11) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C12) is_unique(scatter_dims_to_operand_dims).
  • (C13) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C14) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C15) update_computation tiene el tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), en el que is_promotable(element_type(inputs[i]), Ei).
  • (C16) shape(inputs...) = shape(results...).
  • (C17) element_type(results[i]) = Ei para todos los i en [0,N).

Ejemplos

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

Más ejemplos

select

Semántica

Produce un tensor result en el que cada elemento se selecciona del tensor on_true o on_false según el valor del elemento correspondiente de pred. De manera más formal, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], donde pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Para tipos cuantizados, realiza dequantize_select_quantize(pred, on_true, on_false, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). pred tensor de tipo i1 (C1)
(I2). on_true tensor o por tensor cuantizado (C1-C2).
(I3). on_false tensor o por tensor cuantizado (C2)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C2)

Restricciones

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Ejemplos

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Más ejemplos

select_and_scatter

Semántica

Pasa los valores del tensor source mediante scatter en función del resultado de reduce_window del tensor input con select y produce un tensor result.

En el siguiente diagrama, se muestra cómo se calculan los elementos en result a partir de operand y source con un ejemplo concreto.

Más formalmente:

  • selected_values = reduce_window_without_init(...) por las siguientes entradas:

    • `inputs = [operando].
    • window_dimensions, window_strides y padding, que se usan tal como están.
    • base_dilations = windows_dilations = 1.
    • body se define de la siguiente manera:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    donde E = element_type(operand) y reduce_window_without_init funcionan exactamente como reduce_window, excepto que el schedule del reduce subyacente (consulta reducir) no incluye valores de inicio. Por el momento, no se especifica qué sucede si la ventana correspondiente no tiene valores (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter), en el que:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index si selected_values[source_index] tiene el elemento operand de operand_index
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1-C4), (C6) y (C8-C11)
(I2). source tensor o por tensor cuantizado (C1), (C2)
(I3). init_value Tensor de 0 dimensiones o tensor cuantificado por tensor (C3)
(I4). window_dimensions Constante tensorial unidimensional de tipo si64 (C2), (C4) y (C5)
(I5). window_strides Constante tensorial unidimensional de tipo si64 (C2), (C6) y (C7)
(I6). padding Constante tensorial bidimensional de tipo si64 (C2) y (C8)
(I7). select la función (C9)
(I8). scatter la función (C10)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C11-C12).

Restricciones

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows, donde:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select tiene el tipo (tensor<E>, tensor<E>) -> tensor<i1>, en el que E = element_type(operand).
  • (C10) scatter tiene el tipo (tensor<E>, tensor<E>) -> tensor<E>, donde is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Ejemplos

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Más ejemplos

enviar

Semántica

Envía inputs al canal channel_id y produce un token result.

Si is_host_transfer es true, la operación transfiere datos al host. De lo contrario, transferirá los datos a otro dispositivo. Esto significa que está definido por la implementación. Esta marca duplica la información proporcionada en channel_type, por lo que, en el futuro, planeamos conservar solo una de ellas (#666).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). inputs cantidad variable de tensores o tensores cuantificados
(I2). token token
(I3). channel_id constante de tipo si64
(I4). channel_type enum de DEVICE_TO_DEVICE y DEVICE_TO_HOST (C1)
(I5). is_host_transfer constante de tipo i1 (C1)

Salidas

Nombre Tipo
result token

Restricciones

  • (C1) channel_type se define de la siguiente manera:
    • DEVICE_TO_HOST si es is_host_transfer = true,
    • De lo contrario, DEVICE_TO_DEVICE.

Ejemplos

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Más ejemplos

shift_left

Semántica

Realiza una operación de desplazamiento a la izquierda a nivel de elementos en el tensor lhs según la cantidad de bits rhs y produce un tensor result.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de tipo de número entero (C1)
(I2). rhs tensor de tipo de número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de número entero (C1)

Restricciones

  • (C1) type(lhs) = type(rhs) = type(result).

Ejemplos

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Más ejemplos

shift_right_arithmetic

Semántica

Realiza una operación aritmética de desplazamiento a la derecha a nivel de los elementos en el tensor lhs según la cantidad de bits rhs y produce un tensor result.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de tipo de número entero (C1)
(I2). rhs tensor de tipo de número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de número entero (C1)

Restricciones

  • (C1) type(lhs) = type(rhs) = type(result).

Ejemplos

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Más ejemplos

shift_right_logical

Semántica

Realiza una operación lógica de desplazamiento a la derecha a nivel de elementos en el tensor lhs según la cantidad de bits de rhs y produce un tensor result.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de tipo de número entero (C1)
(I2). rhs tensor de tipo de número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de número entero (C1)

Restricciones

  • (C1) type(lhs) = type(rhs) = type(result).

Ejemplos

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Más ejemplos

igual.

Semántica

Muestra el signo del operand a nivel de elementos y produce un tensor result. De manera más formal, para cada elemento x, la semántica se puede expresar mediante la sintaxis de Python de la siguiente manera:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Para tipos cuantizados, realiza dequantize_op_quantize(sign, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de número entero con signo, punto flotante o tipo complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de número entero con signo, punto flotante o tipo complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Más ejemplos

seno

Semántica

Realiza una operación de seno a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: sin de IEEE-754.
  • Para números complejos: seno complejo.
  • Para tipos cuantizados: dequantize_op_quantize(sine, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Más ejemplos

slice

Semántica

Extrae una porción de operand mediante índices de inicio calculados de forma estática y produce un tensor result. start_indices contiene los índices iniciales de la porción para cada dimensión, limit_indices contiene los índices finales (exclusivos) de la porción para cada dimensión y strides contiene los segmentos de cada dimensión.

De manera más formal, result[result_index] = operand[operand_index], donde operand_index = start_indices + result_index * strides.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o por tensor cuantizado (C1-C3) o C5
(I2). start_indices Constante tensorial unidimensional de tipo si64 (C2), (C3) y (C5)
(I3). limit_indices Constante tensorial unidimensional de tipo si64 (C2), (C3) y (C5)
(I4). strides Constante tensorial unidimensional de tipo si64 (C2) y (C4)

Salidas

Nombre Tipo Restricciones
result tensor o por tensor cuantizado (C1), (C5)

Restricciones

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Ejemplos

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Más ejemplos

sort

Semántica

Ordena porciones unidimensionales de inputs a lo largo de la dimensión dimension juntas, de acuerdo con comparator, y produce results.

A diferencia de entradas similares en otras operaciones, dimension permite valores negativos, con la semántica que se describe a continuación. En el futuro, es posible que esto se inhabilite por motivos de coherencia (#1377).

Si is_stable es verdadero, el orden es estable, es decir, se conserva el orden relativo de los elementos que el comparador considera iguales. Para el caso en el que hay una sola entrada, el comparador considera que dos elementos e1 y e2 son iguales solo si comparator(e1, e2) = comparator(e2, e1) = false. Consulta la formalización a continuación para ver cómo esto se generaliza a varias entradas.

De manera más formal, para todos los result_index en index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], en el que riN son elementos individuales en result_index y : se inserta en adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • donde sort ordena una porción unidimensional en orden no descendente y espera que comparator_together muestre true si el argumento del lado izquierdo es menor que el segundo argumento de la derecha.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). inputs cantidad variádica de tensores o tensores cuantificados por tensor (C1-C5).
(I2). dimension constante de tipo si64 (C4)
(I3). is_stable constante de tipo i1
(I4). comparator la función (C5)

Salidas

Nombre Tipo Restricciones
results cantidad variádica de tensores o tensores cuantificados por tensor (C2) y (C3)

Restricciones

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, donde R = rank(inputs[0]).
  • (C5) comparator tiene el tipo (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, en el que Ei = element_type(inputs[i]).

Ejemplos

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Más ejemplos

sqrt

Semántica

Realiza una operación de raíz cuadrada a nivel de elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: squareRoot de IEEE-754.
  • Para números complejos: raíz cuadrada compleja.
  • Para tipos cuantizados: dequantize_op_quantize(sqrt, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Más ejemplos

subtract

Semántica

Realiza la resta a nivel de elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números enteros: resta de números enteros.
  • Para números de punto flotante: subtraction de IEEE-754.
  • Para números complejos: resta compleja.
  • Para tipos cuantizados:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)
(I2). rhs tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de número entero, punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Ejemplos

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Más ejemplos

tanh

Semántica

Realiza la operación tangente hiperbólica a nivel de los elementos en el tensor operand y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para números de punto flotante: tanh de IEEE-754.
  • Para números complejos: tangente hiperbólica compleja.
  • Para tipos cuantizados:
    • dequantize_op_quantize(tanh, operand, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_type(operand) = baseline_type(result).

Ejemplos

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Más ejemplos

transponer

Semántica

Permuta las dimensiones del tensor operand con permutation y produce un tensor result. De manera más formal, result[result_index] = operand[operand_index], donde result_index[d] = operand_index[permutation[d]].

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor o tensor cuantificado (C1-C4).
(I2). permutation Constante tensorial unidimensional de tipo si64 (C2-C4).

Salidas

Nombre Tipo Restricciones
result tensor o tensor cuantificado (C1) o (C3-C4)

Restricciones

  • (C1) element_type(result) se obtiene de la siguiente manera:
    • element_type(operand), si es !is_per_axis_quantized(operand).
    • element_type(operand), excepto que quantization_dimension(operand) y quantization_dimension(result) pueden diferir.
  • (C2) permutation es una permutación de range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Si es is_per_axis_quantized(result), entonces quantization_dimension(operand) = permutation(quantization_dimension(result)).

Ejemplos

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Más ejemplos

triangular_solve

Semántica

Resuelve lotes de sistemas de ecuaciones lineales con matrices de coeficientes triangulares inferior o superior.

De manera más formal, con a y b, result[i0, ..., iR-3, :, :] es la solución para op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] cuando left_side es true o x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] cuando left_side es false, lo que resuelve la variable x donde op(a) está determinado por transpose_a, que puede ser una de las siguientes opciones:

  • NO_TRANSPOSE: Realiza la operación con a tal como está.
  • TRANSPOSE: realiza una operación de transposición de a.
  • ADJOINT: Realiza una operación sobre la transposición conjugada de a.

Los datos de entrada solo se leen desde el triángulo inferior de a, si lower es true o, de lo contrario, el triángulo superior de a. Los datos de salida se muestran en el mismo triángulo; los valores en el otro triángulo están definidos por la implementación.

Si el valor de unit_diagonal es verdadero, la implementación puede suponer que los elementos diagonales de a son iguales a 1. De lo contrario, el comportamiento no está definido.

Para tipos cuantizados, realiza dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). a tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1-C3).
(I2). b tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1-C4).
(I3). left_side constante de tipo i1 (C3)
(I4). lower constante de tipo i1
(I5). unit_diagonal constante de tipo i1
(I6). transpose_a Enum de NO_TRANSPOSE, TRANSPOSE y ADJOINT

Salidas

Nombre Tipo Restricciones
result tensor de punto flotante o complejo, o tensor cuantificado por tensor (C1)

Restricciones

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) La relación entre shape(a) y shape(b) se define de la siguiente manera:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Ejemplos

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Semántica

Produce una tupla result a partir de los valores val.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). val número variable de valores (C1)

Salidas

Nombre Tipo Restricciones
result tuple (C1)

Restricciones

  • (C1) result tiene el tipo tuple<E0, ..., EN-1>, donde Ei = type(val[i]).

Ejemplos

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Más ejemplos

uniform_dequantize

Semántica

Realiza la conversión a nivel de elementos del tensor cuantificado operand en un tensor de punto flotante result de acuerdo con los parámetros de cuantización definidos por el tipo operand.

Más formalmente, result = dequantize(operand).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor cuantizado (C1), (C2)

Salidas

Nombre Tipo Restricciones
result tensor de tipo de punto flotante (C1), (C2)

Restricciones

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Ejemplos

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semántica

Realiza la conversión a nivel de elementos del tensor de punto flotante o del tensor cuantificado operand en un tensor cuantificado result de acuerdo con los parámetros de cuantización definidos por el tipo result.

Más formalmente,

  • Si es is_float(operand):
    • result = quantize(operand, type(result)).
  • Si es is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand tensor de punto flotante o tipo cuantizado (C1), (C2)

Salidas

Nombre Tipo Restricciones
result tensor cuantizado (C1), (C2)

Restricciones

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Ejemplos

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

mientras

Semántica

Produce el resultado de la ejecución de la función body 0 o más veces, mientras que la función cond genera true. De manera más formal, la semántica se puede expresar mediante la sintaxis de Python de la siguiente manera:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

Por definir el comportamiento de un bucle infinito (#383).

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). operand cantidad variable de tensores, tensores cuantificados o tokens (C1-C3).
(I2). cond la función (C1)
(I3). body la función (C2)

Salidas

Nombre Tipo Restricciones
results cantidad variable de tensores, tensores cuantificados o tokens (C3)

Restricciones

  • (C1) cond tiene el tipo (T0, ..., TN-1) -> tensor<i1>, en el que Ti = type(operand[i]).
  • (C2) body tiene el tipo (T0, ..., TN-1) -> (T0, ..., TN-1), en el que Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Ejemplos

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Más ejemplos

xor

Semántica

Realiza XOR a nivel de elementos de dos tensores lhs y rhs, y produce un tensor result. Según el tipo de elemento, hace lo siguiente:

  • Para valores booleanos: XOR lógico.
  • Para números enteros: XOR a nivel de bits.

Entradas

Etiqueta Nombre Tipo Restricciones
(I1). lhs tensor de tipo booleano o número entero (C1)
(I2). rhs tensor de tipo booleano o número entero (C1)

Salidas

Nombre Tipo Restricciones
result tensor de tipo booleano o número entero (C1)

Restricciones

  • (C1) type(lhs) = type(rhs) = type(result).

Ejemplos

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Ejecución

Ejecución secuencial

Para ejecutar un programa StableHLO, se proporcionan valores de entrada a la función main y se calculan los valores de salida. Los valores de salida de una función se calculan mediante la ejecución del grafo de ops con raíces en la op return correspondiente.

El orden de ejecución se define por la implementación, siempre y cuando esté alineado con el flujo de datos, es decir, si las operaciones se ejecutan antes de sus usos. En StableHLO, todas las operaciones con efectos secundarios consumen un token y producen uno (varios tokens se pueden multiplexar en uno a través de after_all), por lo que el orden de ejecución de los efectos secundarios también se alinea con el flujo de datos. Los posibles órdenes de ejecución del programa de ejemplo anterior son %0%1%2%3%4return o %3%0%1%2%4return.

De manera más formal, un proceso EstableHLO es una combinación de: 1) un programa estable, 2) estados de operación (aún no ejecutado, ya ejecutado) y 3) valores intermedios en los que funciona el proceso. El proceso comienza con los valores de entrada en la función main, avanza por el grafo de operaciones que actualizan los estados de las operaciones y los valores intermedios, y finaliza con los valores de salida. Aún no se formaliza el proceso (#484).

Ejecución paralela

Los programas StableHLO se pueden ejecutar en paralelo y se organizan en una cuadrícula de procesos en 2D de num_replicas por num_partitions, que ambos tienen el tipo ui32.

En la cuadrícula de procesos de StableHLO, num_replicas * num_partitions de los procesos de StableHLO se ejecutan al mismo tiempo. Cada proceso tiene un process_id = (replica_id, partition_id) único, en el que replica_id en replica_ids = range(num_replicas) y partition_id en partition_ids = range(num_partitions), ambos tienen el tipo ui32.

El tamaño de la cuadrícula de procesos se conoce estáticamente para cada programa (en el futuro, planeamos hacerlo una parte explícita de los programas StableHLO #650), y la posición dentro de la cuadrícula de procesos se conoce estáticamente para cada proceso. Cada proceso tiene acceso a su posición dentro de la cuadrícula de procesos mediante las operaciones replica_id y partition_id.

Dentro de la cuadrícula de procesos, todos los programas pueden ser iguales (en el estilo "Programa único, varios datos"), pueden ser todos diferentes (en el estilo "Varios programas, varios datos") o algo intermedio. En el futuro, planeamos agregar compatibilidad con otros modismos de definición de programas StableHLO paralelos, incluido GSPMD (#619).

Dentro de la cuadrícula de procesos, los procesos son mayormente independientes entre sí: tienen estados de operación separados, valores de entrada/intermedio/salida independientes y la mayoría de las operaciones se ejecutan por separado entre procesos, a excepción de una pequeña cantidad de operaciones colectivas que se describen a continuación.

Debido a que la ejecución de la mayoría de las operaciones solo usa valores del mismo proceso, no suele ser ambiguo hacer referencia a estos valores por sus nombres. Sin embargo, cuando se describe la semántica de ops colectivas, esto no es suficiente y da lugar a la notación name@process_id para hacer referencia al valor name dentro de un proceso en particular. (Desde esa perspectiva, se puede ver name no calificado como una abreviatura de name@(replica_id(), partition_id())).

El orden de ejecución entre los procesos está definido por la implementación, excepto por la sincronización introducida por la comunicación punto a punto y las operaciones colectivas, como se describe a continuación.

Comunicación punto a punto

Los procesos StableHLO pueden comunicarse entre sí a través de canales estables. Un canal se representa con un ID positivo del tipo si64. A través de varias operaciones, es posible enviar valores a los canales y recibirlos de ellos.

Aún no se requiere más formalización, p.ej., de dónde provienen estos IDs de canal, cómo los procesos los programas toman conocimiento de ellos y qué tipo de sincronización introducen (#484).

Comunicación en vivo

Cada proceso de StableHLO tiene acceso a dos interfaces de transmisión:

  • Entrada que se puede leer.
  • Salida en la que se pueden escribir.

A diferencia de los canales, que se usan para la comunicación entre procesos y, por lo tanto, tienen procesos en ambos extremos, las entradas y salidas tienen definida la otra implementación final.

Aún no se define más formalización, p.ej., cómo la comunicación de transmisión influye en el orden de ejecución y qué tipo de sincronización introduce (#484).

Operaciones colectivas

Hay seis operaciones colectivas en StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute y reduce_scatter. Todas estas operaciones dividen los procesos en la cuadrícula de procesos de StableHLO en grupos de procesos de StableHLO y ejecutan un cálculo conjunto dentro de cada grupo de procesos, independientemente de otros grupos de procesos.

Dentro de cada grupo de procesos, las operaciones colectivas pueden introducir una barrera de sincronización. No se requiere más formalización, p.ej., explicar cuándo exactamente se produce esta sincronización, cómo llegan exactamente los procesos a esta barrera y qué sucede si no lo hacen (#484).

Si el grupo de procesos involucra una comunicación de partición cruzada, es decir, hay procesos en el grupo de procesos cuyos ID de partición son diferentes, la ejecución de la operación colectiva necesita un canal, y la operación colectiva debe proporcionar un channel_id positivo de tipo si64. La comunicación entre réplicas no necesita canales.

Los cálculos que realizan las operaciones colectivas son específicos de cada operación individual y se describen en las secciones anteriores de las operaciones individuales. Sin embargo, las estrategias mediante las cuales la cuadrícula de procesos se divide en grupos de procesos se comparten entre estas operaciones y se describen en esta sección. De manera más formal, StableHLO admite las siguientes cuatro estrategias.

cross_replica

Solo las comunicaciones entre réplicas ocurren dentro de cada grupo de procesos. Esta estrategia toma replica_groups, una lista de listas de IDs de réplica, y calcula un producto cartesiano de replica_groups según partition_ids. replica_groups debe tener elementos únicos y abarcar todos los replica_ids. De manera más formal, con la sintaxis de Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Por ejemplo, para replica_groups = [[0, 1], [2, 3]] y num_partitions = 2, cross_replica producirá [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Solo las comunicaciones de particiones cruzadas ocurren dentro de cada grupo de procesos. Esta estrategia toma partition_groups (una lista de listas de ID de partición) y calcula un producto cartesiano de partition_groups por replica_ids. partition_groups debe tener elementos únicos y abarcar toda la partition_ids. De manera más formal, con la sintaxis de Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Por ejemplo, para partition_groups = [[0, 1]] y num_replicas = 4, cross_partition producirá [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

Las comunicaciones entre réplicas y particiones pueden ocurrir dentro de cada grupo de procesos. Esta estrategia toma replica_groups, una lista de listas de IDs de réplica, y calcula los productos cartesianos de cada replica_group por partition_ids. replica_groups debe tener elementos únicos y abarcar todas las replica_ids. De manera más formal, con la sintaxis de Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Por ejemplo, para replica_groups = [[0, 1], [2, 3]] y num_partitions = 2, cross_replica_and_partition producirá [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Esta estrategia toma flattened_id_groups, una lista de IDs de procesos "compactos" en forma de replica_id * num_partitions + partition_id, y los convierte en IDs de proceso. flattened_id_groups debe tener elementos únicos y abarcar todos los process_ids. De manera más formal, con la sintaxis de Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Por ejemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 y num_partitions = 2, flattened_ids producirá [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Exactitud

Por el momento, StableHLO no proporciona garantías sobre la exactitud numérica, pero esto puede cambiar en el futuro (#1156).

Errores

Los programas StableHLO se validan a través de un amplio conjunto de restricciones para operaciones individuales, que descarta muchas clases de errores antes del tiempo de ejecución. Sin embargo, las condiciones de error aún son posibles, p.ej., a través de desbordamientos de números enteros, accesos fuera de los límites, etc. A menos que se llamen explícitamente, todos estos errores generan un comportamiento definido por la implementación, pero esto puede cambiar en el futuro (#1157).

Como excepción a esta regla, las excepciones de punto flotante en los programas StableHLO tienen un comportamiento bien definido. Las operaciones que generan excepciones definidas por el estándar IEEE-754 (operación no válida, división por cero, desbordamiento, subdesbordamiento o excepciones inexactas) producen resultados predeterminados (como se define en el estándar) y continúan su ejecución sin elevar la marca de estado correspondiente; similar al control de excepciones raiseNoFlag del estándar. La implementación define las excepciones para las operaciones no estándar (p.ej., las aritméticas complejas y ciertas funciones trascendentales).

Notation

Para describir la sintaxis, en este documento se usa la variante ISO modificada de la sintaxis EBNF (ISO/IEC 14977:1996, Wikipedia), con dos modificaciones: 1) las reglas se definen con ::= en lugar de =;

2) la concatenación se expresa mediante la yuxtaposición en lugar de ,.

Para describir la semántica (es decir, dentro de las secciones “Tipos”, “Constantes” y “Operaciones”), usamos fórmulas que se basan en la sintaxis de Python, ampliada con compatibilidad para expresar operaciones de arreglo de forma concisa, como se describe a continuación. Esto funciona bien para pequeños fragmentos de código, pero en casos excepcionales, cuando se necesitan fragmentos de código más grandes, usamos la sintaxis normal de Python, que siempre se presenta de manera explícita.

Fórmulas

Exploremos cómo funcionan las fórmulas con un ejemplo de la especificación dot_general. Una de las restricciones de esta operación se ve de la siguiente manera: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Los nombres usados en esta fórmula provienen de dos fuentes: 1) funciones globales, es decir, dim, y 2) definiciones de miembros del elemento del programa correspondiente, es decir, las entradas lhs, lhs_batching_dimensions, rhs y rhs_batching_dimensions definidas en la sección "Entradas" de dot_general.

Como se mencionó antes, la sintaxis de esta fórmula se basa en Python con algunas extensiones orientadas a la brevedad. Para entender la fórmula, vamos a transformarla en la sintaxis normal de Python.

A) En estas fórmulas, usamos = para representar la igualdad, por lo que el primer paso para obtener la sintaxis de Python es reemplazar = por ==, de la siguiente manera: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Además, estas fórmulas admiten elipses (...) que convierten las expresiones escalares en expresiones de tensor. En pocas palabras, f(xs...) significa "para cada x escalar en el tensor xs, calcula un f(x) escalar y, luego, muestra todos estos resultados escalares juntos como un resultado del tensor". En la sintaxis normal de Python, nuestra fórmula de ejemplo se convierte en: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Gracias a los puntos suspensivos, a menudo es posible evitar trabajar al nivel de escalares individuales. Sin embargo, en algunos casos difíciles, se puede usar la sintaxis semiinformal de nivel inferior, como en la fórmula start_indices[bi0, ..., :, ..., biN] de la especificación gather. Por razones de brevedad, no proporcionamos un formalismo exacto para traducir esa sintaxis al formato clásico de Python, con la esperanza de que sea intuitivamente comprensible según el caso. Avísanos si algunas fórmulas específicas se ven opacas y trataremos de mejorarlas.

Además, verás que las fórmulas usan elipses para expandir todo tipo de listas, incluidos los tensores, las listas de tensores (que, p.ej., pueden surgir de una cantidad variable de tensores), entre otras. Esta es otra área en la que no proporcionamos un formalismo exacto (p.ej., las listas ni siquiera son parte del sistema de comprensibilidad intuitiva).

C) El último vehículo notacional notable que empleamos es la transmisión implícita. Si bien el opset StableHLO no admite la transmisión implícita, las fórmulas sí lo hacen, para garantizar la concisión. En pocas palabras, si se usa un escalar en un contexto donde se espera un tensor, el escalar se transmite a la forma esperada.

Para continuar con el ejemplo de dot_general, esta es otra restricción: 0 <= lhs_batching_dimensions < rank(lhs). Como se define en la especificación dot_general, lhs_batching_dimensions es un tensor; sin embargo, tanto 0 como rank(lhs) son escalares. Después de aplicar la transmisión implícita, la fórmula se convertirá en [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Cuando se aplica a una operación dot_general específica, esta fórmula se evaluará como un tensor de valores booleanos. Cuando las fórmulas se usan como restricciones, la restricción se mantiene si la fórmula se evalúa como true o como un tensor que solo tiene elementos true.

Nombres

En las fórmulas, el alcance léxico incluye lo siguiente: 1) las funciones globales, 2) las definiciones de los miembros,

3) las definiciones locales. A continuación, se proporciona la lista de funciones globales. La lista de definiciones de elementos depende del elemento del programa al que se aplica la notación:

  • Para las operaciones, las definiciones de los miembros incluyen los nombres ingresados en las secciones “Entradas” y “Salidas”.
  • Para todo lo demás, las definiciones de los miembros incluyen partes estructurales del elemento del programa, nombradas según los no terminales de EBNF correspondientes. La mayoría de las veces, los nombres de esas partes estructurales se obtienen convirtiendo los nombres de las no terminales en snake case (p.ej., IntegerLiteral => integer_literal), pero a veces los nombres se abrevian en el proceso (p.ej., QuantizationStorageType => storage_type), en cuyo caso los nombres se ingresan de manera explícita de manera similar a las secciones "Inputs" y "Outputs" en las especificaciones de operación.
  • Además, las definiciones de los miembros siempre incluyen self para hacer referencia al elemento del programa correspondiente.

Valores

Cuando se evalúan las fórmulas, funcionan con los siguientes tipos de valores: 1) Value (valores reales, p.ej., dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; siempre conocen sus tipos), 2) Placeholder (valores futuros, p.ej., lhs, rhs o result; sus valores reales aún no se conocen, solo se conocen sus tipos), 3) Type (tipos definidos en la sección "Tipos") 4) Function (funciones globales según se definen en la sección "Funciones").

Según el contexto, los nombres pueden referirse a valores diferentes. Específicamente, la sección "Semántica" de las operaciones (y equivalentes de otros elementos del programa) define la lógica del entorno de ejecución, por lo que todas las entradas están disponibles como Value. Por el contrario, la sección "Restricciones" de las operaciones (y equivalentes) define la lógica de "tiempo de compilación", es decir, algo que se suele ejecutar antes del tiempo de ejecución, por lo que solo las entradas constantes están disponibles como Value y las demás solo están disponibles como Placeholder.

Nombres En "Semántica" En “Restricciones”
Funciones globales Function Function
Entradas constantes Value Value
Entradas no constantes Value Placeholder
Salidas Value Placeholder
Definiciones locales Depende de la definición Depende de la definición

Consideremos una operación transpose de ejemplo:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Para esta operación, permutation es una constante, por lo que está disponible como Value tanto en semántica como en restricciones. Por el contrario, operand y result están disponibles como Value en la semántica, pero solo como Placeholder en restricciones.

Funciones

Construcción de tipos

No hay funciones que se puedan usar para construir tipos. En su lugar, usamos directamente la sintaxis de tipos, ya que suele ser más concisa. P.ej., (tensor<E>, tensor<E>) -> (tensor<E>) en lugar de function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Funciones en tipos

  • element_type se define en los tipos de tensor y los tipos de tensor cuantificados y muestra, respectivamente, la parte TensorElementType o QuantizedTensorElementType del TensorType o QuantizedTensorType correspondientes.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value es un atajo para is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value es un acceso directo para is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool verifica si el tipo x se puede ascender al tipo y. Cuando x y y son elementos QuantizedTensorElementType, la promoción se aplica solo a storage_type. Actualmente, esta versión específica de la promoción se usa en el contexto del cálculo de reducción (consulta RFC para obtener más detalles).

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value es un acceso directo para is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Disponible para todos los tipos. Por ejemplo, is_float(x) muestra true si x es un FloatType. Si x es un valor o un marcador de posición, esta función es un acceso directo para is_type_name(type(x)).

  • max_value(x: Type) -> Value muestra el valor máximo de un TensorElementType. Si x no es un TensorElementType, muestra None.

  • min_value(x: Type) -> Value muestra el valor mínimo posible de un TensorElementType. Si x no es un TensorElementType, muestra None.

  • member_name(x: Value | Placeholder | Type) -> Any. Disponible para todas las definiciones de miembros member_name de todos los tipos. Por ejemplo, tensor_element_type(x) muestra la parte TensorElementType de un TensorType correspondiente. Si x es un valor o un marcador de posición, esta función es un acceso directo para member_name(type(x)). Si x no es un tipo que tenga un miembro adecuado, o un valor o un marcador de posición de ese tipo, muestra None.

Construcción de valores

  • operation_name(*xs: Value | Type) -> Value. Disponible para todas las operaciones. Por ejemplo, add(lhs, rhs) toma dos valores de tensor lhs y rhs, y muestra el resultado de evaluar la operación add con estas entradas. Para algunas operaciones, p. ej., broadcast_in_dim, los tipos de sus resultados son “de carga”, es decir, necesarios para evaluar una operación. En este caso, la función toma estos tipos como argumentos.

Función en valores

  • Todos los operadores y las funciones de Python están disponibles. P.ej., las anotaciones de suscripción y segmentación de Python están disponibles para indexarse en tensores, tensores cuantificados y de tuplas.

  • to_destination_type(x: Value, destination_type: Type) -> Value se define en los tensores y muestra el valor convertido de x según type(x) y destination_type de la siguiente manera:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Hay un debate anticipado sobre la combinación de operaciones convert, uniform_quantize y uniform_dequantize (#1576). Después de la combinación, no necesitamos la función anterior y, en su lugar, podemos usar el nombre de la operación para convert.

  • is_nan(x: Value) -> Value se define en los tensores y muestra true si todos los elementos de x son NaN o, de lo contrario, false. Si x no es un tensor, muestra None.

  • is_sorted(x: Value) -> Value se define en los tensores y muestra true si los elementos de x se ordenan de manera ascendente con respecto al orden ascendente lexicográfico de sus índices o, en caso contrario, false. Si x no es un tensor, muestra None.

  • is_unique(x: Value) -> Value se define en los tensores y muestra true si x no tiene elementos duplicados o false de lo contrario. Si x no es un tensor, muestra None.

  • member_name(x: Value) -> Any se define para todas las definiciones de miembros member_name de todos los valores. Por ejemplo, real_part(x) muestra la parte RealPart de un ComplexConstant correspondiente. Si x no es un valor que tenga un miembro apropiado, muestra None.

  • same(x: Value) -> Value se define en los tensores y muestra true si todos los elementos de x son iguales entre sí o, de lo contrario, muestra false. Si el tensor no tiene elementos, se cuenta como "todos son iguales entre sí", es decir, la función muestra true. Si x no es un tensor, muestra None.

  • split(x: Value, num_results: Value, axis: Value) -> Value se define en los tensores y muestra porciones num_results de x a lo largo del eje axis. Si x no es un tensor o dim(x, axis) % num_results != 0, muestra None.

Cálculos de formas

  • axes(x: Value | Placeholder | Type) -> Value es un acceso directo para range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value es un acceso directo para shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List es un acceso directo para list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value se define en los tensores y muestra índices size(x) para el TensorType correspondiente ordenado en orden lexicográfico ascendente, es decir, [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Si x no es un tipo de tensor, un tipo de tensor cuantificado, un valor o un marcador de posición de uno de estos tipos, muestra None.

  • rank(x: Value | Placeholder | Type) -> Value es un acceso directo para size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value se define en la sección "Funciones en tipos" mediante member_name.

  • size(x: Value | Placeholder | Type) -> Value es un acceso directo para reduce(lambda x, y: x * y, shape(x)).

Cálculos de cuantización

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type es un acceso directo para element_type(baseline_type(x)).

  • baseline_type se define en los tipos de tensores y de tipos de tensores cuantificados y los transforma en un "modelo de referencia", es decir, un tipo con la misma forma, pero con los parámetros de cuantización del tipo de elemento restablecidos a los valores predeterminados. Esto se usa como un truco útil para comparar los tipos de tensor y los cuantizados de manera uniforme, lo que es necesario con bastante frecuencia. Para los tipos cuantizados, esto permite comparar tipos que ignoran los parámetros de cuantización, es decir, shape, storage_type, expressed_type, storage_min, storage_max y quantization_dimension (para el tipo cuantizado por eje) deben coincidir, pero scales y zero points pueden diferir.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize se define en tipos de tensores cuantizados y los convierte en tipos de tensores de punto flotante. Esto sucede a través de la conversión de elementos cuantizados que representan valores enteros del tipo de almacenamiento en valores de punto flotante correspondientes del tipo expresado mediante el punto cero y la escala asociados con el tipo de elemento cuantificado.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize se define en tipos de tensores de punto flotante y los convierte en tipos de tensores cuantizados. Esto sucede cuando se convierten los valores de punto flotante del tipo expresado en valores de números enteros correspondientes del tipo de almacenamiento con el punto cero y la escala asociados con el tipo de elemento cuantificado.
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize se usa para especificar cálculos a nivel de elementos en tensores cuantificados. Simplifica, es decir, convierte los elementos cuantificados en sus tipos expresados, realiza una operación y, luego, cuantiza los resultados, es decir, convierte los resultados en sus tipos de almacenamiento. Por el momento, esta función solo funciona con la cuantización por tensor. La cuantización por eje es un trabajo en curso (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

Cálculos de cuadrícula

  • cross_partition(replica_groups: Value) -> Value. Consulta la sección "cross_replica" anterior.

  • cross_replica(replica_groups: Value) -> Value. Consulta la sección "cross_replica" anterior.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Consulta la sección "cross_replica_and_partition" más arriba.

  • flattened_ids(replica_groups: Value) -> Value. Consulta la sección "Flated_ids" anterior.