Especificação StableHLO

StableHLO é um conjunto para operações de alto nível (HLO, na sigla em inglês) em modelos de machine learning (ML). StableHLO funciona como uma camada de portabilidade entre diferentes frameworks e compiladores de ML: frameworks de ML que produzem programas StableHLO são compatíveis com compiladores de ML que consomem programas StableHLO.

Nosso objetivo é simplificar e acelerar o desenvolvimento de ML criando mais interoperabilidade entre vários frameworks de ML (como TensorFlow, JAX e PyTorch) e compiladores de ML (como XLA e IREE). Para isso, este documento fornece uma especificação para a linguagem de programação StableHLO.

Essa especificação contém três seções principais. Primeiro, a seção Programas descreve a estrutura dos programas do StableHLO, que consistem em funções do StableHLO, que também consistem em operações do StableHLO. Nessa estrutura, a seção Ops especifica a semântica de operações individuais. A seção Execution (execução) inclui semântica para todas essas operações executadas em conjunto em um programa. Por fim, a seção Notação discute a notação usada em toda a especificação.

Programas

Program ::= {Func}

Os programas StableHLO consistem em um número arbitrário de funções do StableHLO. Veja abaixo um programa de exemplo com uma função @main que tem três entradas (%image, %weights e %bias) e uma saída. O corpo da função tem 6 ops.

func.func @main(
  %image: tensor<28x28xf32>,
  %weights: tensor<784x10xf32>,
  %bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
  %0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
  %1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
  %2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  %3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
  %4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
  "func.return"(%4): (tensor<1x10xf32>) -> ()
}

Funções

Func        ::= 'func' '.' 'func' FuncId FuncInputs FuncOutputs '{' FuncBody '}'
FuncInputs  ::= '(' [FuncInput {',' FuncInput}] `)`
FuncInput   ::= '%' ValueId ':' ValueType
FuncOutputs ::= ['->' FuncOutput, {',' FuncOutput}]
FuncOutput  ::= ValueType
FuncBody    ::= {Op}

As funções StableHLO, também chamadas de funções nomeadas, têm um identificador, entradas/saídas e um corpo. No futuro, planejamos introduzir outros metadados para funções para ter melhor compatibilidade com HLO (#425, #626, #740 #744).

Identificadores

FuncId  ::= '@' letter {letter | digit}
ValueId ::= '%' digit {digit}
          | '%' letter {letter | digit}
letter  ::= 'a' | ... | 'z' | 'A' | ... | 'Z' | '_'
digit   ::= '0' | ... | '9'

Os identificadores StableHLO são semelhantes aos de várias linguagens de programação, com duas peculiaridades: 1) todos os identificadores têm sigils que distinguem diferentes tipos de identificadores, 2) os identificadores de valor podem ser completamente numéricos para simplificar a geração de programas StableHLO.

Tipos

Type         ::= ValueType | NonValueType
ValueType    ::= TensorType | QuantizedTensorType | TokenType | TupleType
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType

Os tipos de StableHLO são categorizados em tipos de valor, que também são chamados de tipos de primeira classe, que representam valores de StableHLO, e tipos sem valor, que descrevem outros elementos do programa. Os tipos StableHLO são semelhantes aos tipos em muitas linguagens de programação.A principal peculiaridade é a natureza específica do domínio de StableHLO, que resulta em alguns resultados incomuns. Por exemplo, tipos escalares não são tipos de valor.

TensorType ::= 'tensor' '<' Shape TensorElementType '>'
Shape ::= {DimensionSize 'x'}
DimensionSize ::= digit {digit}

Os tipos de tensor representam tensores, ou seja, matrizes multidimensionais. Elas têm uma forma e um tipo de elemento, em que uma forma representa tamanhos de dimensão não negativos na ordem crescente das dimensões correspondentes, que também são chamadas de eixos, numeradas de 0 a R-1. O número de dimensões R é chamado de rank. Por exemplo, tensor<2x3xf32> é um tipo de tensor com a forma 2x3 e o tipo de elemento f32. Ela tem duas dimensões (ou, em outras palavras, dois eixos): 0a e 1a dimensão, com tamanhos 2 e 3. A classificação é 2.

Isso define o suporte para formas estáticas, em que os tamanhos das dimensões são estaticamente conhecidos. No futuro, planejamos também oferecer suporte a formas dinâmicas, em que os tamanhos das dimensões são parcial ou totalmente desconhecidos (#8). Além disso, planejamos estender os tipos de tensores além dos tamanhos de dimensão e tipos de elemento, por exemplo, para incluir layouts (#629) e esparsidade (#1078).

QuantizedTensorType ::= 'tensor' '<' Shape QuantizedTensorElementType '>'
QuantizedTensorElementType ::= '!quant.uniform' '<'
                  QuantizationStorageType
                  ['<' QuantizationStorageMin ':' QuantizationStorageMax '>']
                  ':' QuantizationExpressedType
                  [':' QuantizationDimension]
                  ',' QuantizationParameters '>'
QuantizationStorageType ::= IntegerType
QuantizationStorageMin ::= IntegerConstant
QuantizationStorageMax ::= IntegerConstant
QuantizationExpressedType ::= FloatType
QuantizationDimension ::= IntegerConstant
QuantizationParameters ::= QuantizationParameter
                         | '{' QuantizationParameter {',' QuantizationParameter} '}'
QuantizationParameter ::= QuantizationScale ':' QuantizationZeroPoint
QuantizationScale ::= FloatConstant
QuantizationZeroPoint ::= IntegerConstant
Nome Tipo Restrições
storage_type tipo de número inteiro (C1 a C4) (C9)
storage_min constante de número inteiro (C2), (C4) (C8)
storage_max constante de número inteiro (C3), (C4) (C8)
expressed_type tipo de ponto flutuante (C1) e (C5)
quantization_dimension constante opcional de número inteiro (C11 a C13)
scales número variado de constantes de ponto flutuante (C5-C7), (C10), (C11) (C13)
zero_points número variável de constantes inteiras (C8 a C10)

Os tipos de elementos quantizados representam valores inteiros de um tipo de armazenamento no intervalo de storage_min a storage_max (inclusivo) que correspondem aos valores de ponto flutuante de um tipo expresso. Para um determinado valor inteiro i, o valor de ponto flutuante correspondente f pode ser calculado como f = (i - zero_point) * scale, em que scale e zero_point são chamados de parâmetros de quantização. storage_min e storage_max são opcionais na gramática, mas têm valores padrão de min_value(storage_type) e max_value(storage_type), respectivamente. Os tipos de elementos quantizados têm as seguintes restrições:

  • (C1) num_bits(storage_type) < num_bits(expressed_type).
  • (C2) type(storage_min) = storage_type.
  • (C3) type(storage_max) = storage_type.
  • (C4) min_value(storage_type) <= storage_min < storage_max <= max_value(storage_type).
  • (C5) type(scales...) = expressed_type.
  • (C6) 0 < scales.
  • (C7) is_finite(scales...).
  • (C8) storage_min <= zero_points <= storage_max.
  • (C9) type(zero_points...) = storage_type.
  • (C10) size(scales) = size(zero_points).
  • (C11) Se is_empty(quantization_dimension), então size(scales) = 1.
  • (C12) 0 <= quantization_dimension.

No momento, QuantizationScale é uma constante de ponto flutuante, mas há um forte interesse em escalas baseadas em números inteiros, representadas por multiplicadores e mudanças. Planejamos explorar isso em um futuro próximo (#1404).

Há uma discussão contínua sobre a semântica de QuantizationZeroPoint, incluindo o tipo, os valores e se pode haver apenas um ou potencialmente vários pontos zero em um tipo de tensor quantizado. Com base nos resultados dessa discussão, a especificação sobre zero pontos pode mudar no futuro (#1405).

Outra discussão em andamento envolve a semântica de QuantizationStorageMin e QuantizationStorageMax para determinar se alguma restrição precisa ser imposta a esses valores e aos valores de tensores quantizados (#1406).

Por fim, planejamos explorar a representação de escalas desconhecidas e pontos zero, da mesma forma que planejamos representar tamanhos de dimensão desconhecidos (#1407).

Os tipos de tensores quantizados representam tensores com elementos quantizados. Esses tensores são exatamente os mesmos que os regulares, exceto pelo fato de que os elementos têm tipos de elemento quantizados, em vez de tipos de elementos regulares.

Em tensores quantizados, a quantização pode ser por tensor, o que significa ter um scale e zero_point para o tensor inteiro, ou pode ser por eixo, ou seja, ter vários scales e zero_points, um par por fração de uma determinada dimensão quantization_dimension. Mais formalmente, em um tensor t com quantização por eixo, há frações dim(t, quantization_dimension) do quantization_dimension: t[:, ..., 0, ..., :], t[:, ..., 1, ..., :] etc. Todos os elementos na ia fração usam scales[i] e zero_points[i] como seus parâmetros de quantização. Os tipos de tensores quantizados têm as seguintes restrições:

  • Para a quantização por tensor:
    • Sem restrições adicionais.
  • Para a quantização por eixo:
    • (C12) quantization_dimension < rank(self).
    • (C13) dim(self, quantization_dimension) = size(scales).
TokenType ::= 'token'

Os tipos de token representam tokens, ou seja, valores opacos produzidos e consumidos por algumas operações. Os tokens são usados para impor uma ordem de execução a operações, conforme descrito na seção Execução.

TupleType ::= 'tuple' '<' TupleElementTypes '>'
TupleElementTypes ::= [ValueType {',' ValueType}]

Os tipos de tuplos representam tuplas, ou seja, listas heterogêneas. As tuplas são um recurso legado que só existe para compatibilidade com HLO. No HLO, as tuplas são usadas para representar entradas e saídas variadas. No StableHLO, entradas e saídas variáveis têm suporte nativo, e o único uso de tuplas em StableHLO é representar de forma abrangente a ABI HLO, em que, por exemplo, T, tuple<T> e tuple<tuple<T>> podem ser significativamente diferentes, dependendo de uma implementação específica. No futuro, estamos planejando fazer alterações na HLO ABI que poderão nos permitir remover tipos de tupla do StableHLO (#598).

TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
            | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'

Tipos de elementos representam elementos de tipos de tensor. Ao contrário de muitas linguagens de programação, esses tipos não são de primeira classe em StableHLO. Isso significa que os programas StableHLO não podem representar diretamente os valores desses tipos. Por isso, é idiomático representar valores escalares do tipo T com valores de tensores 0-dimensionais do tipo tensor<T>.

  • O tipo booleano representa os valores booleanos true e false.
  • Os tipos de números inteiros podem ser assinados (si) ou não assinados (ui) e ter uma das larguras de bits compatíveis (4, 8, 16, 32 ou 64). Os tipos siN assinados representam valores inteiros de -2^(N-1) a 2^(N-1)-1 inclusivos, e os tipos uiN não assinados representam valores inteiros de 0 a 2^N-1, inclusive.
  • Os tipos de ponto flutuante podem ser um dos seguintes:
  • Os tipos complexos representam valores complexos que têm uma parte real e uma parte imaginária do mesmo tipo de elemento. Os tipos complexos com suporte são complex<f32> (as duas partes são do tipo f32) e complex<f64> (as duas partes são do tipo f64).
FunctionType ::= '(' InputTypes ')' '->' '(' OutputTypes ')'
InputTypes ::= [ValueType {',' ValueType}]
OutputTypes ::= [ValueType {',' ValueType}]

Os tipos de função representam funções nomeadas e anônimas. Elas têm tipos de entrada (a lista de tipos no lado esquerdo de ->) e tipos de saída (a lista de tipos no lado direito de ->). Em muitas linguagens de programação, os tipos de função são de primeira classe, mas não em StableHLO.

StringType ::= 'string'

O tipo de string representa sequências de bytes. Ao contrário de muitas linguagens de programação, o tipo de string não é de primeira classe em StableHLO e é usado apenas para especificar metadados estáticos de elementos do programa.

Operações

As operações de StableHLO, também chamadas de ops, representam um conjunto fechado de operações de alto nível em modelos de machine learning. Como discutido acima, a sintaxe de StableHLO é muito inspirada no MLIR, que não é necessariamente a alternativa mais ergonômica, mas é provavelmente a melhor opção para o objetivo do StableHLO de criar mais interoperabilidade entre os frameworks de ML e os compiladores de ML.

Op            ::= [OpOutputs] OpName OpInputs ':' OpSignature
OpName        ::= '"' 'stablehlo' '.' OpMnemonic '"'
OpMnemonic    ::= 'abs' | 'add' | ...

As operações StableHLO, também chamadas de ops, têm um nome, entradas/saídas e uma assinatura. O nome consiste no prefixo stablehlo. e um mnemônico que identifica exclusivamente uma das operações compatíveis. Veja abaixo uma lista abrangente de todas as operações compatíveis.

No momento, os programas StableHLO em estado selvagem às vezes contêm operações que não são descritas neste documento. No futuro, planejamos absorver essas operações na opset do StableHLO ou proibir que elas apareçam nos programas do StableHLO. Enquanto isso, aqui está a lista dessas operações:

  • builtin.module, func.func, func.call e func.return (#425).
  • Operações chlo (#602).
  • Categoria "Not in HLO" de operações StableHLO. Inicialmente, elas faziam parte da opset StableHLO, mas depois foram considerados que não se encaixavam bem: broadcast, create_token, cross-replica-sum, dot, einsum, torch_index_select e unary_einsum (#3).
  • Categoria "Dinamismo" de operações StableHLO. Elas foram inicializadas a partir da MHLO, mas ainda não as especificamos: compute_reshape_shape, cstr_reshapable, dynamic_broadcast_in_dim, dynamic_conv, dynamic_gather, dynamic_iota, dynamic_pad, dynamic_reshape, real_dynamic_slice, set_dimension_size (#8).
  • Computações de formas, incluindo as operações arith, shape e tensor (#8).
OpInputs        ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues   ::= '(' [OpInputValue {',' OpInputValue}] ')'
OpInputValue    ::= ValueId
OpInputFuncs    ::= ['(' OpInputFunc {',' OpInputFunc} ')']
OpInputAttrs    ::= ['{' OpInputAttr {',' OpInputAttr} '}']
OpOutputs       ::= [OpOutput {',' OpOutput} '=']
OpOutput        ::= ValueId

As operações consomem entradas e produzem saídas. As entradas são categorizadas em valores de entrada (calculados durante a execução), funções de entrada (fornecidos estaticamente, porque as funções de StableHLO não são valores de primeira classe) e atributos de entrada (também fornecidos estaticamente). O tipo de entradas e saídas consumidos e produzidos por uma operação depende do mnemônico dela. Por exemplo, a operação add consome dois valores de entrada e produz um valor de saída. Em comparação, a operação select_and_scatter consome três valores de entrada, duas funções de entrada e três atributos de entrada.

OpInputFunc ::= '{' Unused FuncInputs ':' FuncBody '}'
Unused      ::= '^' digit {digit}
              | '^' letter {letter | digit}

As funções de entrada, também chamadas de funções anônimas, são muito semelhantes às funções nomeadas. A diferença é que: 1) elas não têm um identificador (por isso, o nome "anônimo"), 2) não declaram tipos de saída (os tipos de saída são inferidos da op return dentro da função).

A sintaxe das funções de entrada inclui uma parte não utilizada no momento (consulte a produção de Unused acima) que existe para compatibilidade com o MLIR. Em MLIR, há um conceito mais geral de "regiões" que pode ter vários "blocos" de operações conectados entre si por operações de salto. Esses blocos têm IDs que correspondem à produção de Unused, para que possam ser diferenciados entre si. O StableHLO não tem operações de salto. Portanto, a parte correspondente da sintaxe de MLIR não é usada, mas ainda está presente.

OpInputAttr      ::= OpInputAttrName '=' OpInputAttrValue
OpInputAttrName  ::= letter {letter | digit}
OpInputAttrValue ::= Constant

Os atributos de entrada têm um nome e um valor, que é uma das constantes compatíveis. Eles são a principal maneira de especificar metadados estáticos para elementos do programa. Por exemplo, a operação concatenate usa o atributo dimension para especificar a dimensão com que os valores de entrada são concatenados. Da mesma forma, a operação slice usa vários atributos, como start_indices e limit_indices, para especificar os limites usados para dividir o valor de entrada.

No momento, os programas StableHLO à disposição às vezes contêm atributos que não são descritos neste documento. No futuro, planejamos absorver esses atributos na opset do StableHLO ou proibir que eles apareçam nos programas do StableHLO. Enquanto isso, veja a lista destes atributos:

  • layout (#629).
  • mhlo.frontend_attributes (#628).
  • mhlo.sharding (#619).
  • output_operand_aliases (#740).
  • Metadados de local (#594).
OpSignature ::= '(' [ValueType {',' ValueType}] ')' '->' '(' [ValueType {',' ValueType}] ')'

A assinatura de operações consiste nos tipos de todos os valores de entrada (a lista de tipos no lado esquerdo de ->) e nos tipos de todos os valores de saída (a lista de tipos no lado direito de ->). Estritamente, os tipos de entrada são redundantes e os de saída também são quase sempre redundantes, porque para a maioria das operações StableHLO, os tipos de saída podem ser inferidos a partir das entradas. No entanto, a assinatura de operação faz parte deliberadamente da sintaxe do StableHLO para compatibilidade com o MLIR.

Confira abaixo um exemplo de operação com mnemônico select_and_scatter. Ela consome três valores de entrada (%operand, %source e %init_value), duas funções de entrada e três atributos de entrada (window_dimensions, window_strides e padding). Observe como a assinatura da operação inclui apenas os tipos dos valores de entrada, mas não os tipos de funções e atributos de entrada fornecidos inline.

%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i32>, tensor<i32>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    "stablehlo.return"(%0) : (tensor<i32>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x2xi32>

Constantes

Constant ::= BooleanConstant
           | IntegerConstant
           | FloatConstant
           | ComplexConstant
           | TensorConstant
           | QuantizedTensorConstant
           | StringConstant
           | EnumConstant

As constantes de StableHLO têm um literal e um tipo que, juntos, representam um valor de StableHLO. Geralmente, o tipo faz parte da sintaxe constante, exceto quando não é ambíguo (por exemplo, uma constante booleana tem inequivocamente o tipo i1, enquanto uma constante de número inteiro pode ter vários tipos possíveis).

BooleanConstant ::= BooleanLiteral
BooleanLiteral  ::= 'true' | 'false'

As constantes booleanas representam os valores booleanos true e false. As constantes booleanas têm o tipo i1.

IntegerConstant   ::= IntegerLiteral ':' IntegerType
IntegerLiteral    ::= ['-' | '+'] DecimalDigits
                    | ['-' | '+'] '0x' HexadecimalDigits
DecimalDigits     ::= decimalDigit {decimalDigit}
HexadecimalDigits ::= hexadecimalDigit {hexadecimalDigit}
decimalDigit      ::= '0' | ... | '9'
hexadecimalDigit  ::= decimalDigit | 'a' | ... | 'f' | 'A' | ... | 'F'

Constantes inteiras representam valores inteiros por meio de strings que usam notação decimal ou hexadecimal. Outras bases, como binária ou octal, não são aceitas. Constantes inteiras têm as seguintes restrições:

  • (C1) is_wellformed(integer_literal, integer_type).
FloatConstant  ::= FloatLiteral ':' FloatType
FloatLiteral   ::= SignPart IntegerPart FractionalPart ScientificPart
                 | '0x' [HexadecimalDigits]
SignPart       ::= ['-' | '+']
IntegerPart    ::= DecimalDigits
FractionalPart ::= ['.' [DecimalDigits]]
ScientificPart ::= [('e' | 'E') ['-' | '+'] DecimalDigits]

Constantes de ponto flutuante representam valores de ponto flutuante por meio de strings que usam notação decimal ou científica. Além disso, a notação hexadecimal pode ser usada para especificar diretamente os bits subjacentes no formato de ponto flutuante do tipo correspondente. As constantes de ponto flutuante têm as seguintes restrições:

  • (C1) Se a notação não hexadecimal for usada, is_wellformed(float_literal, float_type).
  • (C2) Se a notação hexadecimal for usada, size(hexadecimal_digits) = num_bits(float_type) / 4.
ComplexConstant ::= ComplexLiteral ':' ComplexType
ComplexLiteral  ::= '(' RealPart ',' ImaginaryPart ')'
RealPart        ::= FloatLiteral
ImaginaryPart   ::= FloatLiteral

Constantes complexas representam valores complexos usando listas de uma parte real (vem primeiro) e uma parte imaginária (vem depois). Por exemplo, (1.0, 0.0) : complex<f32> representa 1.0 + 0.0i e (0.0, 1.0) : complex<f32> representa 0.0 + 1.0i. A ordem em que essas partes são armazenadas na memória é definida pela implementação. Constantes complexas têm as seguintes restrições:

  • (C1) is_wellformed(real_part, complex_element_type(complex_type)).
  • (C2) is_wellformed(imaginary_part, complex_element_type(complex_type)).
TensorConstant ::= TensorLiteral ':' TensorType
TensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'
DenseLiteral   ::= DenseDimension | DenseElements
DenseDimension ::= '[' [DenseLiteral {',' DenseLiteral}] ']'
DenseElements  ::= [ElementLiteral {',' ElementLiteral}]
ElementLiteral ::= BooleanLiteral | IntegerLiteral | FloatLiteral | ComplexLiteral

As constantes de tensor representam os valores de tensor que usam listas aninhadas especificadas pela notação NumPy. Por exemplo, dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> representa um valor de tensor com o seguinte mapeamento de índices para elementos: {0, 0} => 1, {0, 1} => 2, {0, 2} => 3, {1, 0} => 4, {1, 1} => 5, {1, 2} => 6. A ordem em que esses elementos são armazenados na memória é definida pela implementação. As constantes do tensor têm as seguintes restrições:

  • (C1) has_syntax(tensor_literal, element_type(tensor_type)), em que:
    • has_syntax(element_literal: Syntax, element_type: Type) = is_wellformed(element_literal, type).
    • has_syntax(tensor_literal: List, element_type: Type) = has_syntax(tensor_literal..., element_type).
  • (C2) has_shape(tensor_literal, shape(tensor_type)), em que:
    • has_shape(element_literal: Syntax, []) = true.
    • has_shape(tensor_literal: List, shape: List) = size(tensor_literal) = shape[0] and has_shape(tensor_literal..., shape[1:]).
    • Caso contrário, false.
QuantizedTensorConstant ::= QuantizedTensorLiteral ':' QuantizedTensorType
QuantizedTensorLiteral  ::= 'dense' '<' (DenseLiteral | ElementLiteral) '>'

As constantes de tensores quantizadas representam os valores de tensores quantizados usando a mesma notação das constantes de tensor, com elementos especificados como constantes do tipo de armazenamento. As constantes de tensor quantizadas têm as seguintes restrições:

  • (C1) has_syntax(quantized_tensor_literal, storage_type(quantized_tensor_type)).
  • (C2) has_shape(quantized_tensor_literal, shape(quantized_tensor_type)).
StringConstant  ::= StringLiteral
StringLiteral   ::= '"' {stringCharacter | escapeSequence} '"'
stringCharacter ::= all ASCII characters except '\00', '\01', ... '\1f' and '"'
escapeSequence  ::= '\' ('"' | '\' | 'n' | 't' | (hexadecimalDigit hexadecimalDigit))

literais de string consistem em bytes especificados usando caracteres ASCII e sequências de escape. Eles não dependem de codificação, portanto, a interpretação desses bytes é definida pela implementação. Os literais de string têm o tipo string.

Operações

abs

Semântica

Executa uma operação abs com elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para números inteiros com sinal: módulo de número inteiro.
  • Para dados flutuantes: abs do IEEE-754.
  • Para números complexos: módulo complexo.
  • Para tipos quantizados: dequantize_op_quantize(abs, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de número inteiro assinado, ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1 a C2)

Saídas

Nome Tipo Restrições
result tensor de número inteiro assinado ou tipo de ponto flutuante ou tensor quantizado por tensor (C1 a C2)

Restrições

  • (C1) shape(result) = shape(operand).
  • (C2) baseline_element_type(result) é definida como:
    • complex_element_type(element_type(operand)) se is_complex(operand).
    • Caso contrário, baseline_element_type(operand).

Exemplos

// %operand: [-2, 0, 2]
%result = "stablehlo.abs"(%operand) : (tensor<3xi32>) -> tensor<3xi32>
// %result: [2, 0, 2]

Mais exemplos

adicionar

Semântica

Executa a adição com elementos de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: OR lógico.
  • Para números inteiros: adição de números inteiros.
  • Para dados flutuantes: addition do IEEE-754.
  • Para números complexos: adição complexa.
  • Para tipos quantizados: dequantize_op_quantize(add, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor ou tensor quantizado por tensor (C1)
(I2) rhs tensor ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemplos

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[6, 8], [10, 12]]

Mais exemplos

after_all

Semântica

Garante que as operações que produzem o inputs sejam executadas antes de qualquer operação que dependa de result. A execução dessa operação não faz nada, ela existe apenas para estabelecer dependências de dados de result a inputs.

Entradas

Rótulo Nome Tipo
(I1) inputs número variável de token

Saídas

Nome Tipo
result token

Exemplos

// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

Mais exemplos

all_gather

Semântica

Dentro de cada grupo de processos na grade de processo do StableHLO, os valores do tensor operand são concatenados de cada processo com all_gather_dim e produz um tensor result.

A operação divide a grade do processo StableHLO em process_groups, que é definida da seguinte maneira:

  • cross_replica(replica_groups) se channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) se channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) se channel_id > 0 and use_global_device_ids = true.

Em seguida, em cada process_group:

  • operands@receiver = [operand@sender for sender in process_group] para todos os receiver em process_group.
  • result@process = concatenate(operands@process, all_gather_dim) para todos os process em process_group.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1) e (C6)
(I2) all_gather_dim constante do tipo si64 (C1) e (C6)
(I3) replica_groups constante do tensor bidimensional do tipo si64 (C2 a C4)
(I4) channel_id constante do tipo si64 (C5)
(I5) use_global_device_ids constante do tipo i1 (C5)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C6)

Restrições

  • (C1) 0 <= all_gather_dim < rank(operand).
  • (C2) is_unique(replica_groups).
  • (C3) size(replica_groups) é definida como:
    • num_replicas se cross_replica for usado.
    • num_replicas se cross_replica_and_partition for usado.
    • num_processes se flattened_ids for usado.
  • (C4) 0 <= replica_groups < size(replica_groups).
  • (C5) Se use_global_device_ids = true, então channel_id > 0.
  • (C6) type(result) = type(operand) exceto:
    • dim(result, all_gather_dim) = dim(operand, all_gather_dim) * dim(process_groups, 1).

Exemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
  all_gather_dim = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  // channel_id = 0
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
  // use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]

Mais exemplos

all_reduce

Semântica

Dentro de cada grupo de processos na grade de processos do StableHLO, aplica-se uma função de redução computation aos valores do tensor operand de cada processo e produz um tensor result.

A operação divide a grade do processo StableHLO em process_groups, que é definida da seguinte maneira:

  • cross_replica(replica_groups) se channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) se channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) se channel_id > 0 and use_global_device_ids = true.

Em seguida, em cada process_group:

  • result@process[result_index] = exec(schedule) para alguma árvore binária schedule em que:
    • exec(node) = computation(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule é uma árvore binária definida pela implementação cuja travessia em ordem é to_destination_type(operands@process_group...[result_index], type(func_inputs(computation)[0])).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C5) e (C6)
(I2) replica_groups número variado de constantes de tensor unidimensionais do tipo si64 (C1 a C3)
(I3) channel_id constante do tipo si64 (C4)
(I4) use_global_device_ids constante do tipo i1 (C4)
(I5) computation função (C5)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C6 a C7)

Restrições

  • (C1) is_unique(replica_groups).
  • (C2) size(replica_groups) é definida como:
    • num_replicas se cross_replica for usado.
    • num_replicas se cross_replica_and_partition for usado.
    • num_processes se flattened_ids for usado.
  • (C3) 0 <= replica_groups < size(replica_groups).
  • (C4) Se use_global_device_ids = true, então channel_id > 0.
  • (C5) computation tem o tipo (tensor<E>, tensor<E>) -> (tensor<E>), em que is_promotable(element_type(operand), E).
  • (C6) shape(result) = shape(operand).
  • (C7) element_type(result) = E.

Exemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<i64>) -> tensor<i64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]

Mais exemplos

all_to_all

Semântica

Dentro de cada grupo de processos na grade de processo StableHLO, divide os valores do tensor operand ao longo de split_dimension em partes, dispersa as partes divididas entre os processos, concatena as partes dispersas ao longo de concat_dimension e produz um tensor result.

A operação divide a grade do processo StableHLO em process_groups, que é definida da seguinte maneira:

  • cross_replica(replica_groups) se channel_id <= 0.
  • cross_partition(replica_groups) se channel_id > 0.

Em seguida, em cada process_group:

  • split_parts@sender = split(operand@sender, split_count, split_dimension) para todos os sender em process_group.
  • scattered_parts@receiver = [split_parts@sender[receiver_index] for sender in process_group], em que receiver_index = process_group.index(receiver).
  • result@process = concatenate(scattered_parts@process, concat_dimension).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1 a C3) (C9)
(I2) split_dimension constante do tipo si64 (C1), (C2) e (C9)
(I3) concat_dimension constante do tipo si64 (C3) e (C9)
(I4) split_count constante do tipo si64 (C2), (C4), (C8) (C9)
(I5) replica_groups constante do tensor bidimensional do tipo si64 (C5 a C8)
(I6) channel_id constante do tipo si64

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C9)

Restrições

  • (C1) 0 <= split_dimension < rank(operand).
  • (C2) dim(operand, split_dimension) % split_count = 0.
  • (C3) 0 <= concat_dimension < rank(operand).
  • (C4) 0 < split_count.
  • (C5) is_unique(replica_groups).
  • (C6) size(replica_groups) é definida como:
    • num_replicas se cross_replica for usado.
    • num_partitions se cross_partition for usado.
  • (C7) 0 <= replica_groups < size(replica_groups).
  • (C8) dim(replica_groups, 1) = split_count.
  • (C9) type(result) = type(operand) exceto:
    • dim(result, split_dimension) = dim(operand, split_dimension) / split_count.
    • dim(result, concat_dimension) = dim(operand, concat_dimension) * split_count.

Exemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.all_to_all"(%operand) {
  split_dimension = 1 : i64,
  concat_dimension = 0 : i64,
  split_count = 2 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
// %result@(0, 0): [[1, 2],
//                  [5, 6],
//                  [9, 10],
//                  [13, 14]]
// %result@(1, 0): [[3, 4],
//                  [7, 8],
//                  [11, 12],
//                  [15, 16]]

Mais exemplos

e

Semântica

Executa o elemento AND de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: AND lógico.
  • Para números inteiros: AND bit a bit.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs Tensor do tipo booleano ou inteiro (C1)
(I2) rhs Tensor do tipo booleano ou inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo booleano ou inteiro (C1)

Restrições

  • (C1) type(lhs) = type(rhs) = type(result).

Exemplos

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 0]]

atan2

Semântica

Executa a operação atan2 com elementos nos tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: atan2 do IEEE-754.
  • Para números complexos: atan2 complexo.
  • Para tipos quantizados: dequantize_op_quantize(atan2, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)
(I2) rhs tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemplos

// %lhs: [0.0, 1.0, -1.0]
// %rhs: [0.0, 0.0, 0.0]
%result = "stablehlo.atan2"(%lhs, %rhs) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
// %result: [0.0, 1.57079637, -1.57079637] // [0.0, pi/2, -pi/2]

Mais exemplos

batch_norm_grad

Semântica

Calcula os gradientes de várias entradas de batch_norm_training retropropagando de grad_output e produz tensores grad_operand, grad_scale e grad_offset. Mais formalmente, essa operação pode ser expressa como uma decomposição para operações StableHLO existentes usando a sintaxe do Python da seguinte maneira:

def compute_sum(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  return sum

def compute_mean(operand, feature_index):
  sum = compute_sum(operand, feature_index)
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index):
  # Broadcast inputs to type(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance`
  # Intermediate values will be useful for computing gradients
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)

  # Use the implementation from batchnorm_expander.cc in XLA
  # Temporary variables have exactly the same names as in the C++ code
  elements_per_feature = broadcast_in_dim(
      constant(divide(size(operand), dim(operand, feature_index)),
               element_type(grad_output)),
      [], type(operand))
  i1 = multiply(grad_output, elements_per_feature)
  i2 = broadcast_in_dim(
      compute_sum(grad_output, feature_index), [feature_index], type(operand))
  i3 = broadcast_in_dim(
      compute_sum(multiply(grad_output, centered_operand), feature_index),
      [feature_index], type(operand))
  i4 = multiply(i3, centered_operand)
  i5 = divide(i4, add(variance_bcast, epsilon_bcast))
  i6 = subtract(subtract(i1, i2), i5)

  grad_operand =
      multiply(divide(divide(scale_bcast, stddev), elements_per_feature), i6)
  grad_scale =
      compute_sum(multiply(grad_output, normalized_operand), feature_index)
  grad_offset = compute_sum(grad_output, feature_index)

  return grad_operand, grad_scale, grad_offset

Para tipos quantizados, executa dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, mean, variance, grad_output: batch_norm_grad(operand, scale, mean, variance, grad_output, epsilon, feature_index), operand, scale, mean, variance, grad_output, type(grad_operand), type(grad_scale), type(feature_index)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1 a C3) e (C5)
(I2) scale Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2), (C4) (C5)
(I3) mean Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2) e (C4)
(I4) variance Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2) e (C4)
(I5) grad_output tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C2) e (C3)
(I6) epsilon constante do tipo f32
(I7) feature_index constante do tipo si64 (C1) e (C5)

Saídas

Nome Tipo Restrições
grad_operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C2) e (C3)
grad_scale Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2) e (C4)
grad_offset Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2) e (C4)

Restrições

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, mean, variance, grad_output, grad_operand, grad_scale e grad_offset têm o mesmo baseline_element_type.
  • (C3) operand, grad_output e grad_operand têm a mesma forma.
  • (C4) scale, mean, variance, grad_scale e grad_offset têm a mesma forma.
  • (C5) size(scale) = dim(operand, feature_index).

Exemplos

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
// %grad_output: [
//                [[0.1, 0.1], [0.1, 0.1]],
//                [[0.1, 0.1], [0.1, 0.1]]
//               ]
%grad_operand, %grad_scale, %grad_offset =
"stablehlo.batch_norm_grad"(%operand, %scale, %mean, %variance, %grad_output) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>,
     tensor<2x2x2xf64>) -> (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %grad_operand: [
//                 [[0.0, 0.0], [0.0, 0.0]],
//                 [[0.0, 0.0], [0.0, 0.0]]
//                ]
// %grad_scale:  [0.0, 0.0]
// %grad_offset: [0.4, 0.4]

batch_norm_inference

Semântica

Normaliza o tensor operand em todas as dimensões, exceto a dimensão feature_index, e produz um tensor result. Mais formalmente, essa operação pode ser expressa como uma decomposição para operações StableHLO existentes usando a sintaxe do Python da seguinte maneira:

def batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index):
  # Broadcast inputs to shape(operand)
  scale_bcast = broadcast_in_dim(scale, [feature_index], type(operand))
  offset_bcast = broadcast_in_dim(offset, [feature_index], type(operand))
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  variance_bcast = broadcast_in_dim(variance, [feature_index], type(operand))
  epsilon_bcast = broadcast_in_dim(constant(epsilon, element_type(operand)), [],
                                   type(operand))

  # Perform normalization using the provided `mean` and `variance` instead of
  # computing them like `batch_norm_training` does.
  centered_operand = subtract(operand, mean_bcast)
  stddev = sqrt(add(variance_bcast, epsilon_bcast))
  normalized_operand = divide(centered_operand, stddev)
  return add(multiply(scale_bcast, normalized_operand), offset_bcast)

Para tipos quantizados, executa dequantize_op_quantize(lambda operand, scale, offset, mean, variance: batch_norm_inference(operand, scale, offset, mean, variance, epsilon, feature_index), operand, scale, offset, mean, variance, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1 a C7)
(I2) scale Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2) e (C3)
(I3) offset Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2) e (C4)
(I4) mean Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C5)
(I5) variance Tensor unidimensional do tipo quantizado de ponto flutuante ou por tensor (C2) (C6)
(I6) epsilon constante do tipo f32
(I7) feature_index constante do tipo si64 (C1), (C3 a C6)

Saídas

Nome Tipo Restrições
result tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C2) e (C7)

Restrições

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, mean, variance e result têm o mesmo baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(mean) = dim(operand, feature_index).
  • (C6) size(variance) = dim(operand, feature_index).
  • (C7) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
// %mean: [2.0, 3.0]
// %variance: [1.0, 1.0]
%result = "stablehlo.batch_norm_inference"(%operand, %scale, %offset, %mean, %variance) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> tensor<2x2x2xf64>
// %result: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]

batch_norm_training

Semântica

Calcula a média e a variância em todas as dimensões, exceto a dimensão feature_index, e normaliza o tensor operand que produz os tensores output, batch_mean e batch_var. Mais formalmente, essa operação pode ser expressa como uma decomposição para operações StableHLO existentes usando a sintaxe do Python da seguinte maneira:

def compute_mean(operand, feature_index):
  (sum,) = reduce(
      inputs=[operand],
      init_values=[constant(0, element_type(operand))],
      dimensions=[i for i in range(rank(operand)) if i != feature_index],
      body=lambda x, y: add(x, y))
  divisor = constant(size(operand) / dim(operand, feature_index),
                     element_type(operand))
  divisor_bcast = broadcast_in_dim(divisor, [], type(sum))
  return divide(sum, divisor_bcast)

def compute_variance(operand, feature_index):
  mean = compute_mean(operand, feature_index)
  mean_bcast = broadcast_in_dim(mean, [feature_index], type(operand))
  centered_operand = subtract(operand, mean_bcast)
  return compute_mean(mul(centered_operand, centered_operand), feature_index)

def batch_norm_training(operand, scale, offset, epsilon, feature_index):
  mean = compute_mean(operand, feature_index)
  variance = compute_variance(operand, feature_index)
  return batch_norm_inference(operand, scale, offset, mean, variance, epsilon,
                              feature_index),
         mean, variance

Para tipos quantizados, executa dequantize_batch_norm_grad_or_training_quantize(lambda operand, scale, offset: batch_norm_training(operand, scale, offset, epsilon, feature_index), operand, scale, offset, type(output), type(batch_mean), type(batch_var)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)
(I2) scale Tensor unidimensional de ponto flutuante ou por tensor quantizado (C2) e (C3)
(I3) offset Tensor unidimensional de ponto flutuante ou por tensor quantizado (C2) e (C4)
(I4) epsilon constante do tipo f32 (C1), (C3 a C6)
(I5) feature_index constante do tipo si64 (C1), (C3 a C6)

Saídas

Nome Tipo Restrições
output tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C7)
batch_mean Tensor unidimensional de ponto flutuante ou por tensor quantizado (C2) e (C5)
batch_var Tensor unidimensional de ponto flutuante ou por tensor quantizado (C2) (C6)

Restrições

  • (C1) 0 <= feature_index < rank(operand).
  • (C2) operand, scale, offset, batch_mean, batch_var e output têm o mesmo baseline_element_type.
  • (C3) size(scale) = dim(operand, feature_index).
  • (C4) size(offset) = dim(operand, feature_index).
  • (C5) size(batch_mean) = dim(operand, feature_index).
  • (C6) size(batch_var) = dim(operand, feature_index).
  • (C7) baseline_type(output) = baseline_type(operand).

Exemplos

// %operand: [
//            [[1.0, 2.0], [3.0, 4.0]],
//            [[3.0, 4.0], [1.0, 2.0]]
//           ]
// %scale: [1.0, 1.0]
// %offset: [1.0, 1.0]
%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%operand, %scale, %offset) {
  epsilon = 0.0 : f32,
  feature_index = 2 : i64
} : (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>) ->
    (tensor<2x2x2xf64>, tensor<2xf64>, tensor<2xf64>)
// %output: [
//           [[0.0, 0.0], [2.0, 2.0]],
//           [[2.0, 2.0], [0.0, 0.0]]
//          ]
// %batch_mean: [2.0, 3.0]
// %batch_var: [1.0, 1.0]

bitcast_convert

Semântica

Executa uma operação de bitcast no tensor operand e produz um tensor result, em que os bits do tensor operand inteiro são reinterpretados usando o tipo de tensor result.

Mais formalmente, considerando E = element_type(operand), E' = element_type(result) e R = rank(operand):

  • Se for num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
  • Se for num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
  • Se for num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).

bits retorna a representação na memória de um determinado valor, e o comportamento é definido pela implementação, porque a representação exata dos tensores é definida pela implementação, e a representação exata dos tipos de elementos também é definida pela implementação.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado (C1 a C2)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado (C1 a C2)

Restrições

  • (C1) Considerando E = is_quantized(operand) ? storage_type(operand) : element_type(operand), E' = is_quantized(result) ? storage_type(result) : element_type(result) e R = rank(operand):
    • Se for num_bits(E') = num_bits(E), shape(result) = shape(operand).
    • Se num_bits(E') < num_bits(E):
    • rank(result) = R + 1.
    • dim(result, i) = dim(operand, i) para todos os 0 <= i < R.
    • dim(result, R) * num_bits(E') = num_bits(E).
    • Se num_bits(E') > num_bits(E):
    • rank(result) = R - 1.
    • dim(result, i) = dim(operand, i) para todos os 0 <= i < R.
    • dim(operand, R - 1) * num_bits(E) = num_bits(E').
  • (C2) Se is_complex(operand) or is_complex(result), então is_complex(operand) and is_complex(result).

Exemplos

// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation

Mais exemplos

broadcast_in_dim

Semântica

Expande as dimensões e/ou a classificação de um tensor de entrada duplicando os dados no tensor operand e produz um tensor result. Mais formalmente, result[result_index] = operand[operand_index] em que para todos os d em axes(operand):

  • operand_index[d] = 0 se dim(operand, d) = 1.
  • Caso contrário, operand_index[d] = result_index[broadcast_dimensions[d]].

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado (C1 a C2) e (C5 a C6)
(I2) broadcast_dimensions constante do tensor unidimensional do tipo si64 (C2 a C6)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado (C1), (C3) (C5 a C6)

Restrições

  • (C1) element_type(result) é fornecido por:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand), exceto que quantization_dimension(operand), scales(operand) e zero_points(operand) podem ser diferentes da resp. quantization_dimension(result), scales(result) e zero_points(result). Caso contrário.
  • (C2) size(broadcast_dimensions) = rank(operand).
  • (C3) 0 <= broadcast_dimensions < rank(result).
  • (C4) is_unique(broadcast_dimensions).
  • (C5) Para todos os d em axes(operand):
    • dim(operand, d) = 1 ou
    • dim(operand, d) = dim(result, broadcast_dimensions[d]).
  • (C6) Se is_per_axis_quantized(result):
    • quantization_dimension(result) = broadcast_dimensions[quantization_dimension(operand)].
    • Se for dim(operand, quantization_dimension(operand)) = 1, então scales(result)[i] = scales(operand)[0] and zero_points(result)[i] = zero_points(operand)[0] for i in range(dim(result, quantization_dimension(result))).

Exemplos

// %operand: [
//            [1, 2, 3]
//           ]
%result = "stablehlo.broadcast_in_dim"(%operand) {
  broadcast_dimensions = array<i64: 2, 1>
} : (tensor<1x3xi32>) -> tensor<2x3x2xi32>
// %result: [
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ],
//            [
//             [1, 1],
//             [2, 2],
//             [3, 3]
//            ]
//          ]

Mais exemplos

caso

Semântica

Produz a saída da execução de exatamente uma função de branches, dependendo do valor de index. Mais formalmente, result = selected_branch(), em que:

  • selected_branch = branches[index] se 0 <= index < size(branches).
  • Caso contrário, selected_branch = branches[-1].

Entradas

Rótulo Nome Tipo Restrições
(I1) index Tensor dimensional 0 do tipo si32
(I2) branches número variável de funções (C1 a C4)

Saídas

Nome Tipo Restrições
results número variad de tensores, tensores quantizados ou tokens (C4)

Restrições

  • (C1) 0 < size(branches).
  • (C2) input_types(branches...) = [].
  • (C3) same(output_types(branches...)).
  • (C4) type(results...) = output_types(branches[0]).

Exemplos

// %index: -1
// %result_branch0: [0, 0]
// %result_branch1: [1, 1]
%result0, %result1 = "stablehlo.case"(%index) ({
  "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
  "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
// %result0: [1, 1]
// %result1: [1, 1]

Mais exemplos

CBRT

Semântica

Executa uma operação de raiz cúbica com elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: rootn(x, 3) do IEEE-754.
  • Para números complexos: raiz cúbica complexa.
  • Para tipos quantizados: dequantize_op_quantize(cbrt, operand, type(result))

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [0.0, 1.0, 8.0, 27.0]
%result = "stablehlo.cbrt"(%operand) : (tensor<4xf64>) -> tensor<4xf64>
// %result: [0.0, 1.0, 2.0, 3.0]

Mais exemplos

ceil

Semântica

Executa o ceil com elementos do tensor operand e produz um tensor result. Implementa a operação roundToIntegralTowardPositive da especificação IEEE-754. Para tipos quantizados, executa dequantize_op_quantize(ceil, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.ceil"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-0.0, -0.0, 1.0, 1.0, 2.0]

Mais exemplos

Cholesky

Semântica

Calcula a decomposição de Cholesky de um lote de matrizes.

Mais formalmente, para todos os i em index_space(result), result[i0, ..., iR-3, :, :] é uma decomposição de Cholesky de a[i0, ..., iR-3, :, :], na forma de uma matriz triangular inferior (se lower for true) ou triangular superior (se lower for false). Os valores de saída no triângulo oposto, ou seja, o triângulo superior restrito ou o triângulo inferior restrito correspondentes, são definidos pela implementação.

Se existir i em que a matriz de entrada não for uma matriz hermitiana definida positiva, o comportamento será indefinido.

Para tipos quantizados, executa dequantize_op_quantize(lambda operand: cholesky(operand, lower), a, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) a tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1 a C3)
(I2) lower constante do tensor 0-dimensional do tipo i1

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(a) = baseline_type(result).
  • (C2) 2 <= rank(a).
  • (C3) dim(a, -2) = dim(a, -1).

Exemplos

// %a: [
//      [1.0, 2.0, 3.0],
//      [2.0, 20.0, 26.0],
//      [3.0, 26.0, 70.0]
//     ]
%result = "stablehlo.cholesky"(%a) {
  lower = true
} : (tensor<3x3xf32>) -> tensor<3x3xf64>
// %result: [
//           [1.0, 0.0, 0.0],
//           [2.0, 4.0, 0.0],
//           [3.0, 5.0, 6.0]
//          ]

limitar

Semântica

Fixa todos os elementos do tensor operand entre um valor mínimo e um máximo e produz um tensor result. Mais formalmente, result[result_index] = minimum(maximum(operand[result_index], min_element), max_element), em que min_element = rank(min) = 0 ? min[] : min[result_index], max_element = rank(max) = 0 ? max[] : max[result_index]. Para tipos quantizados, executa dequantize_op_quantize(clamp, min, operand, max, type(result)).

Impor uma ordenação em números complexos envolve semânticas surpreendentes. Portanto, no futuro, planejamos remover o suporte a números complexos para essa operação (#560).

Entradas

Rótulo Nome Tipo Restrições
(I1) min tensor ou tensor quantizado por tensor (C1) e (C3)
(I2) operand tensor ou tensor quantizado por tensor (C1 a C4)
(I3) max tensor ou tensor quantizado por tensor (C2) e (C3)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C4)

Restrições

  • (C1) rank(min) = 0 or shape(min) = shape(operand).
  • (C2) rank(max) = 0 or shape(max) = shape(operand).
  • (C3) baseline_element_type(min) = baseline_element_type(operand) = baseline_element_type(max).
  • (C4) baseline_type(operand) = baseline_type(result).

Exemplos

// %min: [5, 10, 15]
// %operand: [3, 13, 23]
// %max: [10, 15, 20]
%result = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// %result: [5, 13, 20]

Mais exemplos

collective_broadcast

Semântica

Dentro de cada grupo de processos na grade de processos do StableHLO, envie o valor do tensor operand do processo de origem para os processos de destino e produza um tensor result.

A operação divide a grade do processo StableHLO em process_groups, que é definida da seguinte maneira:

  • cross_replica(replica_groups) se channel_id <= 0.
  • cross_partition(replica_groups) se channel_id > 0.

Depois disso, result@process é fornecido por:

  • operand@process_groups[i, 0] se existir um i de modo que o processo esteja em process_groups[i].
  • Caso contrário, broadcast_in_dim(constant(0, element_type(result)), [], type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand Tensor (C3)
(I2) replica_groups número variado de constantes de tensor unidimensionais do tipo si64 (C1) e (C2)
(I3) channel_id constante do tipo si64

Saídas

Nome Tipo Restrições
result Tensor (C3)

Restrições

  • (C1) is_unique(replica_groups).
  • (C2) 0 <= replica_groups < N, em que N é definido como:
    • num_replicas se cross_replica for usado.
    • num_partitions se cross_partition for usado.
  • (C3) type(result) = type(operand).

Exemplos

// num_replicas: 4
// num_partitions: 1
// %operand@(0, 0): [[1, 2]]
// %operand@(1, 0): [[3, 4]]
// %operand@(2, 0): [[5, 6]]
// %operand@(3, 0): [[7, 8]]
%result = "stablehlo.collective_broadcast"(%operand) {
  replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor1x2xi64>) -> tensor<1x2xi64>
// %result@(0, 0): [[0, 0]]
// %result@(1, 0): [[5, 6]]
// %result@(2, 0): [[5, 6]]
// %result@(3, 0): [[0, 0]]

collective_permute

Semântica

Dentro de cada grupo de processos na grade de processo StableHLO, envia o valor do tensor operand do processo de origem para o processo de destino e produz um tensor result.

A operação divide a grade do processo StableHLO em process_groups, que é definida da seguinte maneira:

  • cross_replica(source_target_pairs) se channel_id <= 0.
  • cross_partition(source_target_pairs) se channel_id > 0.

Depois disso, result@process é fornecido por:

  • operand@process_groups[i, 0], se existir um i como process_groups[i, 1] = process.
  • Caso contrário, broadcast_in_dim(constant(is_quantized(result) ? quantize(0, element_type(result)) : 0, element_type(result)), [], type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C5)
(I2) source_target_pairs constante do tensor bidimensional do tipo si64 (C1 a C4)
(I3) channel_id constante do tipo si64

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1)

Restrições

  • (C1) dim(source_target_pairs, 1) = 2.
  • (C2) is_unique(source_target_pairs[:, 0]).
  • (C3) is_unique(source_target_pairs[:, 1]).
  • (C4) 0 <= source_target_pairs < N, em que N é definido como:
    • num_replicas se cross_replica for usado.
    • num_partitions se cross_partition for usado.
  • (C5) type(result) = type(operand).

Exemplos

// num_replicas: 3
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
// %operand@(2, 0): [[9, 10], [11, 12]]
%result = "stablehlo.collective_permute"(%operand) {
  source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[0, 0], [0, 0]]
// %result@(1, 0): [[1, 2], [3, 4]]
// %result@(2, 0): [[5, 6], [7, 8]]

Mais exemplos

compare

Semântica

Executa a comparação com elementos dos tensores lhs e rhs de acordo com comparison_direction e compare_type, e produz um tensor result.

Os valores de comparison_direction e compare_type têm a seguinte semântica:

Para tipos de elementos booleanos e inteiros:

  • EQ: lhs = rhs.
  • NE: lhs != rhs.
  • GE: lhs >= rhs.
  • GT: lhs > rhs.
  • LE: lhs <= rhs.
  • LT: lhs < rhs.

Para tipos de elementos de ponto flutuante com compare_type = FLOAT, a operação implementa as seguintes operações IEEE-754:

  • EQ: compareQuietEqual.
  • NE: compareQuietNotEqual.
  • GE: compareQuietGreaterEqual.
  • GT: compareQuietGreater.
  • LE: compareQuietLessEqual.
  • LT: compareQuietLess.

Para tipos de elementos de ponto flutuante com compare_type = TOTALORDER, a operação usa a combinação de operações totalOrder e compareQuietEqual do IEEE-754. Esse recurso parece não ser usado, por isso planejamos removê-lo no futuro (#584).

Para tipos de elementos complexos, a comparação lexicográfica de pares (real, imag) é realizada usando os comparison_direction e compare_type fornecidos. Impor uma ordenação em números complexos envolve semânticas surpreendentes. Portanto, no futuro, vamos remover o suporte para números complexos quando comparison_direction for GE, GT, LE ou LT (#560).

Para tipos quantizados, executa dequantize_compare(lhs, rhs, comparison_direction).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor ou tensor quantizado por tensor (C1 a C3)
(I2) rhs tensor ou tensor quantizado por tensor (C1 a C2)
(I3) comparison_direction enumeração de EQ, NE, GE, GT, LE e LT
(I4) compare_type enumeração de FLOAT, TOTALORDER, SIGNED e UNSIGNED (C3)

Saídas

Nome Tipo Restrições
result tensor de tipo booleano (C2)

Restrições

  • (C1) baseline_element_type(lhs) = baseline_element_type(rhs).
  • (C2) shape(lhs) = shape(rhs) = shape(result).
  • (C3) compare_type é definida como:
    • SIGNED se is_signed_integer(element_type(lhs)).
    • UNSIGNED se is_unsigned_integer(element_type(lhs)) or is_boolean(element_type(lhs)).
    • FLOAT ou TOTALORDER se is_float(element_type(lhs)).
    • FLOAT se is_complex(element_type(lhs)).

Exemplos

// %lhs: [1.0, 3.0]
// %rhs: [1.1, 2.9]
%result = "stablehlo.compare"(%lhs, %rhs) {
  comparison_direction = #stablehlo<comparison_direction LT>,
  compare_type = #stablehlo<comparison_type FLOAT>
} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
// %result: [true, false]

Mais exemplos

complexo

Semântica

Executa a conversão com elementos para um valor complexo a partir de um par de valores reais e imaginários, lhs e rhs, e produz um tensor result.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor de tipo f32 ou f64 (C1 a C3)
(I2) rhs tensor de tipo f32 ou f64 (C1)

Saídas

Nome Tipo Restrições
result tensor de tipo complexo (C2) e (C3)

Restrições

  • (C1) type(lhs) = type(rhs).
  • (C2) shape(result) = shape(lhs).
  • (C3) element_type(result) tem o tipo complex<E>, em que E = element_type(lhs).

Exemplos

// %lhs: [1.0, 3.0]
// %rhs: [2.0, 4.0]
%result = "stablehlo.complex"(%lhs, %rhs) : (tensor<2xf64>, tensor<2xf64>) -> tensor<2xcomplex<f64>>
// %result: [(1.0, 2.0), (3.0, 4.0)]

Mais exemplos

concatenate

Semântica

Concatena inputs com a dimensão dimension na mesma ordem que os argumentos fornecidos e produz um tensor result. Mais formalmente, result[i0, ..., id, ..., iR-1] = inputs[k][i0, ..., kd, ..., iR-1], em que:

  1. id = d0 + ... + dk-1 + kd.
  2. d é igual a dimension, e d0, ... são tamanhos de da dimensão de inputs.

Entradas

Rótulo Nome Tipo Restrições
(I1) inputs número variável de tensores ou tensores quantizados por tensor (C1 a C6)
(I2) dimension constante do tipo si64 (C2), (C4) (C6)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C5 a C6)

Restrições

  • (C1) same(element_type(inputs...)).
  • (C2) same(shape(inputs...)), exceto dim(inputs..., dimension).
  • (C3) 0 < size(inputs).
  • (C4) 0 <= dimension < rank(inputs[0]).
  • (C5) element_type(result) = element_type(inputs[0]).
  • (C6) shape(result) = shape(inputs[0]) exceto:
    • dim(result, dimension) = dim(inputs[0], dimension) + ....

Exemplos

// %input0: [[1, 2], [3, 4], [5, 6]]
// %input1: [[7, 8]]
%result = "stablehlo.concatenate"(%input0, %input1) {
  dimension = 0 : i64
} : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64>
// %result: [[1, 2], [3, 4], [5, 6], [7, 8]]

Mais exemplos

constante

Semântica

Produz um tensor output com base em um value constante.

Entradas

Rótulo Nome Tipo Restrições
(I1) value constante (C1)

Saídas

Nome Tipo Restrições
output tensor ou tensor quantizado (C1)

Restrições

  • (C1) type(value) = type(output).

Exemplos

%output = "stablehlo.constant"() {
  value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
} : () -> tensor<2x2xf32>
// %output: [[0.0, 1.0], [2.0, 3.0]]

Mais exemplos

fazer uma conversão

Semântica

Executa uma conversão com elementos de um tipo de elemento para outro no tensor operand e produz um tensor result.

Para conversões de boolean-to-any-supported-type, o valor false é convertido em zero, e o valor true é convertido em um. Para any-supported-type-to-boolean, um valor zero é convertido em false, e os valores diferentes de zero são convertidos em true. Confira abaixo como isso funciona para tipos complexos.

Para conversões que envolvem número inteiro a inteiro, número inteiro para ponto flutuante ou ponto flutuante para ponto flutuante, se o valor de origem puder ser representado exatamente no tipo de destino, o valor do resultado vai ser essa representação exata. Caso contrário, o comportamento será a definir (#180).

Para conversões que envolvem floating-point-to-integer, a parte fracionária fica truncada. Se o valor truncado não puder ser representado no tipo de destino, o comportamento será definido (#180).

A conversão que envolve complexo para complexo segue o mesmo comportamento das conversões de ponto flutuante para ponto flutuante na conversão de partes reais e imaginárias.

Para conversões de complex-to-any-other-type e complex-to-any-other-type, o valor imaginário de origem é ignorado, ou o valor imaginário de destino é zero, respectivamente. A conversão da parte real segue as conversões de ponto flutuante.

Em princípio, essa operação poderia expressar desquantização (conversão de tensores quantizados em tensores regulares), quantização (conversão de tensores regulares em tensores quantizados) e requantização (conversão entre tensores quantizados), mas no momento temos operações dedicadas para isso - uniform_dequantize para o primeiro e uniform_quantize para o segundo e o terceiro casos de uso. No futuro, essas duas operações podem ser mescladas em convert (#1576).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand Tensor (C1)

Saídas

Nome Tipo Restrições
result Tensor (C1)

Restrições

  • (C1) shape(operand) = shape(result).

Exemplos

// %operand: [-1, 0, 1]
%result = "stablehlo.convert"(%operand) : (tensor<3xi64>) -> tensor<3xcomplex<f64>>
// %result: [(-1.0, 0.0), (0.0, 0.0), (1.0, 0.0)]

Mais exemplos

convolução

Semântica

Computa produtos escalares entre janelas de lhs e frações de rhs e produz result. O diagrama a seguir mostra como os elementos em result são calculados a partir de lhs e rhs usando um exemplo concreto.

Mais formalmente, considere reformular as entradas abaixo em termos de lhs para poder expressar janelas de lhs:

  • lhs_window_dimensions = lhs_shape(dim(lhs, input_batch_dimension), dim(rhs, kernel_spatial_dimensions), dim(lhs, input_feature_dimension)).
  • lhs_window_strides = lhs_shape(1, window_strides, 1).
  • lhs_padding = lhs_shape([0, 0], padding, [0, 0]).
  • lhs_base_dilations = lhs_shape(1, lhs_dilation, 1).
  • lhs_window_dilations = lhs_shape(1, rhs_dilation, 1).

Esse reenquadramento usa as seguintes funções auxiliares:

  • lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]).
  • result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]).
  • permute([j0, j1, ..., jR-1], permutation) = [i0, i1, ..., iR-1], em que j[d] = i[permutation[d]].

Se for feature_group_count = 1 e batch_group_count = 1, para todos os output_spatial_index em index_space(dim(result, output_spatial_dimensions...)), result[result_shape(:, output_spatial_index, :)] = dot_product em que:

  • padding_value = constant(0, element_type(lhs)).
  • padded_lhs = pad(lhs, padding_value, lhs_padding[:, 0], lhs_padding[:, 1], lhs_base_dilations - 1).
  • lhs_window_start = lhs_shape(0, output_spatial_index, 0) * lhs_window_strides.
  • lhs_window = slice(padded_lhs, lhs_window_start, lhs_window_start + lhs_window_dimensions, lhs_window_dilations).
  • reversed_lhs_window = reverse(lhs_window, [input_spatial_dimensions[dim] for dim in range(size(window_reversal)) if window_reversal[dim] = true]). Esse recurso parece não ser usado, então planejamos removê-lo no futuro (#1181).
  • dot_product = dot_general(reversed_lhs_window, rhs, lhs_batching_dimensions=[], lhs_contracting_dimensions=input_spatial_dimensions + [input_feature_dimension], rhs_batching_dimensions=[], rhs_contracting_dimensions=kernel_spatial_dimensions + [kernel_input_feature_dimension]).

Se feature_group_count > 1:

  • lhses = split(lhs, feature_group_count, input_feature_dimension).
  • rhses = split(rhs, feature_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., feature_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension).

Se batch_group_count > 1:

  • lhses = split(lhs, batch_group_count, input_batch_dimension).
  • rhses = split(rhs, batch_group_count, kernel_output_feature_dimension).
  • results... = convolution(lhses..., rhses..., ..., batch_group_count=1, ...).
  • result = concatenate(results, output_feature_dimension)

Para tipos quantizados, executa dequantize_op_quantize( lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, window_reversal, input_batch_dimension, input_feature_dimension, input_spatial_dimensions, kernel_input_feature_dimension, kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimensions, feature_group_count, batch_group_count, precision_config), lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor ou tensor quantizado por tensor (C1), (C10-C11), (C14) (C25) (C27-C30)
(I2) rhs tensor ou tensor quantizado (C1), (C14-C16), (C25) (C27-C32)
(I3) window_strides constante do tensor unidimensional do tipo si64 (C2 a C3) (C25)
(I4) padding constante do tensor bidimensional do tipo si64 (C4) e (C25)
(I5) lhs_dilation constante do tensor unidimensional do tipo si64 (C5 a C6) (C25)
(I6) rhs_dilation constante do tensor unidimensional do tipo si64 (C7 a C8) (C25)
(I7) window_reversal constante do tensor unidimensional do tipo i1 (C9)
(I8) input_batch_dimension constante do tipo si64 (C10), (C13) (C25)
(I9) input_feature_dimension constante do tipo si64 (C11) e (C13 a C14).
(I10) input_spatial_dimensions constante do tensor unidimensional do tipo si64 (C12), (C13) (C25)
(I11) kernel_input_feature_dimension constante do tipo si64 (C14) e (C18)
(I12). kernel_output_feature_dimension constante do tipo si64 (C15 a C16), (C18), (C25) (C32)
(I13) kernel_spatial_dimensions constante do tensor unidimensional do tipo si64 (C17 a C18) (C25)
(I14) output_batch_dimension constante do tipo si64 (C20) e (C25)
(I15) output_feature_dimension constante do tipo si64 (C20), (C25) (C33)
(I16) output_spatial_dimensions constante do tensor unidimensional do tipo si64 (C19 a C20) (C25)
(I17) feature_group_count constante do tipo si64 (C11), (C14), (C16), (C21) (C23)
(I18) batch_group_count constante do tipo si64 (C10), (C15), (C22), (C23) (C25)
(I19). precision_config número variável de enumerações de DEFAULT, HIGH e HIGHEST. (C24)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado (C25 a C28), (C30 a C31) (C33)

Restrições

  • (C1) N = rank(lhs) = rank(rhs).
  • (C2) size(window_strides) = N - 2.
  • (C3) 0 < window_strides.
  • (C4) shape(padding) = [N - 2, 2].
  • (C5) size(lhs_dilation) = N - 2.
  • (C6) 0 < lhs_dilation.
  • (C7) size(rhs_dilation) = N - 2.
  • (C8) 0 < rhs_dilation.
  • (C9) size(window_reversal) = N - 2.
  • (C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
  • (C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.
  • (C12) size(input_spatial_dimensions) = N - 2.
  • (C13) Considerando input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]:
    • is_unique(input_dimensions).
    • 0 <= input_dimensions < N.
  • (C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
  • (C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
  • (C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.
  • (C17) size(kernel_spatial_dimensions) = N - 2.
  • (C18) Considerando kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]:
    • is_unique(kernel_dimensions).
    • 0 <= kernel_dimensions < N.
  • (C19) size(output_spatial_dimensions) = N - 2.
  • (C20) Considerando output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]:
    • is_unique(output_dimensions).
    • 0 <= output_dimensions < N.
  • (C21) 0 < feature_group_count.
  • (C22) 0 < batch_group_count.
  • (C23) feature_group_count = 1 or batch_group_count = 1.
  • (C24) size(precision_config) = 2.
  • (C25) dim(result, result_dim) é definido como:
    • dim(lhs, input_batch_dimension) / batch_group_count se result_dim = output_batch_dimension.
    • dim(rhs, kernel_output_feature_dimension) se result_dim = output_feature_dimension.
    • num_windows, em que:
    • output_spatial_dimensions[spatial_dim] = result_dim.
    • lhs_dim = input_spatial_dimensions[spatial_dim].
    • rhs_dim = kernel_spatial_dimensions[spatial_dim].
    • dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
    • padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
    • dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
    • is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
    • num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
  • (C26) rank(result) = N.
  • Se a operação usar tensores não quantizados:
    • (C27) element_type(lhs) = element_type(rhs) = element_type(result).
  • Se a operação usar tensores quantizados:
    • (C28) is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and is_quantized_tensor(result).
    • (C29) storage_type(lhs) = storage_type(rhs).
    • (C30) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C31) Se is_per_tensor_quantized(rhs), então is_per_tensor_quantized(result).
    • (C32) Se is_per_axis_quantized(rhs), então quantization_dimension(rhs) = kernel_output_feature_dimension.
    • (C33) Se is_per_axis_quantized(result), então quantization_dimension(result) = output_feature_dimension.

Exemplos

// %lhs: [[
//        [
//          [1], [2], [5], [6]
//        ],
//        [
//          [3], [4], [7], [8]
//        ],
//        [
//          [10], [11], [14], [15]
//        ],
//        [
//          [12], [13], [16], [17]
//        ]
//      ]]
//
// %rhs : [
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]],
//         [[[1]], [[1]], [[1]]]
//        ]
%result = "stablehlo.convolution"(%lhs, %rhs) {
  window_strides = dense<4> : tensor<2xi64>,
  padding = dense<0> : tensor<2x2xi64>,
  lhs_dilation = dense<2> : tensor<2xi64>,
  rhs_dilation = dense<1> : tensor<2xi64>,
  window_reversal = dense<false> : tensor<2xi1>,
  // In the StableHLO dialect, dimension numbers are encoded via:
  // `[<input dimensions>]x[<kernel dimensions>]->[output dimensions]`.
  // "b" is batch dimension, "f" is feature dimension,
  // "i" is input feature dimension, "o" is output feature dimension,
  // "0/1/etc" are spatial dimensions.
  dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>,
  feature_group_count = 1 : i64,
  batch_group_count = 1 : i64,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<1x4x4x1xi32>, tensor<3x3x1x1xi32>) -> tensor<1x2x2x1xi32>
// %result: [[
//            [[10], [26]],
//            [[46], [62]]
//          ]]

cosseno

Semântica

Executa a operação de cosseno com elementos no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: cos do IEEE-754.
  • Para números complexos: cosseno complexo.
  • Para tipos quantizados: dequantize_op_quantize(cosine, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.cosine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.0], [-1.0, 0.0]]

Mais exemplos

count_leading_zeros

Semântica

Executa a contagem elemento-chave do número de zero bits à esquerda no tensor operand e produz um tensor result.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand Tensor do tipo de número inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo de número inteiro (C1)

Restrições

  • (C1) type(operand) = type(result).

Exemplos

// %operand: [[0, 1], [128, -1]]
%result = "stablehlo.count_leading_zeros"(%operand) : (tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[64, 63], [56, 0]]

Mais exemplos

custom_call

Semântica

Encapsula uma operação call_target_name definida pela implementação que usa inputs e called_computations e produz results. has_side_effect, backend_config e api_version podem ser usados para fornecer outros metadados definidos pela implementação.

No momento, essa operação contém uma coleção bastante desorganizada de metadados que reflete a evolução orgânica da operação correspondente no compilador XLA. No futuro, planejamos unificar esses metadados (#741).

Entradas

Rótulo Nome Tipo
(I1) inputs número variável de valores
(I2) call_target_name constante do tipo string
(I3) has_side_effect constante do tipo i1
(I4) backend_config constante do tipo string
(I5) api_version constante do tipo si32
(I6) called_computations número variável de constantes do tipo string

Saídas

Nome Tipo
results número variável de valores

Exemplos

%results = "stablehlo.custom_call"(%input0) {
  call_target_name = "foo",
  has_side_effect = false,
  backend_config = "bar",
  api_version = 1 : i32,
  called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>

dividir

Semântica

Executa a divisão elemento-chave dos tensores de dividendo lhs e divisor rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para números inteiros: divisão de números inteiros que produz o quociente algébico com qualquer parte fracionária descartada.
  • Para dados flutuantes: division do IEEE-754.
  • Para números complexos: divisão complexa.
  • Para tipos quantizados:
    • dequantize_op_quantize(divide, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor de número inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)
(I2) rhs tensor de número inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemplos

// %lhs: [17.1, -17.1, 17.1, -17.1]
// %rhs: [3.0, 3.0, -3.0, -3.0]
%result = "stablehlo.divide"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// %result: [5.66666651, -5.66666651, -5.66666651, 5.66666651]

Mais exemplos

dot_general

Semântica

Computa produtos escalares entre frações de lhs e rhs e produz um tensor result.

Mais formalmente, result[result_index] = dot_product, em que:

  • lhs_result_dimensions = [d for d in axes(lhs) and d not in lhs_batching_dimensions and d not in lhs_contracting_dimensions].
  • rhs_result_dimensions = [d for d in axes(rhs) and d not in rhs_batching_dimensions and d not in rhs_contracting_dimensions].
  • result_batching_index + result_lhs_index + result_rhs_index = result_index, em que size(result_batching_index) = size(lhs_batching_dimensions), size(result_lhs_index) = size(lhs_result_dimensions) e size(result_rhs_index) = size(rhs_result_dimensions).
  • transposed_lhs = transpose(lhs, lhs_batching_dimensions + lhs_result_dimensions + lhs_contracting_dimensions).
  • transposed_lhs_slice = slice(transposed_lhs, result_batching_index + result_lhs_index + [:, ..., :]).
  • reshaped_lhs_slice = reshape(transposed_lhs_slice, dims(lhs, lhs_contracting_dimensions)).
  • transposed_rhs = transpose(rhs, rhs_batching_dimensions + rhs_result_dimensions + rhs_contracting_dimensions).
  • transposed_rhs_slice = slice(transposed_rhs, result_batching_index + result_rhs_index + [:, ..., :]).
  • reshaped_rhs_slice = reshape(transposed_rhs_slice, dims(rhs, rhs_contracting_dimensions)).
  • dot_product = reduce( inputs=[multiply(reshaped_lhs_slice, reshaped_rhs_slice)], init_values=[constant(0, element_type(result))], dimensions=range(size(lhs_contracting_dimensions)), body=lambda x, y: add(x, y))

Para tipos quantizados, executa dequantize_op_quantize( lambda lhs, rhs: dot_general(lhs, rhs, lhs_batching_dimensions, rhs_batching_dimensions, lhs_contracting_dimensions, rhs_contracting_dimensions, precision_config), lhs, rhs, type(result)).

Especifica apenas a semântica da quantização por tensor. A quantização por eixo está em desenvolvimento (#1574). Além disso, no futuro, poderemos considerar adicionar suporte à quantização híbrida (#1575).

precision_config controla a compensação entre velocidade e precisão para cálculos em back-ends de aceleradores. Pode ser um dos seguintes. No momento, a semântica desses valores de tipo enumerado está subespecificada, mas estamos planejando resolver isso em #755 (link em inglês):

  • DEFAULT: cálculo mais rápido, mas estimativa menos precisa do número original.
  • HIGH: cálculo mais lento, mas com aproximação mais precisa do número original.
  • HIGHEST: cálculo mais lento, mas aproximação mais precisa do número original.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor ou tensor quantizado por tensor (C5 a C6), (C9 a C10) (C12 a C16)
(I2) rhs tensor ou tensor quantizado por tensor (C7-C10) (C12)
(I3) lhs_batching_dimensions constante do tensor unidimensional do tipo si64 (C1), (C3), (C5), (C9) (C12)
(I4) rhs_batching_dimensions constante do tensor unidimensional do tipo si64 (C1), (C4), (C7) (C9)
(I5) lhs_contracting_dimensions constante do tensor unidimensional do tipo si64 (C2), (C3), (C6) (C10)
(I6) rhs_contracting_dimensions constante do tensor unidimensional do tipo si64 (C2), (C4), (C8) (C10)
(I7) precision_config número variável de enumerações de DEFAULT, HIGH e HIGHEST. (C11)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C12), (C14) (C16)

Restrições

  • (C1) size(lhs_batching_dimensions) = size(rhs_batching_dimensions).
  • (C2) size(lhs_contracting_dimensions) = size(rhs_contracting_dimensions).
  • (C3) is_unique(lhs_batching_dimensions + lhs_contracting_dimensions).
  • (C4) is_unique(rhs_batching_dimensions + rhs_contracting_dimensions).
  • (C5) 0 <= lhs_batching_dimensions < rank(lhs).
  • (C6) 0 <= lhs_contracting_dimensions < rank(lhs).
  • (C7) 0 <= rhs_batching_dimensions < rank(rhs).
  • (C8) 0 <= rhs_contracting_dimensions < rank(rhs).
  • (C9) dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).
  • (C10) dim(lhs, lhs_contracting_dimensions...) = dim(rhs, rhs_contracting_dimensions...).
  • (C11) size(precision_config) = 2.
  • (C12) shape(result) = dim(lhs, lhs_batching_dimensions) + dim(lhs, lhs_result_dimensions) + dim(rhs, rhs_result_dimensions).
  • Se a operação usar tensores não quantizados:
    • (C13) element_type(lhs) = element_type(rhs).
  • Se a operação usar tensores quantizados:
    • (C14) is_quantized(lhs) and is_quantized(rhs) and is_quantized(result).
    • (C15) storage_type(lhs) = storage_type(rhs).
    • (C16) expressed_type(lhs) = expressed_type(rhs) = expressed_type(result).
    • (C17) zero_points(rhs) = 0.

Exemplos

// %lhs: [
//        [[1, 2],
//         [3, 4]],
//        [[5, 6],
//         [7, 8]]
//       ]
// %rhs: [
//        [[1, 0],
//         [0, 1]],
//        [[1, 0],
//         [0, 1]]
//       ]
%result = "stablehlo.dot_general"(%lhs, %rhs) {
  dot_dimension_numbers = #stablehlo.dot<
    lhs_batching_dimensions = [0],
    rhs_batching_dimensions = [0],
    lhs_contracting_dimensions = [2],
    rhs_contracting_dimensions = [1]
  >,
  precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
} : (tensor<2x2x2xi64>, tensor<2x2x2xi64>) -> tensor<2x2x2xi64>
// %result: [
//           [[1, 2],
//            [3, 4]],
//           [[5, 6],
//            [7, 8]]
//          ]

Mais exemplos

dynamic_slice

Semântica

Extrai uma fração do operand usando índices iniciais calculados dinamicamente e produz um tensor result. start_indices contém os índices iniciais da fração para cada dimensão sujeita a um possível ajuste, e slice_sizes contém os tamanhos da fatia para cada dimensão. Mais formalmente, result[result_index] = operand[operand_index], em que:

  • adjusted_start_indices = clamp(0, start_indices, shape(operand) - slice_sizes).
  • operand_index = adjusted_start_indices + result_index.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1), (C2) e (C4)
(I2) start_indices número variado de tensores 0-dimensionais do tipo inteiro (C2) e (C3)
(I3) slice_sizes constante do tensor unidimensional do tipo si64 (C2), (C4) (C5)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1) e (C5)

Restrições

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(slice_sizes) = rank(operand).
  • (C3) same(type(start_indices...)).
  • (C4) 0 <= slice_sizes <= shape(operand).
  • (C5) shape(result) = slice_sizes.

Exemplos

// %operand: [
//            [0, 0, 1, 1],
//            [0, 0, 1, 1],
//            [0, 0, 0, 0],
//            [0, 0, 0, 0]
//           ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_slice"(%operand, %start_indices0, %start_indices1) {
  slice_sizes = dense<[2, 2]> : tensor<2xi64>
} : (tensor<4x4xi32>, tensor<i64>, tensor<i64>) -> tensor<2x2xi32>
// %result: [
//           [1, 1],
//           [1, 1]
//          ]

Mais exemplos

dynamic_update_slice

Semântica

Produz um tensor result que é igual ao tensor operand, exceto que a fração que começa em start_indices é atualizada com os valores em update. Mais formalmente, result[result_index] é definido como:

  • update[update_index] se 0 <= update_index < shape(update), em que:
    • adjusted_start_indices = clamp(0, start_indices, shape(operand) - shape(update)).
    • update_index = result_index - adjusted_start_indices.
  • Caso contrário, operand[result_index].

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1 a C4) (C6)
(I2) update tensor ou tensor quantizado por tensor (C2), (C3) e (C6)
(I3) start_indices número variado de tensores 0-dimensionais do tipo inteiro (C4) e (C5)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1)

Restrições

  • (C1) type(operand) = type(result).
  • (C2) element_type(update) = element_type(operand).
  • (C3) rank(update) = rank(operand).
  • (C4) size(start_indices) = rank(operand).
  • (C5) same(type(start_indices...)).
  • (C6) 0 <= shape(update) <= shape(operand).

Exemplos

// %operand: [
//            [1, 1, 0, 0],
//            [1, 1, 0, 0],
//            [1, 1, 1, 1],
//            [1, 1, 1, 1]
//           ]
// %update: [
//           [1, 1],
//           [1, 1]
//          ]
// %start_indices0: -1
// %start_indices1: 3
%result = "stablehlo.dynamic_update_slice"(%operand, %update, %start_indices0, %start_indices1)
  : (tensor<4x4xi32>, tensor<2x2xi32>, tensor<i64>, tensor<i64>) -> tensor<4x4xi32>
// %result: [
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1],
//           [1, 1, 1, 1]
//          ]

Mais exemplos

exponencial

Semântica

Executa a operação exponencial com elementos no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: exp do IEEE-754.
  • Para números complexos: exponencial complexa.
  • Para tipos quantizados: dequantize_op_quantize(exponential, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.exponential"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[1.0, 2.7182818284590451], [7.3890560989306504, 20.085536923187668]]

Mais exemplos

exponential_minus_one

Semântica

Executa exponencialmente exponencial menos uma operação no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: expm1 do IEEE-754.
  • Para números complexos: exponencial complexa menos um.
  • Para tipos quantizados: dequantize_op_quantize(exponential_minus_one, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [0.0, 1.0]
%result = "stablehlo.exponential_minus_one"(%operand) : (tensor<2xf64>) -> tensor<2xf64>
// %result: [0.0, 1.71828187]

Mais exemplos

fft

Semântica

Executa as transformações de Fourier direta e inversa para entradas/saídas reais e complexas.

fft_type é um destes:

  • FFT: encaminhamento de FFT complexo para complexo.
  • IFFT: FFT inverso de complexo para complexo.
  • RFFT: encaminhamento de FFT real para complexo.
  • IRFFT: FFT real para inverso (ou seja, usa dados complexos e retorna reais).

Mais formalmente, considerando a função fft, que usa tensores unidimensionais de tipos complexos como entrada, produz tensores unidimensionais do mesmo tipo que a saída e calcula a transformada discreta de Fourier:

Para fft_type = FFT, result é definido como o resultado final de uma série de cálculos L, em que L = size(fft_length). Por exemplo, para L = 3:

  • result1[i0, ..., :] = fft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Além disso, considerando a função ifft, que tem a mesma assinatura de tipo e calcula o inverso de fft:

Para fft_type = IFFT, result é definido como o inverso dos cálculos para fft_type = FFT. Por exemplo, para L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = ifft(result2[i0, ..., :]).

Além disso, considerando a função rfft, que usa tensores unidimensionais de tipos de ponto flutuante, produz tensores unidimensionais de tipos complexos da mesma semântica de ponto flutuante e funciona da seguinte maneira:

  • rfft(real_operand) = truncated_result onde
  • complex_operand... = (real_operand..., 0.0).
  • complex_result = fft(complex_operand).
  • truncated_result = complex_result[:(rank(complex_result) / 2 + 1)].

Quando a transformação de Fourier discreta é calculada para operandos reais, os primeiros elementos N/2 + 1 do resultado definem inequivocamente o restante do resultado, de modo que o resultado de rfft é truncado para evitar a computação de elementos redundantes.

Para fft_type = RFFT, result é definido como o resultado final de uma série de cálculos L, em que L = size(fft_length). Por exemplo, para L = 3:

  • result1[i0, ..., :] = rfft(operand[i0, ..., :]).
  • result2[i0, ..., :, iR-1] = fft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :, iR-2, iR-1] = fft(result2[i0, ..., :, iR-2, iR-1]).

Por fim, considerando a função irfft, que tem a mesma assinatura de tipo e calcula o inverso de rfft:

Para fft_type = IRFFT, result é definido como o inverso dos cálculos para fft_type = RFFT. Por exemplo, para L = 3:

  • result1[i0, ..., :, iR-2, iR-1] = ifft(operand[i0, ..., :, iR-2, iR-1]).
  • result2[i0, ..., :, iR-1] = ifft(result1[i0, ..., :, iR-1]).
  • result[i0, ..., :] = irfft(result2[i0, ..., :]).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo (C1), (C2), (C4) (C5)
(I2) fft_type enumeração de FFT, IFFT, RFFT e IRFFT (C2) e (C5)
(I3) fft_length constante do tensor unidimensional do tipo si64 (C1), (C3) e (C4)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo (C2), (C4) (C5)

Restrições

  • (C1) size(fft_length) <= rank(operand).
  • (C2) A relação entre os tipos de elemento operand e result varia:
    • Se fft_type = FFT, element_type(operand) e element_type(result) tiverem o mesmo tipo complexo.
    • Se fft_type = IFFT, element_type(operand) e element_type(result) tiverem o mesmo tipo complexo.
    • Se fft_type = RFFT, element_type(operand) será um tipo de ponto flutuante e element_type(result) será um tipo complexo da mesma semântica de ponto flutuante.
    • Se fft_type = IRFFT, element_type(operand) será um tipo complexo e element_type(result) será um tipo de ponto flutuante da mesma semântica de ponto flutuante.
  • (C3) 1 <= size(fft_length) <= 3.
  • (C4) Se entre operand e result, houver um tensor real de um tipo de ponto flutuante, depois shape(real)[-size(fft_length):] = fft_length.
  • (C5) shape(result) = shape(operand) exceto:
    • Se for fft_type = RFFT, dim(result, -1) = dim(operand, -1) = 0 ? 0 : dim(operand, -1) / 2 + 1.
    • Se for fft_type = IRFFT, dim(operand, -1) = dim(result, -1) = 0 ? 0 : dim(result, -1) / 2 + 1.

Exemplos

// %operand: [(1.0, 0.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0)]
%result = "stablehlo.fft"(%operand) {
  fft_type = #stablehlo<fft_type FFT>,
  fft_length = dense<4> : tensor<1xi64>
} : (tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>>
// %result: [(1.0, 0.0), (1.0, 0.0), (1.0, 0.0), (1.0, 0.0)]

floor

Semântica

Executa o elemento mínimo do tensor operand e produz um tensor result. Implementa a operação roundToIntegralTowardNegative da especificação IEEE-754. Para tipos quantizados, executa dequantize_op_quantize(floor, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [-0.8166, -0.2530, 0.2530, 0.8166, 2.0]
%result = "stablehlo.floor"(%operand) : (tensor<5xf32>) -> tensor<5xf32>
// %result: [-1.0, -1.0, 0.0, 0.0, 2.0]

Mais exemplos

reunir

Semântica

Coleta fatias do tensor operand de deslocamentos especificados em start_indices e produz um tensor result.

O diagrama a seguir mostra como os elementos em result são mapeados em elementos em operand usando um exemplo concreto. O diagrama seleciona alguns índices de result de exemplo e explica em detalhes a quais índices operand eles correspondem.

Mais formalmente, result[result_index] = operand[operand_index], em que:

  • batch_dims = [d for d in axes(result) and d not in offset_dims].
  • batch_index = result_index[batch_dims...].
  • start_index é definido como:
    • start_indices[bi0, ..., :, ..., biN], em que bi são elementos individuais em batch_index, e : é inserido no índice de index_vector_dim, se index_vector_dim < rank(start_indices).
    • Caso contrário, [start_indices[batch_index]].
  • Para d_operand em axes(operand),
    • full_start_index[d_operand] = clamp(start_index[d_start], 0, dim(operand, d_operand) - slice_sizes[d_operand]) se d_operand = start_index_map[d_start].
    • Caso contrário, full_start_index[d_operand] = 0.
  • offset_index = result_index[offset_dims...].
  • full_offset_index = [oi0, ..., 0, ..., oiN], em que oi são elementos individuais em offset_index, e 0 é inserido em índices de collapsed_slice_dims.
  • operand_index = full_start_index + full_offset_index

Se indices_are_sorted for true, a implementação poderá presumir que start_indices estão classificados em relação a start_index_map. Caso contrário, o comportamento será indefinido. Mais formalmente, para todos os i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1), (C7), (C10-C12) (C14)
(I2) start_indices Tensor do tipo de número inteiro (C2), (C3) (C13)
(I3) offset_dims constante do tensor unidimensional do tipo si64 (C1), (C4-C5) (C13)
(I4) collapsed_slice_dims constante do tensor unidimensional do tipo si64 (C1), (C6 a C8) e (C13)
(I5) start_index_map constante do tensor unidimensional do tipo si64 (C3), (C9) (C10)
(I6) index_vector_dim constante do tipo si64 (C2), (C3) (C13)
(I7) slice_sizes constante do tensor unidimensional do tipo si64 (C8) (C11 a C13)
(I8) indices_are_sorted constante do tipo i1

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C5) (C13 a C14)

Restrições

  • (C1) rank(operand) = size(offset_dims) + size(collapsed_slice_dims).
  • (C2) 0 <= index_vector_dim <= rank(start_indices).
  • (C3) size(start_index_map) = index_vector_dim < rank(start_indices) ? dim(start_indices, index_vector_dim) : 1.
  • (C4) is_unique(offset_dims) and is_sorted(offset_dims).
  • (C5) 0 <= offset_dims < rank(result).
  • (C6) is_unique(collapsed_slice_dims) and is_sorted(collapsed_slice_dims).
  • (C7) 0 <= collapsed_slice_dims < rank(operand).
  • (C8) slice_sizes[collapsed_slice_dims...] <= 1.
  • (C9) is_unique(start_index_map).
  • (C10) 0 <= start_index_map < rank(operand).
  • (C11) size(slice_sizes) = rank(operand).
  • (C12) 0 <= slice_sizes <= shape(operand).
  • (C13) shape(result) = combine(batch_dim_sizes, offset_dim_sizes), em que:
    • batch_dim_sizes = shape(start_indices), exceto pelo fato de que o tamanho da dimensão start_indices correspondente a index_vector_dim não está incluído.
    • offset_dim_sizes = shape(slice_sizes), exceto pelo fato de que os tamanhos de dimensão em slice_sizes correspondentes a collapsed_slice_dims não estão incluídos.
    • combine coloca batch_dim_sizes em eixos correspondentes a batch_dims e offset_dim_sizes em eixos correspondentes a offset_dims.
  • (C14) element_type(operand) = element_type(result).

Exemplos

// %operand: [
//            [[1, 2], [3, 4], [5, 6], [7, 8]],
//            [[9, 10],[11, 12], [13, 14], [15, 16]],
//            [[17, 18], [19, 20], [21, 22], [23, 24]]
//           ]
// %start_indices: [
//                  [[0, 0], [1, 0], [2, 1]],
//                  [[0, 1], [1, 1], [0, 2]]
//                 ]
%result = "stablehlo.gather"(%operand, %start_indices) {
  dimension_numbers = #stablehlo.gather<
    offset_dims = [2, 3],
    collapsed_slice_dims = [0],
    start_index_map = [1, 0],
    index_vector_dim = 2>,
  slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>,
  indices_are_sorted = false
} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32>
// %result: [
//            [
//              [[1, 2], [3, 4]],
//              [[3, 4], [5, 6]],
//              [[13, 14], [15, 16]]
//            ],
//            [
//              [[9, 10], [11, 12]],
//              [[11, 12], [13, 14]],
//              [[17, 18], [19, 20]]
//            ]
//          ]

Mais exemplos

get_dimension_size

Semântica

Produz o tamanho do dimension especificado do operand. Mais formalmente, result = dim(operand, dimension).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand Tensor (C1)
(I2) dimension constante do tipo si64 (C1)

Saídas

Nome Tipo
result Tensor dimensional 0 do tipo si32

Restrições

  • (C1) 0 <= dimension < rank(operand).

Exemplos

// %operand: [[1, 2, 3], [4, 5, 6]]
%result = "stablehlo.get_dimension_size"(%operand) {
  dimension = 1 : i64
} : (tensor<2x3xi64>) -> tensor<i32>
// %result: 3

Mais exemplos

get_tuple_element

Semântica

Extrai o elemento na posição index da tupla operand e produz um result. Mais formalmente, result = operand[index].

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tuple (C1) e (C2)
(I2) index constante do tipo si32 (C1) e (C2)

Saídas

Nome Tipo Restrições
result qualquer tipo compatível (C2)

Restrições

  • (C1) 0 <= index < size(operand).
  • (C2) type(result) = tuple_element_types(operand)[index].

Exemplos

// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
  index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]

Mais exemplos

if

Semântica

Produz a saída da execução de exatamente uma função de true_branch ou false_branch, dependendo do valor de pred. Mais formalmente, result = pred ? true_branch() : false_branch().

Entradas

Rótulo Nome Tipo Restrições
(I1) pred Tensor dimensional 0 do tipo i1
(I2) true_branch função (C1 a C3)
(I3) false_branch função (C1) e (C2)

Saídas

Nome Tipo Restrições
results número variad de tensores, tensores quantizados ou tokens (C3)

Restrições

  • (C1) input_types(true_branch) = input_types(false_branch) = [].
  • (C2) output_types(true_branch) = output_types(false_branch).
  • (C3) type(results...) = output_types(true_branch).

Exemplos

// %result_true_branch: 10
// %result_false_branch: 11
// %pred: true
%result = "stablehlo.if"(%pred) ({
  "stablehlo.return"(%result_true_branch) : (tensor<i32>) -> ()
}, {
  "stablehlo.return"(%result_false_branch) : (tensor<i32>) -> ()
}) : (tensor<i1>) -> tensor<i32>
// %result: 10

Mais exemplos

Imagem

Semântica

Extrai a parte imaginária, por elemento, de operand e produz um tensor result. Mais formalmente, para cada elemento x: imag(x) = is_complex(x) ? imaginary_part(x) : constant(0, element_type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo (C1) e (C2)

Saídas

Nome Tipo Restrições
result tensor do tipo de ponto flutuante (C1) e (C2)

Restrições

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) é definida como:
    • complex_element_type(element_type(operand)) se is_complex(operand).
    • Caso contrário, element_type(operand).

Exemplos

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.imag"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [2.0, 4.0]

Mais exemplos

entrada

Semântica

Lê dados da entrada e produz results.

A semântica de infeed_config é definida pela implementação.

results consistem em valores de payload que vêm primeiro e um token que vem por último. No futuro, planejamos dividir o payload e o token em duas saídas separadas para melhorar a clareza (#670).

Entradas

Rótulo Nome Tipo
(I1) token token
(I2) infeed_config constante do tipo string

Saídas

Nome Tipo Restrições
results número variad de tensores, tensores quantizados ou tokens (C1 a C3)

Restrições

  • (C1) 0 < size(results).
  • (C2) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C3) is_token(type(results[-1])).

Exemplos

// %token: !stablehlo.token
// infeed_queue[0]: [[1, 2], [3, 4]]
// infeed_queue[1]: [[5, 6], [7, 8]]
%results0:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results0#0: [[1, 2], [3, 4]]
%results1:2 = "stablehlo.infeed"(%token) {
  infeed_config = ""
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)
// results1#0: [[5, 6], [7, 8]]

Mais exemplos

Iota

Semântica

Preenche um tensor output com valores em ordem crescente, começando de zero ao longo da dimensão iota_dimension. Mais formalmente,

output[result_index] = constant(is_quantized(output) ? quantize(result_index[iota_dimension], element_type(output)) : result_index[iota_dimension], element_type(output)).

Entradas

Rótulo Nome Tipo Restrições
(I1) iota_dimension si64 (C1)

Saídas

Nome Tipo Restrições
output tensor de número inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) 0 <= iota_dimension < rank(output).

Exemplos

%output = "stablehlo.iota"() {
  iota_dimension = 0 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 0, 0, 0, 0],
//           [1, 1, 1, 1, 1],
//           [2, 2, 2, 2, 2],
//           [3, 3, 3, 3, 3]
//          ]

%output = "stablehlo.iota"() {
  iota_dimension = 1 : i64
} : () -> tensor<4x5xi32>
// %output: [
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4],
//           [0, 1, 2, 3, 4]
//          ]

Mais exemplos

is_finite

Semântica

Executa elementos verificando se o valor em x é finito (ou seja, não é +Inf, -Inf ou NaN) e produz um tensor y. Implementa a operação isFinite da especificação IEEE-754. Para tipos quantizados, o resultado é sempre true.

Entradas

Rótulo Nome Tipo Restrições
(I1) x tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
y tensor de tipo booleano (C1)

Restrições

  • (C1) shape(x) = shape(y).

Exemplos

// Logical values: -Inf, +Inf, NaN, ...
// %x: [0xFFF0000000000000, 0x7FF0000000000000, 0x7FF8000000000000, -10.0, -0.0, 0.0, 10.0]
%y = "stablehlo.is_finite"(%x) : (tensor<7xf64) -> tensor<7xi1>
// %y: [false, false, false, true, true, true, true]

Mais exemplos

log

Semântica

Executa uma operação de logaritmo com elementos no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: log do IEEE-754.
  • Para números complexos: logaritmo complexo.
  • Para tipos quantizados: dequantize_op_quantize(log, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [[1.0, 2.0], [3.0, 4.0]]
%result = "stablehlo.log"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.0, 0.69314718055994529], [1.0986122886681098, 1.3862943611198906]]

Mais exemplos

log_plus_one

Semântica

Executa o logaritmo de elemento e uma operação no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: logp1 do IEEE-754.
  • Para números complexos: logaritmo complexo mais um.
  • Para tipos quantizados: dequantize_op_quantize(log_plus_one, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [0.0, -0.999, 7.0, 6.38905621, 15.0]
%result = "stablehlo.log_plus_one"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [0.0, -6.90776825, 2.07944155, 2.0, 2.77258873]

Mais exemplos

logística

Semântica

Executa a operação logística de elementos no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: division(1, addition(1, exp(-x))) do IEEE-754.
  • Para números complexos: logística complexa.
  • Para tipos quantizados: dequantize_op_quantize(logistic, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [[0.0, 1.0], [2.0, 3.0]]
%result = "stablehlo.logistic"(%operand) : (tensor<2x2xf64>) -> tensor<2x2xf64>
// %result: [[0.5, 0.73105858], [0.88079708, 0.95257413]]

Mais exemplos

mapa

Semântica

Aplica uma função de mapa computation a inputs ao longo do dimensions e produz um tensor result.

Mais formalmente, result[result_index] = computation(inputs...[result_index]). dimensions não é usado no momento e provavelmente será removido no futuro (#487).

Entradas

Rótulo Nome Tipo Restrições
(I1) inputs número variável de tensores ou tensores quantizados por tensor (C1 a C4)
(I2) dimensions constante do tensor unidimensional do tipo si64 (C3)
(I3) computation função (C4)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1) e (C4)

Restrições

  • (C1) shape(inputs...) = shape(result).
  • (C2) 0 < size(inputs) = N.
  • (C3) dimensions = range(rank(inputs[0])).
  • (C4) computation tem o tipo (tensor<E0>, ..., tensor<EN-1>) -> tensor<E'>, em que Ei = element_type(inputs[i]) e E' = element_type(result).

Exemplos

// %input0: [[0, 1], [2, 3]]
// %input1: [[4, 5], [6, 7]]
%result = "stablehlo.map"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<i64>
    stablehlo.return %0 : tensor<i64>
}) {
  dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// %result: [[0, 5], [12, 21]]

Mais exemplos

máximo

Semântica

Executa a operação máxima com elemento nos tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: OR lógico.
  • Para números inteiros: máximo de números inteiros.
  • Para dados flutuantes: maximum do IEEE-754.
  • Para números complexos: máximo lexicográfico para o par (real, imaginary). Impor uma ordenação em números complexos envolve semânticas surpreendentes. Portanto, no futuro, planejamos remover o suporte a números complexos para essa operação (#560).
  • Para tipos quantizados:
    • dequantize_op_quantize(maximum, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor ou tensor quantizado por tensor (C1)
(I2) rhs tensor ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemplos

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.maximum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 8]]

Mais exemplos

mínimo

Semântica

Executa uma operação mínima com elementos nos tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: AND lógico.
  • Para números inteiros: mínimo para números inteiros.
  • Para dados flutuantes: minimum do IEEE-754.
  • Para números complexos: mínimo lexicográfico para o par (real, imaginary). Impor uma ordenação em números complexos envolve semânticas surpreendentes. Portanto, no futuro, planejamos remover o suporte a números complexos para essa operação (#560).
  • Para tipos quantizados:
    • dequantize_op_quantize(minimum, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor ou tensor quantizado por tensor (C1)
(I2) rhs tensor ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemplos

// %lhs: [[1, 2], [7, 8]]
// %rhs: [[5, 6], [3, 4]]
%result = "stablehlo.minimum"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[1, 2], [3, 4]]

Mais exemplos

multiplicar

Semântica

Executa o produto com elementos de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: AND lógico.
  • Para números inteiros: multiplicação de números inteiros.
  • Para dados flutuantes: multiplication do IEEE-754.
  • Para números complexos: multiplicação complexa.
  • Para tipos quantizados:
    • dequantize_op_quantize(multiply, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor ou tensor quantizado por tensor (C1)
(I2) rhs tensor ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 12], [21, 32]]

Mais exemplos

negate

Semântica

Executa a negação com elementos do tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para números inteiros com assinatura: negação de números inteiros.
  • Para números inteiros sem assinatura: bitcast para número inteiro assinado, negação de número inteiro, bitcast de volta para número inteiro não assinado.
  • Para dados flutuantes: negate do IEEE-754.
  • Para números complexos: negação complexa.
  • Para tipos quantizados: dequantize_op_quantize(negate, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// Negation operation with integer Tensors
// %operand: [0, -2]
%result = "stablehlo.negate"(%operand) : (tensor<2xi32>) -> tensor<2xi32>
// %result: [0, 2]

// Negation operation with with complex tensors
// %operand: (2.5, 0.0)
%result = "stablehlo.negate"(%operand) : (tensor<1xcomplex<f32>>) -> tensor<1xcomplex<f32>>
// %result: [-2.5, -0.0]

Mais exemplos

não

Semântica

Executa NOT elemento do tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: NOT lógico.
  • Para números inteiros: NOT bit a bit.

Argumentos

Nome Tipo Restrições
operand Tensor do tipo booleano ou inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo booleano ou inteiro (C1)

Restrições

  • (C1) type(operand) = type(result).

Exemplos

// Bitwise operation with with integer tensors
// %operand: [[1, 2], [3, 4]]
%result = "stablehlo.not"(%operand) : (tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[-2, -3], [-4, -5]]

// Bitwise operation with with boolean tensors
// %operand: [true, false]
%result = "stablehlo.not"(%operand) : (tensor<2xi1>) -> tensor<2xi1>
// %result: [false, true]

optimization_barrier

Semântica

Garante que as operações que produzem o operand sejam executadas antes de qualquer operação que dependa de result e impede que as transformações do compilador movam operações pela barreira. Fora isso, a operação é uma identidade, ou seja, result = operand.

Argumentos

Nome Tipo Restrições
operand número variável de tensores, tensores quantizados por tensor ou tokens (C1)

Saídas

Nome Tipo Restrições
result número variável de tensores, tensores quantizados por tensor ou tokens (C1)

Restrições

  • (C1) type(operand...) = type(result...).

Exemplos

// %operand0: 0.0
// %operand1: 1.0
%result0, %result1 = "stablehlo.optimization_barrier"(%operand0, %operand1) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// %result0: 0.0
// %result1: 1.0

Mais exemplos

ou

Semântica

Executa OR de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: OR lógico.
  • Para números inteiros: OR bit a bit.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor de tipo inteiro ou booleano (C1)
(I2) rhs tensor de tipo inteiro ou booleano (C1)

Saídas

Nome Tipo Restrições
result tensor de tipo inteiro ou booleano (C1)

Restrições

  • (C1) type(lhs) = type(rhs) = type(result).

Exemplos

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 6], [7, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.or"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, true]]

saída

Semântica

Grava inputs na saída e produz um token result.

A semântica de outfeed_config é definida pela implementação.

Entradas

Rótulo Nome Tipo
(I1) inputs número variado de tensores ou tensores quantizados
(I2) token token
(I3) outfeed_config constante do tipo string

Saídas

Nome Tipo
result token

Exemplos

%result = "stablehlo.outfeed"(%inputs0, %token) {
  outfeed_config = ""
} : (tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token

Mais exemplos

bloco

Semântica

Expande operand pelo padding ao redor do tensor, bem como entre os elementos dele com o padding_value especificado.

edge_padding_low e edge_padding_high especificam a quantidade de padding adicionada na parte inferior (ao lado do índice 0) e na parte sofisticada (ao lado do índice mais alto) de cada dimensão, respectivamente. O padding pode ser negativo, já que o valor absoluto do padding negativo indica o número de elementos que vão ser removidos da dimensão especificada.

interior_padding especifica a quantidade de padding adicionada entre dois elementos em cada dimensão, que não podem ser negativos. O preenchimento interno ocorre antes do preenchimento da borda, de modo que o preenchimento da borda negativa remova elementos do operando com padding interno.

Mais formalmente, result[result_index] é definido como:

  • operand[operand_index] se result_index = edge_padding_low + operand_index * (interior_padding + 1).
  • Caso contrário, padding_value.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1), (C2) e (C4)
(I2) padding_value tensor de 0 dimensional ou tensor quantizado por tensor (C1)
(I3) edge_padding_low constante do tensor unidimensional do tipo si64 (C1) e (C4)
(I4) edge_padding_high constante do tensor unidimensional do tipo si64 (C1) e (C4)
(I5) interior_padding constante do tensor unidimensional do tipo si64 (C2 a C4)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C3 a C6)

Restrições

  • (C1) element_type(operand) = element_type(padding_value) = element_type(result).
  • (C2) size(edge_padding_low) = size(edge_padding_high) = size(interior_padding) = rank(operand).
  • (C3) 0 <= interior_padding.
  • (C4) shape(result) = shape(operand) + edge_padding_low + max(shape(operand) - 1, 0) * interior_padding + edge_padding_high.

Exemplos

// %operand: [
//            [1, 2, 3],
//            [4, 5, 6]
//           ]
// %padding_value: 0
%result = "stablehlo.pad"(%operand, %padding_value) {
  edge_padding_low = dense<[0, 1]> : tensor<2xi64>,
  edge_padding_high = dense<[2, 1]> : tensor<2xi64>,
  interior_padding = dense<[1, 2]> : tensor<2xi64>
} : (tensor<2x3xi32>, tensor<i32>) -> tensor<5x9xi32>
// %result: [
//           [0, 1, 0, 0, 2, 0, 0, 3, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 4, 0, 0, 5, 0, 0, 6, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0],
//           [0, 0, 0, 0, 0, 0, 0, 0, 0]
//          ]

Mais exemplos

partition_id

Semântica

Produz partition_id do processo atual.

Saídas

Nome Tipo
result Tensor dimensional 0 do tipo ui32

Exemplos

%result = "stablehlo.partition_id"() : () -> tensor<ui32>

Mais exemplos

Popcnt

Semântica

Executa a contagem elemento-chave do número de bits definidos no tensor operand e produz um tensor result.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand Tensor do tipo de número inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo de número inteiro (C1)

Restrições

  • (C1) type(operand) = type(result).

Exemplos

// %operand: [0, 1, 2, 127]
%result = "stablehlo.popcnt"(%operand) : (tensor<4xi64>) -> tensor<4xi64>
// %result: [0, 1, 1, 7]

Mais exemplos

potência

Semântica

Executa a exponenciação baseada em elementos do tensor lhs por rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para números inteiros: exponenciação de números inteiros.
  • Para dados flutuantes: pow do IEEE-754.
  • Para números complexos: exponenciação complexa.
  • Para tipos quantizados: dequantize_op_quantize(power, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)
(I2) rhs tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %lhs: [-2.0, -0.0, -36.0, 5.0, 3.0, 10000.0]
// %rhs: [2.0, 2.0, 1.1, 2.0, -1.0, 10.0]
%result = "stablehlo.power"(%lhs, %rhs) : (tensor<6xf64>, tensor<6xf64>) -> tensor<6xf64>
// %result: [4.0, 0.0, -nan, 25.0, 0.333333343, inf]

Mais exemplos

real

Semântica

Extrai a parte real, com elementos, de operand e produz um tensor result. Mais formalmente, para cada elemento x: real(x) = is_complex(x) ? real_part(x) : x.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo (C1) e (C2)

Saídas

Nome Tipo Restrições
result tensor do tipo de ponto flutuante (C1) e (C2)

Restrições

  • (C1) shape(result) = shape(operand).
  • (C2) element_type(result) é definida como:
    • complex_element_type(element_type(operand)) se is_complex(operand).
    • Caso contrário, element_type(operand).

Exemplos

// %operand: [(1.0, 2.0), (3.0, 4.0)]
%result = "stablehlo.real"(%operand) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// %result: [1.0, 3.0]

Mais exemplos

recv

Semântica

Recebe dados de um canal com channel_id e produz results.

Se is_host_transfer for true, a operação transferirá dados do host. Caso contrário, ele transfere dados de outro dispositivo. Isso significa que é definido pela implementação. Essa flag duplica as informações fornecidas em channel_type. Portanto, no futuro, planejamos manter apenas uma delas (#666).

results consistem em valores de payload que vêm primeiro e um token que vem por último. No futuro, planejamos dividir o payload e o token em duas saídas separadas para melhorar a clareza (#670).

Entradas

Rótulo Nome Tipo Restrições
(I1) token token (C4)
(I2) channel_id constante do tipo si64
(I3) channel_type enumeração de DEVICE_TO_DEVICE e HOST_TO_DEVICE (C1)
(I4) is_host_transfer constante do tipo i1 (C1)

Saídas

Nome Tipo Restrições
results número variad de tensores, tensores quantizados ou tokens (C2 a C4)

Restrições

  • (C1) channel_type é definida como:
    • HOST_TO_DEVICE se is_host_transfer = true,
    • Caso contrário, DEVICE_TO_DEVICE.
  • (C2) 0 < size(results).
  • (C3) is_empty(result[:-1]) ou is_tensor(type(results[:-1])).
  • (C4) is_token(type(results[-1])).

Exemplos

%results0, %results1 = "stablehlo.recv"(%token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 3>,
  is_host_transfer = true
} : (!stablehlo.token) -> (tensor<2x2xi64>, !stablehlo.token)

Mais exemplos

reduce

Semântica

Aplica uma função de redução body a inputs e init_values ao longo da dimensions e produz tensores results.

A ordem de reduções é definida pela implementação, o que significa que body e init_values precisam formar um monoide para garantir que a operação produza os mesmos resultados para todas as entradas em todas as implementações. No entanto, essa condição não se aplica a muitas reduções conhecidas. Por exemplo, a adição de ponto flutuante para body e zero para init_values não formam um monoide porque a adição de ponto flutuante não é associativa.

Mais formalmente, results...[j0, ..., jR-1] = reduce(input_slices_converted), em que:

  • input_slices = inputs...[j0, ..., :, ..., jR-1], em que : são inseridos em dimensions.
  • input_slices_converted = to_destination_type(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...).
  • init_values_converted = to_destination_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...).
  • reduce(input_slices_converted) = exec(schedule) para alguma árvore binária schedule em que:
    • exec(node) = body(exec(node.left), exec(node.right)).
    • exec(leaf) = leaf.value.
  • schedule é uma árvore binária completa definida pela implementação cuja travessia em ordem consiste em:
    • Valores input_slices_converted...[index], para todos os index em index_space(input_slices_converted) na ordem lexicográfica crescente de index.
    • Intercalada com uma quantidade definida de init_values_converted em posições definidas pela implementação.

Entradas

Rótulo Nome Tipo Restrições
(I1) inputs número variável de tensores ou tensores quantizados por tensor (C1 a C4), (C6) e (C7)
(I2) init_values número variado de tensores 0-dimensionais ou tensores quantizados por tensor (C2) e (C3)
(I3) dimensions constante do tensor unidimensional do tipo si64 (C4), (C5) e (C7)
(I4) body função (C6)

Saídas

Nome Tipo Restrições
results número variável de tensores ou tensores quantizados por tensor (C3), (C7) (C8)

Restrições

  • (C1) same(shape(inputs...)).
  • (C2) element_type(inputs...) = element_type(init_values...).
  • (C3) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C4) 0 <= dimensions < rank(inputs[0]).
  • (C5) is_unique(dimensions).
  • (C6) body tem o tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), em que is_promotable(element_type(inputs[i]), Ei).
  • (C7) shape(results...) = shape(inputs...), exceto pelo fato de que os tamanhos de dimensão de inputs... correspondentes a dimensions não são incluídos.
  • (C8) element_type(results[i]) = Ei para todos os i em [0,N).

Exemplos

// %input = [[0, 1, 2, 3, 4, 5]]
// %init_value = 0
%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
// %result = [15]

Mais exemplos

reduce_precision

Semântica

Executa a conversão com elementos de operand para outro tipo de ponto flutuante que usa exponent_bits e mantissa_bits, de volta ao tipo de ponto flutuante original, e produz um tensor output.

Mais formalmente:

  • Os bits de mantissa do valor original são atualizados para arredondar o valor original para o valor mais próximo representável com mantissa_bits usando a semântica roundToIntegralTiesToEven.
  • Em seguida, se mantissa_bits for menor que o número de bits de mantissa do valor original, os bits de mantissa serão truncados para mantissa_bits.
  • Em seguida, se os bits expoentes do resultado intermediário não se encaixarem no intervalo fornecido por exponent_bits, o resultado intermediário vai ultrapassar o infinito usando o sinal original ou ir para zero com o sinal original.
  • Para tipos quantizados, executa dequantize_op_quantize( lambda operand: reduce_precision(operand, exponent_bits, mantissa_bits), operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)
(I2) exponent_bits constante do tipo si32 (C2)
(I3) mantissa_bits constante do tipo si32 (C3)

Saídas

Nome Tipo Restrições
output tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(output).
  • (C2) 1 <= exponent_bits.
  • (C3) 0 <= mantissa_bits.

Exemplos

// Logical values: +Inf, NaN, +Denormal, 0.0, 65519.0, 65520.0
// %operand: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0x0000000000000001, 0.0, 65519.0, 65520.0]
%output = "stablehlo.reduce_precision"(%operand) {
  exponent_bits = 5 : i32,
  mantissa_bits = 10 : i32
} : (tensor<6xf64>) -> tensor<6xf64>
// Logical values: +Inf, NaN, 0.0, 0.0, 65504.0, +Inf
// %output: [0x7FF0000000000000, 0x7FFFFFFFFFFFFFFF, 0.0, 0.0, 65504.0, 0x7FF0000000000000]

Mais exemplos

reduce_scatter

Semântica

Dentro de cada grupo de processos na grade de processos do StableHLO, realiza a redução, usando computations, sobre os valores do tensor operand de cada processo, divide o resultado da redução ao longo de scatter_dimension em partes e dispersa as partes divididas entre os processos para produzir o result.

A operação divide a grade do processo StableHLO em process_groups, que é definida da seguinte maneira:

  • cross_replica(replica_groups) se channel_id <= 0 and use_global_device_ids = false.
  • cross_replica_and_partition(replica_groups) se channel_id > 0 and use_global_device_ids = false.
  • flattened_ids(replica_groups) se channel_id > 0 and use_global_device_ids = true.

Em seguida, em cada process_group:

  • reduced_value = all_reduce(operand, replica_groups, channel_id, use_global_device_ids, computation).
  • parts@sender = split(reduced_value@sender, dim(process_groups, 1), scatter_dimension).
  • result@receiver = parts@sender[receiver_index] para todos os sender em process_group, em que receiver_index = process_group.index(receiver).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1), (C2), (C7) (C8)
(I2) scatter_dimension constante do tipo si64 (C1), (C2) e (C8)
(I3) replica_groups constante do tensor bidimensional do tipo si64 (C3 a C5)
(I4) channel_id constante do tipo si64 (C6)
(I5) use_global_device_ids constante do tipo i1 (C6)
(I6) computation função (C7)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C8 a C9)

Restrições

  • (C1) dim(operand, scatter_dimension) % dim(process_groups, 1) = 0.
  • (C2) 0 <= scatter_dimension < rank(operand).
  • (C3) is_unique(replica_groups).
  • (C4) size(replica_groups) é definida como:
    • num_replicas se cross_replica for usado.
    • num_replicas se cross_replica_and_partition for usado.
    • num_processes se flattened_ids for usado.
  • (C5) 0 <= replica_groups < size(replica_groups).
  • (C6) Se use_global_device_ids = true, então channel_id > 0.
  • (C7) computation tem o tipo (tensor<E>, tensor<E>) -> (tensor<E>), em que is_promotable(element_type(operand), E).
  • (C8) shape(result) = shape(operand) exceto:
    • dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1).
  • (C9) element_type(result) = E.

Exemplos

// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2, 3, 4],
//                   [5, 6, 7, 8]]
// %operand@(1, 0): [[9, 10, 11, 12],
//                   [13, 14, 15, 16]]
%result = "stablehlo.reduce_scatter"(%operand) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
  %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
  "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension = 1 : i64,
  replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
  channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
//
// %result@(0, 0): [[10, 12],
//                  [18, 20]]
// %result@(1, 0): [[14, 16],
//                  [22, 24]]

Mais exemplos

reduce_window

Semântica

Aplica uma função de redução body a janelas de inputs e init_values e produz results.

O diagrama a seguir mostra como os elementos em results... são calculados a partir de inputs... usando um exemplo concreto.

Mais formalmente, results...[result_index] = reduce(windows, init_values, axes(inputs...), body) (consulte reduzir), em que:

  • padded_inputs = pad(inputs..., init_values..., padding[:, 0], padding[:, 1], base_dilations - 1).
  • window_start = result_index * window_strides.
  • window_end = window_start + (window_dimensions - 1) * window_dilations + 1.
  • windows = slice(padded_inputs..., window_start, window_end, window_dilations).

Entradas

Rótulo Nome Tipo Restrições
(I1) inputs número variável de tensores ou tensores quantizados por tensor (C1 a C4), (C6), (C8), (C10), (C12), (C13) (C15)
(I2) init_values número variado de tensores 0-dimensionais ou tensores quantizados por tensor (C1) e (C13)
(I3) window_dimensions constante do tensor unidimensional do tipo si64 (C4), (C5) (C15)
(I4) window_strides constante do tensor unidimensional do tipo si64 (C6), (C7) (C15)
(I5) base_dilations constante do tensor unidimensional do tipo si64 (C8), (C9) (C15)
(I6) window_dilations constante do tensor unidimensional do tipo si64 (C10), (C11) (C15)
(I7) padding constante do tensor bidimensional do tipo si64 (C12) e (C15)
(I8) body função (C13)

Saídas

Nome Tipo Restrições
results número variável de tensores ou tensores quantizados por tensor (C1) e (C14 a C16).

Restrições

  • (C1) 0 < size(inputs) = size(init_values) = size(results) = N.
  • (C2) same(shape(inputs...)).
  • (C3) element_type(inputs...) = element_type(init_values...).
  • (C4) size(window_dimensions) = rank(inputs[0]).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(inputs[0]).
  • (C7) 0 < window_strides.
  • (C8) size(base_dilations) = rank(inputs[0]).
  • (C9) 0 < base_dilations.
  • (C10) size(window_dilations) = rank(inputs[0]).
  • (C11) 0 < window_dilations.
  • (C12) shape(padding) = [rank(inputs[0]), 2].
  • (C13) body tem o tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), em que is_promotable(element_type(inputs[i]), Ei).
  • (C14) same(shape(results...)).
  • (C15) shape(results[0]) = num_windows em que:
    • dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1.
    • padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1].
    • dilated_window_shape = (window_dimensions - 1) * window_dilations + 1.
    • is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1.
  • (C16) element_type(results[i]) = Ei para todos os i em [0,N).

Exemplos

// %input = [[1, 2], [3, 4], [5, 6]]
// %init_value = 0
%result = "stablehlo.reduce_window"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[2, 1]> : tensor<2xi64>,
  window_strides = dense<[4, 1]> : tensor<2xi64>,
  base_dilations = dense<[2, 1]> : tensor<2xi64>,
  window_dilations = dense<[3, 1]> : tensor<2xi64>,
  padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
// %result = [[0, 0], [3, 4]]

Mais exemplos

restante

Semântica

Executa o resto com elementos dos tensores de dividendo lhs e divisor rhs e produz um tensor result.

Mais formalmente, o sinal do resultado é tirado do dividendo, e o valor absoluto do resultado é sempre menor que o valor absoluto do divisor. O restante é calculado como lhs - d * rhs, em que d é fornecido por:

  • Para números inteiros: stablehlo.divide(lhs, rhs).
  • Para pontos flutuantes: division(lhs, rhs) do IEEE-754 com atributo de arredondamento roundTowardZero.
  • Para números complexos: a definir (#997).
  • Para tipos quantizados:
    • dequantize_op_quantize(remainder, lhs, rhs, type(result)).

Para tipos de elementos de ponto flutuante, essa operação é diferente da operação remainder da especificação IEEE-754, em que d é um valor integral mais próximo do valor exato de lhs/rhs com vínculos a pares.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor de número inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)
(I2) rhs tensor de número inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de número inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %lhs: [17, -17, 17, -17]
// %rhs: [3, 3, -3, -3]
%result = "stablehlo.remainder"(%lhs, %rhs) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi64>
// %result: [2, -2, 2, -2]

Mais exemplos

replica_id

Semântica

Produz replica_id do processo atual.

Saídas

Nome Tipo
result Tensor dimensional 0 do tipo ui32

Exemplos

%result = "stablehlo.replica_id"() : () -> tensor<ui32>

Mais exemplos

remodelar

Semântica

Executa a remodelação do tensor operand para um tensor result. Conceitualmente, isso significa manter a mesma representação canônica, mas possivelmente alterar a forma, por exemplo, de tensor<2x3xf32> para tensor<3x2xf32> ou tensor<6xf32>.

Mais formalmente, result[result_index] = operand[operand_index], em que result_index e operand_index têm a mesma posição na ordem lexicográfica de index_space(result) e index_space(operand).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado (C1 a C3)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado (C1 a C3)

Restrições

  • (C1) element_type(result) é fornecido por:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand), exceto que quantization_dimension(operand) e quantization_dimension(result) podem ser diferentes, caso contrário.
  • (C2) size(operand) = size(result).
  • (C3) Se is_per_axis_quantized(operand):
    • reduce(dims(operand, [0, 1, ..., quantization_dimension(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [0, 1, ..., quantization_dimension(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).
    • dim(operand, quantization_dimension(operand)) = dim(result, quantization_dimension(result)).
    • reduce(dims(operand, [quantization_dimension(operand) + 1, ..., rank(operand) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y) = reduce(dims(result, [quantization_dimension(result) + 1, ..., rank(result) - 1]), init_values=1, dimensions=[0], body=lambda x, y: x * y).

Exemplos

// %operand: [[1, 2, 3], [4, 5, 6]]]
%result = "stablehlo.reshape"(%operand) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// %result: [[1, 2], [3, 4], [5, 6]]

Mais exemplos

anular

Semântica

Inverte a ordem dos elementos na operand ao longo do dimensions especificado e produz um tensor result. Mais formalmente, result[result_index] = operand[operand_index], em que:

  • operand_index[d] = dim(result, d) - result_index[d] - 1 se d em dimensions.
  • Caso contrário, operand_index[d] = result_index[d].

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1) e (C3)
(I2) dimensions constante do tensor unidimensional do tipo si64 (C2) e (C3)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1) e (C3)

Restrições

  • (C1) type(operand) = type(result).
  • (C2) is_unique(dimensions).
  • (C3) 0 <= dimensions < rank(result).

Exemplos

// %operand = [[1, 2], [3, 4], [5, 6]]
%result = "stablehlo.reverse"(%operand) {
  dimensions = dense<1> : tensor<1xi64>
} : (tensor<3x2xi32>) -> tensor<3x2xi32>
// %result: [[2, 1], [4, 3], [6, 5]]

Mais exemplos

rm

Semântica

Gera números aleatórios usando o algoritmo rng_distribution e produz um tensor result de uma determinada forma shape.

Se for rng_distribution = UNIFORM, os números aleatórios serão gerados seguindo a distribuição uniforme no intervalo [a, b). Se a >= b, o comportamento é indefinido.

Se for rng_distribution = NORMAL, os números aleatórios serão gerados seguindo a distribuição normal com média = a e desvio padrão = b. Se for b < 0, o comportamento será indefinido.

A maneira exata como os números aleatórios são gerados é definida pela implementação. Por exemplo, eles podem ou não ser determinísticos e podem ou não usar o estado oculto.

Em conversas com muitas partes interessadas, essa operação surgiu como sendo efetivamente descontinuada. Portanto, no futuro, planejamos removê-la (#597).

Entradas

Rótulo Nome Tipo Restrições
(I1) a Tensor dimensional de tipo inteiro, booleano ou ponto flutuante (C1) e (C2)
(I2) b Tensor dimensional de tipo inteiro, booleano ou ponto flutuante (C1) e (C2)
(I3) shape constante do tensor unidimensional do tipo si64 (C3)
(I4) rng_distribution enumeração de UNIFORM e NORMAL (C2)

Saídas

Nome Tipo Restrições
result Tensor do tipo inteiro, booleano ou ponto flutuante (C1 a C3)

Restrições

  • (C1) element_type(a) = element_type(b) = element_type(result).
  • (C2) Se rng_distribution = NORMAL, então is_float(a).
  • (C3) shape(result) = shape.

Exemplos

// %a = 0
// %b = 2
// %shape = [3, 3]
%result = "stablehlo.rng"(%a, %b, %shape) {
  rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<i32>, tensor<i32>, tensor<2xi64>) -> tensor<3x3xi32>
// %result: [
//           [1, 0, 1],
//           [1, 1, 1],
//           [0, 0, 0]
//          ]

rng_bit_generator

Semântica

Retorna um output preenchido com bits aleatórios uniformes e um estado de saída atualizado output_state usando o algoritmo gerador de números pseudoaleatórios rng_algorithm que tem um estado inicial initial_state. A saída é garantida como uma função determinística de initial_state, mas não é garantido que ela seja determinística entre as implementações.

rng_algorithm é um destes:

  • DEFAULT: algoritmo definido pela implementação.
  • THREE_FRY: variante definida pela implementação do algoritmo Threefry.*
  • PHILOX: variante definida pela implementação do algoritmo Philox.*

* Consulte: Salmon et al. SC 2011. Números aleatórios paralelos: 1, 2, 3.

Entradas

Rótulo Nome Tipo Restrições
(I1) rng_algorithm enumeração de DEFAULT, THREE_FRY e PHILOX (C2)
(I2) initial_state Tensor unidimensional do tipo ui64 (C1) e (C2)

Saídas

Nome Tipo Restrições
output_state Tensor unidimensional do tipo ui64 (C1)
output tensor de tipo inteiro ou de ponto flutuante

Restrições

  • (C1) type(initial_state) = type(output_state).
  • (C2) size(initial_state) é definida como:
    • definido pela implementação se rng_algorithm = DEFAULT.
    • 2 se rng_algorithm = THREE_FRY.
    • 2 ou 3 se rng_algorithm = PHILOX.

Exemplos

// %initial_state: [1, 2]
%output_state, %output = "stablehlo.rng_bit_generator"(%initial_state) {
  rng_algorithm = #stablehlo<rng_algorithm THREE_FRY>
} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
// %output_state: [1, 6]
// %output: [
//           [9236835810183407956, 16087790271692313299],
//           [18212823393184779219, 2658481902456610144]
//          ]

round_nearest_afz

Semântica

Executa arredondamento por elemento em direção ao número inteiro mais próximo, separando os vínculos de zero, no tensor operand, e produz um tensor result. Implementa a operação roundToIntegralTiesToAway da especificação IEEE-754. Para tipos quantizados, executa dequantize_op_quantize(round_nearest_afz, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_afz"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-3.0, 0.0, 1.0, 1.0, 3.0]

Mais exemplos

round_nearest_even

Semântica

Executa arredondamento por elemento em direção ao número inteiro mais próximo, rompendo vínculos com o número inteiro par, no tensor operand, e produz um tensor result. Implementa a operação roundToIntegralTiesToEven da especificação IEEE-754. Para tipos quantizados, executa dequantize_op_quantize(round_nearest_even, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de tipo de ponto flutuante ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]

Mais exemplos

rsqrt

Semântica

Executa uma operação de raiz quadrada recíproca de elementos no tensor operand e gera um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: rSqrt do IEEE-754.
  • Para números complexos: raiz quadrada recíproca complexa.
  • Para tipos quantizados: dequantize_op_quantize(rsqrt, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [[1.0, 4.0], [9.0, 25.0]]
%result = "stablehlo.rsqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[1.0, 0.5], [0.33333343, 0.2]]

Mais exemplos

scatter

Semântica

Produz tensores results, que são iguais a inputs, mas várias frações especificadas por scatter_indices são atualizadas com os valores updates usando update_computation.

O diagrama a seguir mostra como os elementos em updates... são mapeados em elementos em results... usando um exemplo concreto. O diagrama seleciona alguns exemplos de índices de updates... e explica em detalhes a quais índices de results... eles correspondem.

Mais formalmente, para todos os update_index em index_space(updates[0]):

  • update_scatter_dims = [d for d in axes(updates[0]) and d not in update_window_dims].
  • update_scatter_index = update_index[update_scatter_dims...].
  • start_index é definido como:
    • scatter_indices[si0, ..., :, ..., siN], em que si são elementos individuais em update_scatter_index, e : é inserido no índice index_vector_dim, se index_vector_dim < rank(scatter_indices).
    • Caso contrário, [scatter_indices[update_scatter_index]].
  • Para d_input em axes(inputs[0]),
    • full_start_index[d_input] = start_index[d_start] se d_input = scatter_dims_to_operand_dims[d_start].
    • Caso contrário, full_start_index[d_input] = 0.
  • update_window_index = update_index[update_window_dims...].
  • full_window_index = [wi0, ..., 0, ..., wiN], em que wi são elementos individuais em update_window_index, e 0 é inserido em índices de inserted_window_dims.
  • result_index = full_start_index + full_window_index.

Sendo assim, results = exec(schedule, inputs), em que:

  • schedule é uma permutação definida pela implementação de index_space(updates[0]).
  • exec([update_index, ...], results) = exec([...], updated_results) em que:
    • Se result_index estiver dentro dos limites para shape(results...)
    • updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )
    • updated_values = update_computation(results...[result_index], updates_converted)
    • updated_results é uma cópia de results com results...[result_index] definido como updated_values....
    • Como alternativa, faça o seguinte:
    • updated_results = results.
  • exec([], results) = results.

Se indices_are_sorted for true, a implementação poderá presumir que scatter_indices estão classificados em relação a scatter_dims_to_operand_dims. Caso contrário, o comportamento será indefinido. Mais formalmente, para todos os i1 < i2 de indices(result), full_start_index(i1) <= full_start_index(i2).

Se unique_indices for true, a implementação vai poder presumir que todos os índices de result_index distribuídos são exclusivos. Se unique_indices for true, mas os índices sendo dispersos não forem exclusivos, o comportamento será indefinido.

Entradas

Rótulo Nome Tipo Restrições
(I1) inputs número variável de tensores ou tensores quantizados por tensor (C1), (C2), (C4-C6), (C10), (C13) (C15-C16)
(I2) scatter_indices Tensor do tipo de número inteiro (C4), (C11) (C14)
(I3) updates número variável de tensores ou tensores quantizados por tensor (C3 a C6)
(I4) update_window_dims constante do tensor unidimensional do tipo si64 (C2), (C4), (C7) (C8)
(I5) inserted_window_dims constante do tensor unidimensional do tipo si64 (C2), (C4), (C9) (C10)
(I6) scatter_dims_to_operand_dims constante do tensor unidimensional do tipo si64 (C11 a C13)
(I7) index_vector_dim constante do tipo si64 (C4), (C11) (C14)
(I8) indices_are_sorted constante do tipo i1
(I9) unique_indices constante do tipo i1
(I10) update_computation função (C15)

Saídas

Nome Tipo Restrições
results número variável de tensores ou tensores quantizados por tensor (C15 a C17)

Restrições

  • (C1) same(shape(inputs...)).
  • (C2) rank(inputs[0]) = size(update_window_dims) + size(inserted_window_dims).
  • (C3) same(shape(updates...)).
  • (C4) shape(updates[0]) = combine(update_scatter_dim_sizes, update_window_dim_sizes), em que:
    • update_scatter_dim_sizes = shape(scatter_indices), exceto pelo fato de que o tamanho da dimensão de scatter_indices correspondente a index_vector_dim não está incluído.
    • update_window_dim_sizes <= shape(inputs[0]), exceto pelo fato de que os tamanhos de dimensão em inputs[0] correspondentes a inserted_window_dims não estão incluídos.
    • combine coloca update_scatter_dim_sizes em eixos correspondentes a update_scatter_dims e update_window_dim_sizes em eixos correspondentes a update_window_dims.
  • (C5) 0 < size(inputs) = size(updates) = N.
  • (C6) element_type(updates...) = element_type(inputs...).
  • (C7) is_unique(update_window_dims) and is_sorted(update_window_dims).
  • (C8) 0 <= update_window_dims < rank(updates[0]).
  • (C9) is_unique(inserted_window_dims) and is_sorted(update_window_dims).
  • (C10) 0 <= inserted_window_dims < rank(inputs[0]).
  • (C11) size(scatter_dims_to_operand_dims) = index_vector_dim < rank(scatter_indices) ? dim(scatter_indices, index_vector_dim) : 1.
  • (C12) is_unique(scatter_dims_to_operand_dims).
  • (C13) 0 <= scatter_dims_to_operand_dims < rank(inputs[0]).
  • (C14) 0 <= index_vector_dim <= rank(scatter_indices).
  • (C15) update_computation tem o tipo (tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>), em que is_promotable(element_type(inputs[i]), Ei).
  • (C16) shape(inputs...) = shape(results...).
  • (C17) element_type(results[i]) = Ei para todos os i em [0,N).

Exemplos

// %input: [
//          [[1, 2], [3, 4], [5, 6], [7, 8]],
//          [[9, 10], [11, 12], [13, 14], [15, 16]],
//          [[17, 18], [19, 20], [21, 22], [23, 24]]
//         ]
// %scatter_indices: [[[0, 2], [1, 0], [2, 1]], [[0, 1], [1, 0], [0, 9]]]
// %update: [
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]],
//           [[[1, 1], [1, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]]
//          ]
%result = "stablehlo.scatter"(%input, %scatter_indices, %update) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  scatter_dimension_numbers = #stablehlo.scatter<
    update_window_dims = [2, 3],
    inserted_window_dims = [0],
    scatter_dims_to_operand_dims = [1, 0],
    index_vector_dim = 2>,
  indices_are_sorted = false,
  unique_indices = false
} : (tensor<3x4x2xi64>, tensor<2x3x2xi64>, tensor<2x3x2x2xi64>) -> tensor<3x4x2xi64>
// %result: [
//           [[1, 2], [5, 6], [7, 8], [7, 8]],
//           [[10, 11], [12, 13], [14, 15], [16, 17]],
//           [[18, 19], [20, 21], [21, 22], [23, 24]]
//          ]

Mais exemplos

select

Semântica

Produz um tensor result em que cada elemento é selecionado do tensor on_true ou on_false com base no valor do elemento correspondente de pred. Mais formalmente, result[result_index] = pred_element ? on_true[result_index] : on_false[result_index], em que pred_element = rank(pred) = 0 ? pred[] : pred[result_index]. Para tipos quantizados, executa dequantize_select_quantize(pred, on_true, on_false, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) pred tensor do tipo i1 (C1)
(I2) on_true tensor ou tensor quantizado por tensor (C1 a C2)
(I3) on_false tensor ou tensor quantizado por tensor (C2)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C2)

Restrições

  • (C1) rank(pred) = 0 or shape(pred) = shape(on_true).
  • (C2) baseline_type(on_true) = baseline_type(on_false) = baseline_type(result).

Exemplos

// %pred: [[false, true], [true, false]]
// %on_true: [[1, 2], [3, 4]]
// %on_false: [[5, 6], [7, 8]]
%result = "stablehlo.select"(%pred, %on_true, %on_false) : (tensor<2x2xi1>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[5, 2], [3, 8]]

Mais exemplos

select_and_scatter

Semântica

Dispersa os valores do tensor source usando scatter com base no resultado de reduce_window do tensor input usando select e produz um tensor result.

O diagrama a seguir mostra como os elementos em result são calculados a partir de operand e source usando um exemplo concreto.

Mais formalmente:

  • selected_values = reduce_window_without_init(...) com as seguintes entradas:

    • "inputs" = [opera].
    • window_dimensions, window_strides e padding, que são usados no estado em que se encontram.
    • base_dilations = windows_dilations = 1.
    • body é definido como:
    def body(arg0: tensor<E>, arg1: tensor<E>) -> tensor<E>:
      return select(arg0, arg1) ? arg0 : arg1;
    

    em que E = element_type(operand) e reduce_window_without_init funcionam exatamente como reduce_window, exceto que o schedule do reduce subjacente (consulte reduzir) não inclui valores init. Atualmente, não está especificado o que acontece se a janela correspondente não tiver valores (#731).

  • result[result_index] = reduce([source_values], [init_value], [0], scatter), em que:

    • source_values = [source[source_index] for source_index in source_indices].
    • selected_index(source_index) = operand_index se selected_values[source_index] tiver o elemento operand de operand_index.
    • source_indices = [source_index for source_index in indices(source) if selected_index(source_index) = result_index].

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1 a C4), (C6) e (C8 a C11)
(I2) source tensor ou tensor quantizado por tensor (C1) e (C2)
(I3) init_value tensor de 0 dimensional ou tensor quantizado por tensor (C3)
(I4) window_dimensions constante do tensor unidimensional do tipo si64 (C2), (C4) (C5)
(I5) window_strides constante do tensor unidimensional do tipo si64 (C2), (C6) (C7)
(I6) padding constante do tensor bidimensional do tipo si64 (C2) e (C8)
(I7) select função (C9)
(I8) scatter função (C10)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C11 a C12)

Restrições

  • (C1) element_type(operand) = element_type(source).
  • (C2) shape(source) = num_windows, em que:
    • padded_operand_shape = padding[:, 0] + shape(operand) + padding[:, 1].
    • is_empty_window = padded_operand_shape = 0 || window_dimensions > padded_operand_shape.
    • num_windows = is_empty_window ? 0 : floor((padded_operand_shape - window_dimensions) / window_strides) + 1.
  • (C3) element_type(init_value) = element_type(operand).
  • (C4) size(window_dimensions) = rank(operand).
  • (C5) 0 < window_dimensions.
  • (C6) size(window_strides) = rank(operand).
  • (C7) 0 < window_strides.
  • (C8) shape(padding) = [rank(operand), 2].
  • (C9) select tem o tipo (tensor<E>, tensor<E>) -> tensor<i1>, em que E = element_type(operand).
  • (C10) scatter tem o tipo (tensor<E>, tensor<E>) -> tensor<E>, em que is_promotable(element_type(operand), E).
  • (C11) shape(operand) = shape(result).
  • (C12) element_type(result) = E.

Exemplos

// %operand: [[1, 5], [2, 5], [3, 6], [4, 4]]
// %source: [[5, 6], [7, 8]]
// %init_value: 0
%result = "stablehlo.select_and_scatter"(%operand, %source, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GE>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%0) : (tensor<i1>) -> ()
}, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  window_dimensions = dense<[3, 1]> : tensor<2xi64>,
  window_strides = dense<[2, 1]> : tensor<2xi64>,
  padding = dense<[[0, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<4x2xi64>, tensor<2x2xi64>, tensor<i64>) -> tensor<4x2xi64>
// %result: [[0, 0], [0, 0], [5, 14], [7, 0]]

Mais exemplos

send

Semântica

Envia inputs para um canal channel_id e produz um token result.

Se is_host_transfer for true, a operação transferirá dados para o host. Caso contrário, ele vai transferir dados para outro dispositivo. Isso significa que é definido pela implementação. Essa flag duplica as informações fornecidas em channel_type. Portanto, no futuro, planejamos manter apenas uma delas (#666).

Entradas

Rótulo Nome Tipo Restrições
(I1) inputs número variado de tensores ou tensores quantizados
(I2) token token
(I3) channel_id constante do tipo si64
(I4) channel_type enumeração de DEVICE_TO_DEVICE e DEVICE_TO_HOST (C1)
(I5) is_host_transfer constante do tipo i1 (C1)

Saídas

Nome Tipo
result token

Restrições

  • (C1) channel_type é definida como:
    • DEVICE_TO_HOST se is_host_transfer = true,
    • Caso contrário, DEVICE_TO_DEVICE.

Exemplos

%result = "stablehlo.send"(%operand, %token) {
  channel_handle = #stablehlo.channel_handle<handle = 1, type = 2>,
  is_host_transfer = true
} : (tensor<2x2xi64>, !stablehlo.token) -> !stablehlo.token

Mais exemplos

shift_left

Semântica

Executa o elemento de deslocamento para a esquerda no tensor lhs por número de rhs de bits e produz um tensor result.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs Tensor do tipo de número inteiro (C1)
(I2) rhs Tensor do tipo de número inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo de número inteiro (C1)

Restrições

  • (C1) type(lhs) = type(rhs) = type(result).

Exemplos

// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]

Mais exemplos

shift_right_arithmetic

Semântica

Executa a operação aritmética de deslocamento com elementos à direita no tensor lhs por número rhs de bits e produz um tensor result.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs Tensor do tipo de número inteiro (C1)
(I2) rhs Tensor do tipo de número inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo de número inteiro (C1)

Restrições

  • (C1) type(lhs) = type(rhs) = type(result).

Exemplos

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_arithmetic"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-1, 0, 1]

Mais exemplos

shift_right_logical

Semântica

Executa uma operação lógica de deslocamento para a direita com elementos no tensor lhs por número rhs de bits e produz um tensor result.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs Tensor do tipo de número inteiro (C1)
(I2) rhs Tensor do tipo de número inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo de número inteiro (C1)

Restrições

  • (C1) type(lhs) = type(rhs) = type(result).

Exemplos

// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]

Mais exemplos

de igual.

Semântica

Retorna o sinal de operand com elementos e produz um tensor result. Mais formalmente, para cada elemento x, a semântica pode ser expressa usando a sintaxe do Python da seguinte maneira:

def sign(x):
  if is_integer(x):
    if compare(x, 0, LT, SIGNED): return -1
    if compare(x, 0, EQ, SIGNED): return 0
    return 1
  elif is_float(x):
    if is_nan(x): return NaN
    if compare(x, -0.0, EQ, FLOAT): return -0.0
    if compare(x, +0.0, EQ, FLOAT): return +0.0
    if compare(x, 0.0, LT, FLOAT): return -1.0
    return 1.0
  elif is_complex(x):
    if is_nan(real(x)) or is_nan(imag(x)): return (NaN, NaN)
    if compare(x, (0.0, 0.0), EQ, FLOAT): return (0.0, 0.0)
    return divide(x, convert(abs(x), type(x)))

Para tipos quantizados, executa dequantize_op_quantize(sign, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de número inteiro assinado, ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de número inteiro assinado, ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// operand: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]
%result = "stablehlo.sign"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// Logical values: +NaN, -1.0, -0.0, +0.0, 1.0
// %result: [0x7FFFFFFFFFFFFFFF, -1.0, -0.0, 0.0, 1.0]

Mais exemplos

seno

Semântica

Executa a operação de seno de elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: sin do IEEE-754.
  • Para números complexos: seno complexo.
  • Para tipos quantizados: dequantize_op_quantize(sine, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [
//            [0.0, 1.57079632],       // [0, pi/2]
//            [3.14159265, 4.71238898] // [pi, 3pi/2]
//           ]
%result = "stablehlo.sine"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [0.0, -1.0]]

Mais exemplos

slice

Semântica

Extrai uma fração do operand usando índices iniciais calculados estaticamente e produz um tensor result. start_indices contém os índices iniciais da fração para cada dimensão, limit_indices contém os índices finais (exclusivo) para a fatia de cada dimensão, e strides contém os passos de cada dimensão.

Mais formalmente, result[result_index] = operand[operand_index], em que operand_index = start_indices + result_index * strides.

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado por tensor (C1 a C3) e (C5)
(I2) start_indices constante do tensor unidimensional do tipo si64 (C2), (C3) e (C5)
(I3) limit_indices constante do tensor unidimensional do tipo si64 (C2), (C3) e (C5)
(I4) strides constante do tensor unidimensional do tipo si64 (C2) e (C4)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado por tensor (C1) e (C5)

Restrições

  • (C1) element_type(operand) = element_type(result).
  • (C2) size(start_indices) = size(limit_indices) = size(strides) = rank(operand).
  • (C3) 0 <= start_indices <= limit_indices <= shape(operand).
  • (C4) 0 < strides.
  • (C5) shape(result) = ceil((limit_indices - start_indices) / strides).

Exemplos

// %operand: [
//            [0, 0, 0, 0],
//            [0, 0, 1, 1],
//            [0, 0, 1, 1]
//           ]
%result = "stablehlo.slice"(%operand) {
  start_indices = dense<[1, 2]> : tensor<2xi64>,
  limit_indices = dense<[3, 4]> : tensor<2xi64>,
  strides = dense<1> : tensor<2xi64>
} : (tensor<3x4xi64>) -> tensor<2x2xi64>
// % result: [
//            [1, 1],
//            [1, 1]
//           ]

Mais exemplos

sort

Semântica

Classifica fatias unidimensionais de inputs ao longo da dimensão dimension juntas, de acordo com um comparator, e produz results.

Ao contrário de entradas semelhantes em outras operações, dimension permite valores negativos, com a semântica descrita abaixo. No futuro, isso pode ser proibido por motivos de consistência (#1377).

Se is_stable for verdadeiro, a classificação será estável, ou seja, a ordem relativa dos elementos considerados iguais pelo comparador será preservada. Para o caso em que há uma única entrada, dois elementos e1 e e2 serão considerados iguais pelo comparador apenas se comparator(e1, e2) = comparator(e2, e1) = false. Consulte a formalização abaixo para entender como isso se aplica a várias entradas.

Mais formalmente, para todos os result_index em index_space(results[0]):

  • adjusted_dimension = dimension >= 0 ? dimension : rank(inputs[0]) + dimension.
  • result_slice = [ri0, ..., :, ..., riR-1], em que riN são elementos individuais em result_index, e : é inserido em adjusted_dimension.
  • inputs_together = (inputs[0]..., ..., inputs[N-1]...).
  • results_together[result_slice] = sort(inputs_together[result_slice], comparator_together).
  • em que sort classifica uma fatia unidimensional em ordem não decrescente esperando que comparator_together retorne true se o argumento do lado esquerdo for menor que o segundo argumento à direita.
  • def comparator_together(lhs_together, rhs_together):
      args = []
      for (lhs_el, rhs_el) in zip(lhs_together, rhs_together):
        args.append(lhs_el)
        args.append(rhs_el)
      return comparator(*args)
    
  • (results[0]..., ..., results[N-1]...) = results_together.

Entradas

Rótulo Nome Tipo Restrições
(I1) inputs número variável de tensores ou tensores quantizados por tensor (C1 a C5)
(I2) dimension constante do tipo si64 (C4)
(I3) is_stable constante do tipo i1
(I4) comparator função (C5)

Saídas

Nome Tipo Restrições
results número variável de tensores ou tensores quantizados por tensor (C2) e (C3)

Restrições

  • (C1) 0 < size(inputs).
  • (C2) type(inputs...) = type(results...).
  • (C3) same(shape(inputs...) + shape(results...)).
  • (C4) -R <= dimension < R, em que R = rank(inputs[0]).
  • (C5) comparator tem o tipo (tensor<E1>, tensor<E1>, ..., tensor<EN-1>, tensor<EN-1>) -> tensor<i1>, em que Ei = element_type(inputs[i]).

Exemplos

// %input0 = [[1, 2, 3], [3, 2, 1]]
// %input1 = [[3, 2, 1], [1, 2, 3]]
%result0, %result1 = "stablehlo.sort"(%input0, %input1) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: tensor<i64>):
    %predicate = "stablehlo.compare"(%arg0, %arg1) {
      comparison_direction = #stablehlo<comparison_direction GT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    "stablehlo.return"(%predicate) : (tensor<i1>) -> ()
}) {
  dimension = 0 : i64,
  is_stable = true
} : (tensor<2x3xi64>, tensor<2x3xi64>) -> (tensor<2x3xi64>, tensor<2x3xi64>)
// %result0 = [[3, 2, 3], [1, 2, 1]]
// %result1 = [[1, 2, 1], [3, 2, 3]]

Mais exemplos

sqrt

Semântica

Executa a operação de raiz quadrada com elemento no tensor operand e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: squareRoot do IEEE-754.
  • Para números complexos: raiz quadrada complexa.
  • Para tipos quantizados: dequantize_op_quantize(sqrt, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [[0.0, 1.0], [4.0, 9.0]]
%result = "stablehlo.sqrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// %result: [[0.0, 1.0], [2.0, 3.0]]

Mais exemplos

subtract

Semântica

Executa a subtração com elementos de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para números inteiros: subtração de números inteiros.
  • Para dados flutuantes: subtraction do IEEE-754.
  • Para números complexos: subtração complexa.
  • Para tipos quantizados:
    • dequantize_op_quantize(subtract, lhs, rhs, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)
(I2) rhs tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de inteiro, de ponto flutuante ou complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(lhs) = baseline_type(rhs) = baseline_type(result).

Exemplos

// %lhs: [[6, 8], [10, 12]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
// %result: [[1, 2], [3, 4]]

Mais exemplos

Tanh

Semântica

Executa uma operação tangente hiperbólica com elementos no tensor operand e gera um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para dados flutuantes: tanh do IEEE-754.
  • Para números complexos: tangente hiperbólica complexa.
  • Para tipos quantizados:
    • dequantize_op_quantize(tanh, operand, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_type(operand) = baseline_type(result).

Exemplos

// %operand: [-1.0, 0.0, 1.0]
%result = "stablehlo.tanh"(%operand) : (tensor<3xf32>) -> tensor<3xf32>
// %result: [-0.76159416, 0.0, 0.76159416]

Mais exemplos

transpor

Semântica

Altera as dimensões do tensor operand usando permutation e produz um tensor result. Mais formalmente, result[result_index] = operand[operand_index], em que result_index[d] = operand_index[permutation[d]].

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor ou tensor quantizado (C1 a C4)
(I2) permutation constante do tensor unidimensional do tipo si64 (C2 a C4)

Saídas

Nome Tipo Restrições
result tensor ou tensor quantizado (C1) e (C3 a C4)

Restrições

  • (C1) element_type(result) é fornecido por:
    • element_type(operand), se !is_per_axis_quantized(operand).
    • element_type(operand), exceto que quantization_dimension(operand) e quantization_dimension(result) podem ser diferentes, caso contrário.
  • (C2) permutation é uma permutação de range(rank(operand)).
  • (C3) shape(result) = dim(operand, permutation...).
  • (C4) Se is_per_axis_quantized(result), então quantization_dimension(operand) = permutation(quantization_dimension(result)).

Exemplos

// %operand: [
//            [[1,2], [3,4], [5,6]],
//            [[7,8], [9,10], [11,12]]
//           ]
%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
// %result: [
//           [[1,7], [3,9], [5,11]],
//           [[2,8], [4,10], [6,12]]
//          ]

Mais exemplos

triangular_solve

Semântica

Resolve lotes de sistemas de equações lineares com matrizes de coeficiente triangulares inferiores ou superiores.

Mais formalmente, considerando a e b, result[i0, ..., iR-3, :, :] é a solução para op(a[i0, ..., iR-3, :, :]) * x = b[i0, ..., iR-3, :, :] quando left_side é true ou x * op(a[i0, ..., iR-3, :, :]) = b[i0, ..., iR-3, :, :] quando left_side é false, resolvendo para a variável x em que op(a) é determinado por transpose_a, que pode ser um destes:

  • NO_TRANSPOSE: executa a operação usando a no estado em que se encontra.
  • TRANSPOSE: executa a operação na transposição de a.
  • ADJOINT: executa a operação na transposição conjugada de a.

Os dados de entrada serão lidos apenas do triângulo de baixo de a, se lower for true. Caso contrário, o triângulo superior de a. Os dados de saída são retornados no mesmo triângulo, e os valores no outro são definidos pela implementação.

Se unit_diagonal for verdadeiro, a implementação poderá presumir que os elementos diagonais de a são iguais a 1. Caso contrário, o comportamento será indefinido.

Para tipos quantizados, executa dequantize_op_quantize(lambda x, y: triangular_solve(x, y, left_side, lower, unit_diagonal, transpose_a), a, b, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) a tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1 a C3)
(I2) b tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1 a C4)
(I3) left_side constante do tipo i1 (C3)
(I4) lower constante do tipo i1
(I5) unit_diagonal constante do tipo i1
(I6) transpose_a enumeração de NO_TRANSPOSE, TRANSPOSE e ADJOINT

Saídas

Nome Tipo Restrições
result tensor de ponto flutuante ou tipo complexo ou tensor quantizado por tensor (C1)

Restrições

  • (C1) baseline_element_type(a) = baseline_element_type(b).
  • (C2) 2 <= rank(a) = rank(b) = R.
  • (C3) A relação entre shape(a) e shape(b) é definida da seguinte maneira:
    • shape(a)[:-3] = shape(b)[:-3].
    • dim(a, -2) = dim(a, -1) = dim(b, left_side ? -2 : -1).
  • (C4) baseline_type(b) = baseline_type(result).

Exemplos

// %a = [
//       [1.0, 0.0, 0.0],
//       [2.0, 4.0, 0.0],
//       [3.0, 5.0, 6.0]
//      ]
// %b = [
//       [2.0, 0.0, 0.0],
//       [4.0, 8.0, 0.0],
//       [6.0, 10.0, 12.0]
//      ]
%result = "stablehlo.triangular_solve"(%a, %b) {
  left_side = true,
  lower = true,
  unit_diagonal = false,
  transpose_a = #stablehlo<transpose NO_TRANSPOSE>
} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// %result: [
//           [2.0, 0.0, 0.0],
//           [0.0, 2.0, 0.0],
//           [0.0, 0.0, 2.0]
//          ]

tuple

Semântica

Produz uma tupla result a partir dos valores val.

Entradas

Rótulo Nome Tipo Restrições
(I1) val número variável de valores (C1)

Saídas

Nome Tipo Restrições
result tuple (C1)

Restrições

  • (C1) result tem o tipo tuple<E0, ..., EN-1>, em que Ei = type(val[i]).

Exemplos

// %val0: [1.0, 2.0]
// %val1: (3)
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
// %result: ([1.0, 2.0], (3))

Mais exemplos

uniform_dequantize

Semântica

Executa a conversão com elementos do tensor quantizado operand em um tensor de ponto flutuante result de acordo com os parâmetros de quantização definidos pelo tipo operand.

Mais formalmente, result = dequantize(operand).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor quantizado (C1) e (C2)

Saídas

Nome Tipo Restrições
result tensor do tipo de ponto flutuante (C1) e (C2)

Restrições

  • (C1) shape(operand) = shape(result).
  • (C2) element_type(result) = expressed_type(operand).

Exemplos

// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]

uniform_quantize

Semântica

Executa a conversão elemento-chave do tensor de ponto flutuante ou quantizado operand em um tensor quantizado result de acordo com os parâmetros de quantização definidos pelo tipo result.

Mais formalmente,

  • Se is_float(operand):
    • result = quantize(operand, type(result)).
  • Se is_quantized(operand):
    • float_result = dequantize(operand).
    • result = quantize(float_result, type(result)).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand tensor de ponto flutuante ou tipo quantizado (C1) e (C2)

Saídas

Nome Tipo Restrições
result tensor quantizado (C1) e (C2)

Restrições

  • (C1) shape(operand) = shape(result).
  • (C2) expressed_type(result) = is_float(operand) ? element_type(operand) : expressed_type(operand).

Exemplos

// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]

enquanto

Semântica

Produz a saída da execução da função body 0 ou mais vezes, enquanto a função cond gera true. Mais formalmente, a semântica pode ser expressa usando a sintaxe do Python da seguinte maneira:

internal_state = operand
while cond(*internal_state):
  internal_state = body(*internal_state)
results = internal_state

O comportamento de um loop infinito ainda será definido (383, link em inglês).

Entradas

Rótulo Nome Tipo Restrições
(I1) operand número variad de tensores, tensores quantizados ou tokens (C1 a C3)
(I2) cond função (C1)
(I3) body função (C2)

Saídas

Nome Tipo Restrições
results número variad de tensores, tensores quantizados ou tokens (C3)

Restrições

  • (C1) cond tem o tipo (T0, ..., TN-1) -> tensor<i1>, em que Ti = type(operand[i]).
  • (C2) body tem o tipo (T0, ..., TN-1) -> (T0, ..., TN-1), em que Ti = type(operand[i]).
  • (C3) type(results...) = type(operand...).

Exemplos

// %init_i: 1
// %init_sum: 0
// %one: 1
// %ten: 10
%results0, %results1 = "stablehlo.while"(%init_i, %init_sum) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %cond = "stablehlo.compare"(%arg0, %ten) {
      comparison_direction = #stablehlo<comparison_direction LT>
    } : (tensor<i64>, tensor<i64>) -> tensor<i1>
    stablehlo.return %cond : tensor<i1>
  }, {
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %new_sum = stablehlo.add %arg1, %one : tensor<i64>
    %new_i = stablehlo.add %arg0, %one : tensor<i64>
    stablehlo.return %new_i, %new_sum : tensor<i64>, tensor<i64>
}) : (tensor<i64>, tensor<i64>) -> (tensor<i64>, tensor<i64>)
// %results0: 10
// %results1: 10

Mais exemplos

Xor

Semântica

Executa XOR elemento de dois tensores lhs e rhs e produz um tensor result. Dependendo do tipo de elemento, faça o seguinte:

  • Para booleanos: XOR lógico.
  • Para números inteiros: XOR bit a bit.

Entradas

Rótulo Nome Tipo Restrições
(I1) lhs Tensor do tipo booleano ou inteiro (C1)
(I2) rhs Tensor do tipo booleano ou inteiro (C1)

Saídas

Nome Tipo Restrições
result Tensor do tipo booleano ou inteiro (C1)

Restrições

  • (C1) type(lhs) = type(rhs) = type(result).

Exemplos

// Bitwise operation with with integer tensors
// %lhs: [[1, 2], [3, 4]]
// %rhs: [[5, 6], [7, 8]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// %result: [[4, 4], [4, 12]]

// Logical operation with with boolean tensors
// %lhs: [[false, false], [true, true]]
// %rhs: [[false, true], [false, true]]
%result = "stablehlo.xor"(%lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
// %result: [[false, true], [true, false]]

Execução

Execução sequencial

Um programa StableHLO é executado fornecendo valores de entrada para a função main e calculando os valores de saída. Os valores de saída de uma função são calculados executando o gráfico de operações com acesso root na operação return correspondente.

A ordem de execução é definida pela implementação desde que esteja alinhada com o Dataflow, ou seja, se as operações forem executadas antes de serem usadas. No StableHLO, todas as operações com efeito colateral consomem um token e produzem um token (vários tokens podem ser multiplexados em um token via after_all), de modo que a ordem de execução dos efeitos colaterais também esteja alinhada com o Dataflow. As possíveis ordens de execução do programa de exemplo acima são %0%1%2%3%4return ou %3%0%1%2%4return.

Mais formalmente, um processo StableHLO é uma combinação de: 1) um programa StableHLO, 2) status de operação (ainda não executado, já executado) e 3) valores intermediários em que o processo está trabalhando. O processo começa com valores de entrada para a função main, progride no gráfico de operações que atualizam status de operações e valores intermediários e conclui com os valores de saída. Uma nova formalização ainda será definida (#484).

Execução paralela

Os programas StableHLO podem ser executados em paralelo, organizados em uma grade de processos 2D de num_replicas por num_partitions, que tem o tipo ui32.

Na grade de processos do StableHLO, num_replicas * num_partitions dos processos do StableHLO são executados ao mesmo tempo. Cada processo tem um process_id = (replica_id, partition_id) exclusivo, em que replica_id em replica_ids = range(num_replicas) e partition_id em partition_ids = range(num_partitions), que têm o tipo ui32.

O tamanho da grade de processos é conhecido estaticamente para cada programa. No futuro, planejamos torná-la uma parte explícita dos programas StableHLO #650, e a posição dentro da grade de processos é conhecida estaticamente para todos os processos. Cada processo tem acesso à própria posição na grade de processos com as operações replica_id e partition_id.

Dentro da grade de processos, os programas podem ser todos os mesmos (no estilo "Programa único, vários dados"), podem ser todos diferentes (no estilo "Vários programas, vários dados") ou algo semelhante. No futuro, estamos planejando incluir suporte para outras expressões idiomáticas de definição de programas StableHLO paralelos, incluindo o GSPMD (#619).

Dentro da grade de processos, os processos são, em sua maioria, independentes. Eles têm status de operação separados, valores de entrada/intermediário/saída separados e a maioria das operações é executada separadamente entre processos, com exceção de um pequeno número de operações coletivas descritas abaixo.

Considerando que a execução da maioria das operações usa apenas valores do mesmo processo, geralmente não há ambiguidade para se referir a esses valores pelos nomes. No entanto, ao descrever a semântica de operações coletivas, isso é insuficiente e gera a notação name@process_id para se referir ao valor name em um processo específico. Nessa perspectiva, o name não qualificado pode ser considerado como uma abreviação de name@(replica_id(), partition_id()).

A ordem de execução nos processos é definida pela implementação, exceto pela sincronização introduzida pela comunicação ponto a ponto e pelas operações coletivas, conforme descrito abaixo.

Comunicação ponto a ponto

Os processos do StableHLO podem se comunicar uns com os outros por meio de canais do StableHLO. Um canal é representado por um ID positivo do tipo si64. Por meio de várias operações, é possível enviar valores para canais e recebê-los de canais.

Ainda há mais formalização, por exemplo, de onde vêm esses IDs de canais, como os programas tomam conhecimento deles e que tipo de sincronização é introduzido por eles, ainda vai ser definido (#484).

Comunicação de streaming

Cada processo do StableHLO tem acesso a duas interfaces de streaming:

  • Infeed que pode ser lido.
  • Saída que pode ser gravada.

Ao contrário dos canais, que são usados para comunicação entre processos e, portanto, têm processos em ambas as extremidades, os feeds e as saídas têm a outra implementação definida.

Ainda vai ser definido, por exemplo, como a comunicação por streaming influencia a ordem de execução e o tipo de sincronização introduzido por ela, ainda vai ser definida (#484).

Operações coletivas

Há seis operações coletivas no StableHLO: all_gather, all_reduce, all_to_all, collective_broadcast, collective_permute e reduce_scatter. Todas essas operações dividem os processos na grade de processos do StableHLO em grupos de processos do StableHLO e executam uma computação conjunta em cada grupo de processos, independentemente de outros grupos.

Dentro de cada grupo de processo, as operações coletivas podem introduzir uma barreira de sincronização. Ainda mais formalização, por exemplo, descobrir quando exatamente essa sincronização acontece, como os processos chegam nessa barreira e o que acontece se não acontecerem, ainda será definida (#484).

Se o grupo de processos envolver a comunicação entre partições, ou seja, se houver processos no grupo com IDs de partição diferentes, a execução da operação coletiva precisará de um canal, e a operação coletiva precisará fornecer um channel_id positivo do tipo si64. A comunicação entre réplicas não precisa de canais.

Os cálculos realizados pelas operações coletivas são específicos para operações individuais e são descritos nas seções individuais acima. No entanto, as estratégias pelas quais a grade de processos é dividida em grupos de processos são compartilhadas entre essas operações e estão descritas nesta seção. Mais formalmente, o StableHLO é compatível com as quatro estratégias a seguir.

cross_replica

Apenas comunicações entre réplicas acontecem dentro de cada grupo de processos. Essa estratégia usa replica_groups, uma lista de listas de IDs de réplica, e calcula um produto cartesiano de replica_groups de partition_ids. replica_groups precisa ter elementos exclusivos e abranger todos os replica_ids. Mais formalmente, usando a sintaxe do Python:

def cross_replica(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    for partition_id in partition_ids:
      process_group = []
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Por exemplo, para replica_groups = [[0, 1], [2, 3]] e num_partitions = 2, cross_replica produzirá [[(0, 0), (1, 0)], [(0, 1), (1, 1)], [(2, 0), (3, 0)], [(2, 1), (3, 1)]].

cross_partition

Apenas comunicações entre partições acontecem dentro de cada grupo de processos. Essa estratégia usa partition_groups, uma lista de listas de IDs de partição, e calcula um produto cartesiano de partition_groups de replica_ids. partition_groups precisa ter elementos exclusivos e abranger todos os partition_ids. Mais formalmente, usando a sintaxe Python:

def cross_partition(partition_groups: List[List[PartitionId]]) -> List[List[ProcessId]]:
  for partition_group in partition_groups:
    for replica_id in replica_ids:
      process_group = []
      for partition_id in partition_group:
        process_group.append((replica_id, partition_id))
      yield process_group

Por exemplo, para partition_groups = [[0, 1]] e num_replicas = 4, cross_partition produzirá [[(0, 0), (0, 1)], [(1, 0), (1, 1)], [(2, 0), (2, 1)], [(3, 0), (3, 1)]].

cross_replica_and_partition

As comunicações entre réplicas e partições podem acontecer dentro de cada grupo de processos. Essa estratégia usa replica_groups, uma lista de listas de IDs de réplica, e calcula produtos cartesianos de cada replica_group por partition_ids. replica_groups precisa ter elementos exclusivos e abranger todos os replica_ids. Mais formalmente, usando a sintaxe Python:

def cross_replica_and_partition(replica_groups: List[List[ReplicaId]]) -> List[List[ProcessId]]:
  for replica_group in replica_groups:
    process_group = []
    for partition_id in partition_ids:
      for replica_id in replica_group:
        process_group.append((replica_id, partition_id))
    yield process_group

Por exemplo, para replica_groups = [[0, 1], [2, 3]] e num_partitions = 2, cross_replica_and_partition produzirá [[(0, 0), (1, 0), (0, 1), (1, 1)], [(2, 0), (3, 0), (2, 1), (3, 1)]].

flattened_ids

Essa estratégia usa flattened_id_groups, uma lista de listas de IDs de processos "nivelados" na forma de replica_id * num_partitions + partition_id, e as transforma em IDs de processo. flattened_id_groups precisa ter elementos exclusivos e abranger todos os process_ids. Mais formalmente, usando a sintaxe Python:

def flattened_ids(flattened_id_groups: List[List[ui32]]) -> List[List[ProcessId]]:
  for flattened_id_group in flattened_id_groups:
    process_group = []
    for flattened_id in flattened_id_group:
      replica_id = flattened_id // num_partitions
      partition_id = flattened_id % num_partitions
      process_group.append((replica_id, partition_id))
    yield process_group

Por exemplo, para flattened_id_groups = [[0, 1, 2, 3], [4, 5, 6, 7]], num_replicas = 4 e num_partitions = 2, flattened_ids vai produzir [[(0, 0), (0, 1), (1, 0), (1, 1)], [(2, 0), (2, 1), (3, 0), (3, 1)]].

Acurácia

No momento, o StableHLO não oferece garantias sobre precisão numérica, mas isso pode mudar no futuro (#1156).

Erros

Os programas StableHLO são validados por um amplo conjunto de restrições para operações individuais, o que exclui muitas classes de erros antes do tempo de execução. No entanto, as condições de erro ainda são possíveis, por exemplo, por estouros de números inteiros, acessos fora dos limites etc. A menos que chamados explicitamente, todos esses erros resultam em um comportamento definido pela implementação, mas isso pode mudar no futuro (#1157).

Como exceção a essa regra, as exceções de ponto flutuante nos programas StableHLO têm um comportamento bem definido. As operações que resultam em exceções definidas pelo padrão IEEE-754 (operação inválida, divisão por zero, estouro, subfluxo ou exceções imprecisos) produzem resultados padrão (conforme definido no padrão) e continuam a execução sem gerar a flag de status correspondente, semelhante ao processamento de exceções raiseNoFlag do padrão. Exceções para operações não padrão (por exemplo, aritmética complexa e determinadas funções transcendentais) são definidas pela implementação.

Notation

Para descrever a sintaxe, este documento usa a variação ISO modificada da sintaxe EBNF (ISO/IEC 14977:1996, Wikipédia), com duas modificações: 1) as regras são definidas usando ::= em vez de =.

2) a concatenação é expressa usando justaposição em vez de ,.

Para descrever a semântica, ou seja, nas seções "Tipos", "Constantes" e "Operações", usamos fórmulas baseadas na sintaxe do Python estendida com suporte para expressão concisa de operações de matriz, conforme descrito abaixo. Isso funciona bem para pequenos snippets de código, mas, em casos raros, quando são necessários snippets maiores, usamos a sintaxe Python básica, sempre introduzida explicitamente.

fórmulas

Vamos explorar como as fórmulas funcionam com base em um exemplo da especificação dot_general. Uma das restrições para essa operação é a seguinte: dim(lhs, lhs_batching_dimensions...) = dim(rhs, rhs_batching_dimensions...).

Os nomes usados nessa fórmula vêm de duas fontes: 1) funções globais, ou seja, dim, 2) definições de membros do elemento do programa correspondente, ou seja, entradas lhs, lhs_batching_dimensions, rhs e rhs_batching_dimensions definidas na seção "Entradas" de dot_general.

Como mencionado acima, a sintaxe dessa fórmula é baseada em Python com algumas extensões orientadas para concisão. Para que a fórmula faça sentido, vamos transformá-la em uma sintaxe básica do Python.

A) Nessas fórmulas, usamos = para representar a igualdade. Portanto, a primeira etapa para acessar a sintaxe do Python é substituir = por ==, da seguinte maneira: dim(lhs, lhs_batching_dimensions...) == dim(rhs, rhs_batching_dimensions...).

B) Além disso, essas fórmulas são compatíveis com elipses (...), que transformam expressões escalares em expressões de tensor. Em poucas palavras, f(xs...) significa "para cada x escalar no tensor xs, calcular um f(x) escalar e, em seguida, retornar todos esses resultados escalares juntos como um resultado do tensor". Na sintaxe comum do Python, nossa fórmula de exemplo se transforma em: [dim(lhs, dim1) for dim1 in lhs_batching_dimensions] == [dim(rhs, dim2) for dim2 in rhs_batching_dimensions].

Graças às elipses, muitas vezes é possível evitar o trabalho no nível de escalares individuais. No entanto, em alguns casos complicados, a sintaxe semiinformal de nível inferior pode ser usada, como na fórmula start_indices[bi0, ..., :, ..., biN] da especificação gather. No serviço da concisão, não fornecemos um formalismo exato para traduzir essa sintaxe para Python baunilha, na esperança de que ela ainda seja intuitivamente compreensível caso a caso. Informe-nos se algumas fórmulas específicas parecerem opacas, e tentaremos aprimorá-las.

Além disso, as fórmulas usam elipses para expandir todos os tipos de listas, incluindo tensores, listas de tensores (por exemplo, que podem surgir de um número variado de tensores) etc. Essa é outra área em que não fornecemos um formalismo exato (por exemplo, listas nem fazem parte do sistema do tipo StableHLO) e, em vez disso, dependemos da compreensão intuitiva.

C) O último veículo notável que empregamos é a transmissão implícita. Embora a opset StableHLO não ofereça suporte à transmissão implícita, as fórmulas também oferecem concisão. Em poucas palavras, se um escalar é usado em um contexto em que um tensor é esperado, o escalar é transmitido para a forma esperada.

Para continuar o exemplo de dot_general, aqui está outra restrição: 0 <= lhs_batching_dimensions < rank(lhs). Conforme definido na especificação dot_general, lhs_batching_dimensions é um tensor. No entanto, 0 e rank(lhs) são escalares. Depois de aplicar a transmissão implícita, a fórmula se tornará [0, ..., 0] <= lhs_batching_dimensions < [rank(lhs), ..., rank(lhs)].

Quando aplicada a uma operação dot_general específica, essa fórmula será avaliada como um tensor de booleanos. Quando as fórmulas são usadas como restrições, a restrição será mantida se a fórmula for avaliada como true ou para um tensor que tenha apenas elementos true.

Nomes

Em fórmulas, o escopo lexical inclui: 1) funções globais, 2) definições de membros,

3) definições locais. Confira abaixo a lista de funções globais. A lista de definições de elemento depende do elemento do programa a que a notação é aplicada:

  • Para operações, as definições de membro incluem nomes introduzidos nas seções "Entradas" e "Saídas".
  • Para todo o restante, as definições de membro incluem partes estruturais do elemento do programa, nomeadas com base nos não terminais de EBNF correspondentes. Na maioria das vezes, os nomes dessas partes estruturais são recebidos convertendo os nomes dos não terminais em snake-case (por exemplo, IntegerLiteral => integer_literal), mas às vezes os nomes são abreviados no processo (por exemplo, QuantizationStorageType => storage_type). Nesse caso, os nomes são introduzidos explicitamente de maneira explícita nas seções "Entradas" / "Saídas" nas especificações de operação.
  • Além disso, as definições de membro sempre incluem self para se referir ao elemento do programa correspondente.

Valores

Quando as fórmulas são avaliadas, elas trabalham com os seguintes tipos de valores: 1) Value (valores reais, por exemplo, dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>; sempre sabem os tipos), 2) Placeholder (valores futuros, por exemplo, lhs, rhs ou result; os valores reais ainda não são conhecidos, somente os tipos são conhecidos), 3) Type (tipos conforme definidos na seção "Tipos"), 4) Function (funções globais, como definido na seção "Funções").

Dependendo do contexto, os nomes podem se referir a valores diferentes. Mais especificamente, a seção "Semântica" para operações (e equivalentes para outros elementos do programa) define a lógica do ambiente de execução. Portanto, todas as entradas estão disponíveis como Value. Em contrapartida, a seção "Restrições" de operações (e equivalentes) define a lógica de "tempo de compilação", ou seja, algo que normalmente é executado antes da execução, para que apenas entradas constantes estejam disponíveis como Value, e outras entradas estão disponíveis apenas como Placeholder.

Nomes Em "Semântica" Em "Restrições"
Funções globais Function Function
Entradas constantes Value Value
Entradas não constantes Value Placeholder
Saídas Value Placeholder
Definições de locais Depende da definição. Depende da definição.

Vamos considerar um exemplo de operação transpose:

%result = "stablehlo.transpose"(%operand) {
  permutation = dense<[2, 1, 0]> : tensor<3xi64>
} : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>

Para essa operação, permutation é uma constante, por isso está disponível como um Value em semântica e restrições. Por outro lado, operand e result estão disponíveis como um Value na semântica, mas apenas como um Placeholder em restrições.

Funções

Construção de tipos

Não existem funções que possam ser usadas para construir tipos. Em vez disso, usamos diretamente a sintaxe de tipografia, porque ela geralmente é mais concisa. Por exemplo, (tensor<E>, tensor<E>) -> (tensor<E>) em vez de function_type( [tensor_type([], E), tensor_type([], E)], [tensor_type([], E)]).

Funções em tipos

  • element_type é definido em tipos de tensores e tipos de tensores quantizados e retorna, respectivamente, a parte TensorElementType ou QuantizedTensorElementType do TensorType ou QuantizedTensorType correspondente.
def element_type(x: Value | Placeholder | Type):
 if type(x) == TensorType:
    return tensor_element_type(x)
  if type(x) == QuantizedTensorType:
    return quantized_tensor_element_type(x)
  if type(x) is not Type:
    return element_type(type(x))
  • is_per_axis_quantized(x: Value | Placeholder | Type) -> Value é um atalho para is_quantized(x) and quantization_dimension(x) is not None.

  • is_per_tensor_quantized(x: Value | Placeholder | Type) -> Value é um atalho para is_quantized(x) and quantization_dimension(x) is None.

  • is_promotable(x: Type, y: Type) -> bool verifica se o tipo x pode ser promovido para o tipo y. Quando x e y forem QuantizedTensorElementTypes, a promoção será aplicada apenas ao storage_type. Essa versão específica da promoção é usada atualmente no contexto do cálculo de redução. Consulte RFC (em inglês) para mais detalhes.

def is_promotable(x: Type, y: Type) -> Value:
  is_same_type = (is_bool(x) and is_bool(y)) or
    (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
    (is_complex(x) and is_complex(y)) or
    (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

  if is_same_type == False:
    return False

  if is_integer(x) or is_float(x):
    return bitwidth(x) <= bitwidth(y)

  if is_complex(x):
    return bitwidth(element_type(x)) <= bitwidth(element_type(y))

  if is_quantized(x):
    return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

  return false
  • is_quantized(x: Value | Placeholder | Type) -> Value é um atalho para is_quantized_tensor_element_type(x).

  • is_type_name(x: Value | Placeholder | Type) -> Value. Disponível para todos os tipos. Por exemplo, is_float(x) retorna true se x for um FloatType. Se x for um valor ou marcador de posição, essa função será um atalho para is_type_name(type(x)).

  • max_value(x: Type) -> Value retorna o valor máximo de uma TensorElementType. Se x não for um TensorElementType, será retornado None.

  • min_value(x: Type) -> Value retorna o valor mínimo possível de uma TensorElementType. Se x não for um TensorElementType, será retornado None.

  • member_name(x: Value | Placeholder | Type) -> Any. Disponível para todas as definições de membro member_name de todos os tipos. Por exemplo, tensor_element_type(x) retorna a parte TensorElementType de uma TensorType correspondente. Se x for um valor ou marcador de posição, essa função será um atalho para member_name(type(x)). Se x não for um tipo que tenha um membro apropriado ou um valor ou um marcador desse tipo, retorna None.

Construção de valores

  • operation_name(*xs: Value | Type) -> Value. Disponível para todas as operações. Por exemplo, add(lhs, rhs) usa dois valores de tensor lhs e rhs e retorna a saída da avaliação da operação add com essas entradas. Para algumas operações, como broadcast_in_dim, os tipos de saídas são "suporte de carga", ou seja, necessários para avaliar uma operação. Nesse caso, a função usa esses tipos como argumentos.

Função em valores

  • Todos os operadores e funções do Python estão disponíveis. Por exemplo, as notações de assinatura e fração do Python estão disponíveis para indexação em tensores, tensores quantizados e tuplas.

  • to_destination_type(x: Value, destination_type: Type) -> Value é definido em tensores e retorna o valor convertido de x com base em type(x) e destination_type da seguinte maneira:

def to_destination_type(x: Value, destination_type: Type) -> Value:
  if type(x) == destination_type:
    return x

  if is_quantized(destination_type):
    if is_quantized(type(x)):
      return quantize(x, destination_type)
    assert is_float(type(x))
    return quantize(x, destination_type)

  if is_quantized(type(x)):
    assert destination_type = expressed_type(type(x))
    return dequantize(type(x))

  return convert(x, destination_type)

Há uma discussão inicial sobre a mesclagem das operações convert, uniform_quantize e uniform_dequantize (#1576). Após a mesclagem, a função acima não é mais necessária e podemos usar o nome da operação para convert.

  • is_nan(x: Value) -> Value é definido nos tensores e retorna true se todos os elementos de x forem NaN ou false, caso contrário. Se x não for um tensor, retorna None.

  • is_sorted(x: Value) -> Value é definido em tensores e retorna true se os elementos de x estiverem classificados em ordem crescente em relação à ordem lexicográfica crescente de seus índices. Caso contrário, retorna false. Se x não for um tensor, None será retornado.

  • is_unique(x: Value) -> Value é definido nos tensores e retorna true se x não tiver elementos duplicados ou false caso contrário. Se x não for um tensor, retorna None.

  • member_name(x: Value) -> Any é definido para todas as definições de membros member_name de todos os valores. Por exemplo, real_part(x) retorna a parte RealPart de uma ComplexConstant correspondente. Se x não for um valor que tenha um membro apropriado, retorna None.

  • same(x: Value) -> Value é definido nos tensores e retorna true se os elementos de x forem todos iguais entre si ou false caso contrário. Se o tensor não tiver elementos, isso conta como "todos iguais entre si", ou seja, a função retorna true. Se x não for um tensor, retorna None.

  • split(x: Value, num_results: Value, axis: Value) -> Value é definido em tensores e retorna fatias num_results de x ao longo do eixo axis. Se x não for um tensor ou dim(x, axis) % num_results != 0, retornará None.

Computações de formas

  • axes(x: Value | Placeholder | Type) -> Value é um atalho para range(rank(x)).

  • dim(x: Value | Placeholder | Type, axis: Value) -> Value é um atalho para shape(x)[axis].

  • dims(x: Value | Placeholder | Type, axes: List) -> List é um atalho para list(map(lambda axis: dim(x, axis), axes)).

  • index_space(x: Value | Placeholder | Type) -> Value é definido em tensores e retorna índices size(x) para a TensorType correspondente classificada em ordem lexicográfica crescente, ou seja, [0, ..., 0], [0, ..., 1], ..., shape(x) - 1. Se x não for um tipo de tensor, um tipo de tensor quantizado, um valor ou um marcador de posição de um desses tipos, retorna None.

  • rank(x: Value | Placeholder | Type) -> Value é um atalho para size(shape(x)).

  • shape(x: Value | Placeholder | Type) -> Value é definido na seção "Funções em tipos" via member_name.

  • size(x: Value | Placeholder | Type) -> Value é um atalho para reduce(lambda x, y: x * y, shape(x)).

Computações de quantização

  • def baseline_element_type(x: Value | Placeholder | Type) -> Type é um atalho para element_type(baseline_type(x)).

  • baseline_type é definido em tipos de tensores e tipos de tensores quantizados e os transforma em um "valor de referência", ou seja, um tipo com a mesma forma, mas com os parâmetros de quantização do tipo de elemento redefinidos para os valores padrão. Isso é usado como um truque útil para comparar os dois tipos de tensor de maneira uniforme, o que é necessário com bastante frequência. Para tipos quantizados, isso permite comparar tipos, ignorando os parâmetros de quantização, ou seja, shape, storage_type, expressed_type, storage_min, storage_max e quantization_dimension (para o tipo quantizado por eixo) precisam ser correspondentes, mas scales e zero points podem ser diferentes.

def baseline_type(x: Value | Placeholder | Type) -> Type:
  if type(x) == TensorType:
    return x
  if type(x) == QuantizedTensorType:
    element_type = quantized_tensor_element_type(x)
    baseline_element_type = QuantizedTensorElementType(
      storage_type = storage_type(element_type),
      storage_min = storage_min(element_type),
      storage_max = storage_max(element_type),
      expressed_type = expressed_type(element_type),
      quantization_dimension = quantization_dimension(element_type),
      scales = [constant(1.0, expressed_type(element_type))] * dim(x, quantization_dimension(element_type)),
      zero_points = [constant(0, storage_type(element_type))] * dim(x, quantization_dimension(element_type)))
    return QuantizedTensorType(shape(x), baseline_element_type)
  if type(x) is not Type:
    return baseline_element_type(type(x))
  • dequantize é definido em tipos de tensores quantizados e os transforma em tipos de tensores de ponto flutuante. Isso acontece pela conversão de elementos quantizados, que representam valores inteiros do tipo de armazenamento em valores de ponto flutuante correspondentes do tipo expresso usando o ponto zero e a escala associados ao tipo de elemento quantizado.
def compute_zero_points(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(zero_point(quantized_type), storage_type(quantized_type)), [], result_type)
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      zero_points[i] = zero_points(quantized_type)[i[d]]
    return zero_points

def compute_scales(quantized_type, result_type):
  if is_per_tensor_quantized(quantized_type):
    return broadcast_in_dim(constant(scale(quantized_type), expressed_type(quantized_type)), [],
            type(result_type))
  if is_per_axis_quantized(quantized_type):
    for i in index_space(result_type):
      d = quantization_dimension(quantized_type)
      scales[i] = scales(quantized_type)[i[d]]
    return scales

def dequantize(x: Value) -> Value:
  assert is_quantized(x)
  x_storage = bitcast_convert(x, storage_type(x))
  x_storage_sub = x_storage - compute_zero_points(type(x), type(x_storage))
  x_expressed_sub = convert(x_storage_sub, expressed_type(x))
  return x_expressed_sub * compute_scales(type(x), type(x_expressed_sub))
  • quantize é definido em tipos de tensores de ponto flutuante e os transforma em tipos de tensores quantizados. Isso acontece pela conversão de valores de ponto flutuante do tipo expresso em valores inteiros correspondentes do tipo de armazenamento usando o ponto zero e a escala associados ao tipo de elemento quantizado.
def quantize(x: Value, type: Type) -> Value:
  assert is_float(x) and is_quantized(type)
  x_expressed_rounded = round_nearest_even(x / compute_scales(type, type(x)))
  x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
  x_storage_add = x_storage_rounded + compute_zero_points(type, type(x_storage_rounded))
  x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
  return bitcast_convert(x_storage, type)
  • dequantize_op_quantize é usado para especificar cálculos elementos em tensores quantizados. Ela desquantiza, ou seja, transforma elementos quantizados nos tipos expressos, executa uma operação e quantiza, ou seja, transforma os resultados de volta nos tipos de armazenamento. No momento, essa função só funciona para quantização por tensor. A quantização por eixo está em desenvolvimento (#1574).
def dequantize_op_quantize(op, *inputs_and_output_type):
  inputs = inputs_and_output_type[:-1]
  output_type = inputs_and_output_type[-1]

  float_inputs = map(dequantize, inputs)
  float_result = op(*float_inputs)
  return quantize(float_result, output_type)

def dequantize_batch_norm_grad_or_training_quantize(op, *inputs_and_output_types):
  inputs = inputs_and_output_type[:-3]
  float_inputs = map(dequantize, inputs)
  float_results = op(*float_inputs)
  return map(quantize, float_results, inputs_and_output_type[-3:])

def dequantize_compare(lhs, rhs, comparison_direction):
  float_lhs = dequantize(lhs)
  float_rhs = dequantize(rhs)
  return compare(float_lhs, float_rhs, comparison_direction, FLOAT)

def dequantize_select_quantize(pred, on_true, on_false, output_type):
  float_on_true = dequantize(on_true)
  float_on_false = dequantize(on_false)
  float_result = select(pred, float_on_true, float_on_false)
  return quantize(float_result, output_type)

Cálculos em grade

  • cross_partition(replica_groups: Value) -> Value. Consulte a seção "cross_replica" acima.

  • cross_replica(replica_groups: Value) -> Value. Consulte a seção "cross_replica" acima.

  • cross_replica_and_partition(replica_groups: Value) -> Value. Consulte a seção "cross_replica_and_partition" acima.

  • flattened_ids(replica_groups: Value) -> Value. Consulte a seção "Flated_ids" acima.